Spring Boot 多线程使用

Spring Boot 提供了非常优雅地使用多线程执行任务的方式,本文说明 Spring Boot 项目如何利用 ThreadPoolTaskExecutor 来使用多线程。

创建 Spring Boot 项目

使用 IntelliJ Idea 创建向导创建一个 Spring Boot 项目,或者在 Spring 官网创建一个 Spring Boot 项目,地址:https://start.spring.io/。由于创建过程比较简单,此处不再赘述。
其中,pom.xm 文件如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.4.1</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>me.leehao</groupId>
<artifactId>async-method</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>async-method</name>
<description>Demo project for Async Method</description>

<properties>
<java.version>1.8</java.version>
</properties>

<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>

由于本文采用调用 Web 接口来进行测试,故引入了 spring-boot-starter-web 依赖。

配置多线程

为了在 Spring Boot 中使用多线程,需要在配置类中添加 @EnableAsync 注解。

@SpringBootApplication 是 Spring Boot 项目的核心注解,实际包含以下三个注解:

  • @Configuration:用于定义一个配置类
  • @EnableAutoConfiguration:Spring Boot 根据依赖来自动配置项目
  • @ComponentScan:告诉 Spring 哪个 package 用注解标识的类会被 Spring 自动扫描并装入 bean 容器

因此,可以直接在启动主类中添加 @EnableAsync 注解,这样一来,被添加 @Async 注解的方法,就会在线程池中执行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
package me.leehao.asyncmethod;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.client.RestTemplate;

import java.util.concurrent.Executor;

@SpringBootApplication
@EnableAsync
public class AsyncMethodApplication {
public static void main(String[] args) {
SpringApplication.run(AsyncMethodApplication.class, args);
}

@Bean("taskExecutor")
public Executor asyncExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(8);
executor.setMaxPoolSize(8);
executor.setQueueCapacity(500);
executor.setThreadNamePrefix("async-query-");
executor.initialize();
return executor;
}

@Bean
public RestTemplate restTemplate() {
return new RestTemplate();
}
}

同时,使用 @Bean 注解,定义一个名称为 taskExecutorExecutor,并对线程池进行配置。设置线程池初始化线程数量为 8,最大线程数量为 8 ,任务队列长度为 500,并设置线程名称前缀为 async-query- 以便于日志打印时确定执行的线程。

为了使用 RestTemplate 发送 HTTP GET 请求,定义了一个 RestTemplate bean。

使用多线程

完成线程池配置后,接下来,只需要在需要多线程执行的方法添加@Async 注解即可实现异步并发执行。

@Async("taskExecutor") 定义执行 Executor 的名称为 taskExecutor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package me.leehao.asyncmethod;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;

import java.util.concurrent.CompletableFuture;

@Service
public class AsyncService {
private static final Logger logger = LoggerFactory.getLogger(AsyncService.class);

@Autowired
private RestTemplate restTemplate;

@Async("taskExecutor")
public CompletableFuture<String> queryIp() throws InterruptedException {
String url = "http://httpbin.org/ip";
String results = restTemplate.getForObject(url, String.class);
logger.info("查询 ip 返回:{}", results);
// 等待 10 秒,模拟调用第三方接口较长时间返回结果
Thread.sleep(10000L);
return CompletableFuture.completedFuture(results);
}
}

queryIp 方法用于向 http://httpbin.org/ip 查询 ip 地址,为了模拟请求较长时间返回,在方法中 sleep 了 10 秒时间,即如果单线程调用三次 queryIp 方法,至少需要 30 秒,这里可以展示出多线程调用时的执行时间。

queryIp 方法返回结果是 CompletableFuture<String> ,这是异步方法要求的约定。

执行

为触发异步方法的调用,定义一个 controller 接口 /async

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
package me.leehao.asyncmethod;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;

import javax.annotation.Resource;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

@RestController
public class AsyncController {
private static final Logger logger = LoggerFactory.getLogger(AsyncController.class);

@Resource
private AsyncService asyncService;

@RequestMapping(value = "/async", method = RequestMethod.GET)
public String async() throws InterruptedException, ExecutionException {
long start = System.currentTimeMillis();

CompletableFuture<String> ip1 = asyncService.queryIp();
CompletableFuture<String> ip2 = asyncService.queryIp();
CompletableFuture<String> ip3 = asyncService.queryIp();

CompletableFuture.allOf(ip1, ip2, ip3).join();

float exc = (float)(System.currentTimeMillis() - start)/1000;
logger.info("完成所有查询所用时间:{}", exc);

return String.format("ip1: %s, ip2: %s, ip3: %s",
ip1.get(), ip2.get(), ip3.get());
}
}

/async 接口中,调用了三次查询 ip 的方法。通过 CompletableFuture.allOf 方法创建一个 CompletableFuture 对象数组,然后调用 join 方法,等待所有的调用完毕,并打印出最终的消耗时间。

执行程序,调用 /async 接口, 可以看到日志输出:

2021-01-05 17:08:56.589 INFO 3032 — [ async-query-3] me.leehao.asyncmethod.AsyncService : 查询 ip 返回:{
“origin”: “x.x.x.x”
}

async-query- 正是我们定义的线程名称前缀。

浏览器输出:

ip1: { “origin”: “x.x.x.x” } , ip2: { “origin”: “x.x.x.x” } , ip3: { “origin”: “x.x.x.x” }

说明三次查询 ip 的请求在线程池中执行,并已成功返回。

附:源代码

https://github.com/haozlee/async-method

参考资料