一 前言
CountDownLatch使用场景主要用于控制主线程等待所有子线程全部执行完成然后恢复主线程执行。
主要有两个核心方法:
countDown()用于减少计数器次数,每调用一次就会减少1。
await()方法表示什么时候CountDownLatch计数的值为0时,才返回主线程执行,否则处于阻塞状态。
假设有这样一种场景:
我们需要分批先从数据库查询出数据,然后将数据请求第三方,根据三方返回结果进行大量逻辑的数据处理。
如果单线程去处理,势必太慢了,大家当然会想到用多线程去处理,但是,有一个问题就是我们如何保证每一次
查询的数据不是正在处理的数据?
也就是要实现只有当每一批数据处理完之后再去数据库取下一批数据,并且每一批数据都是多线程处理的。
方案有很多,这里只针对自定义线程池和使用CountDownLatch来保证顺序处理批次数据进行讨论。
二 自定义线程池
在Java并发包下线程池工具类Excutors中含有多种现场池,我们可以参考其创建线程池的源码定义自己的线程池。
创建线程池需要注意线程池核心数,最大线程池数,线程阻塞队列,以及饱和策略。同时需要为自己的线程取一个
有意义的名称,方便日志查看和错误分析。
下面为自定义线程池源码,可以根据自己的需要进行调整。
package com.lanhuigu.thread.threadpool;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 自定义线程池
* @author yihonglei
* @date 2018/8/31 15:07
*/
public class ThreadPoolUtil {
public static final int DEFAULT_CORE_THREADS = Runtime.getRuntime().availableProcessors();
/**
* 默认创建CPU核心线程数线程池
* @author yihonglei
* @date 2018/8/31 15:16
* @param threadPrefix 线程名称前缀,给线程取一个有意义的名称,方便问题排查
* @return java.util.concurrent.ExecutorService
*/
public static ExecutorService newThreadPool(String threadPrefix) {
return doNewThreadPool(threadPrefix, DEFAULT_CORE_THREADS, DEFAULT_CORE_THREADS);
}
/**
* 线程池创建
* @author yihonglei
* @date 2018/8/31 15:50
* @param threadPrefix 线程名称前缀,给线程取一个有意义的名称,方便问题排查
* @param coreThreads 核心线程池数
* @param maxThreads 最大线程池数
* @return java.util.concurrent.ExecutorService
*/
public static ExecutorService newThreadPool(String threadPrefix, int coreThreads, int maxThreads) {
return doNewThreadPool(threadPrefix, coreThreads, maxThreads);
}
/**
* 线程池创建
* @author yihonglei
* @date 2018/8/31 15:50
* @param threadPrefix 线程名称前缀,给线程取一个有意义的名称,方便问题排查
* @param coreThreads 核心线程池数
* @param maxThreads 最大线程池数
* @return java.util.concurrent.ExecutorService
*/
private static ExecutorService doNewThreadPool(final String threadPrefix, int coreThreads, int maxThreads) {
/** 线程池阻塞队列 */
final LinkedBlockingQueue<Runnable> blockingQueue = new LinkedBlockingQueue<>(1024);
/** 创建线程池 */
ExecutorService executorService = new ThreadPoolExecutor(coreThreads, maxThreads, 60, TimeUnit.SECONDS, blockingQueue, new ThreadFactory() {
/** 创建线程 */
final AtomicInteger counter = new AtomicInteger();
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
t.setName(threadPrefix + "-" + counter.incrementAndGet());
t.setPriority(Thread.NORM_PRIORITY);
return t;
}
}, new RejectedExecutionHandler() {
/** 饱和策略 */
@Override
public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
try {
blockingQueue.put(r);
} catch (InterruptedException e) {
System.out.println("TODO-实际项目中用日志框架打日志");
}
}
});
return executorService;
}
/**
* 关闭线程池
* @author yihonglei
* @date 2018/8/31 19:00
* @param executorService
* @return void
*/
public static void shutdown(ExecutorService executorService) {
if (!executorService.isShutdown()) {
executorService.shutdown();
}
}
}
三 CountDownLatch线程执行顺序控制
线程池定义完后通过CountDownLatch控制主线程必须等待线程池子线程执行完才恢复执行主线程。
package com.lanhuigu.thread.countdownlatch;
import com.lanhuigu.thread.threadpool.ThreadPoolUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
/**
* 多线程+CountDownLatch演示
* @author yihonglei
* @date 2018/8/31 17:22
*/
public class CountDownLatchTest{
private static final int CORE_THREADS = 50;
private static ExecutorService executorService = null;
public static void main(String[] args) {
start();
}
/**
* 启动线程
*/
public static void start() {
int counterBatch = 1;
try {
// 开启线程池
openPool();
// 数据循环处理
while (true) {
// 模拟数据库查询出的List
List<String> list = new ArrayList<>();
for (int i = 0; i < 10; i++) {
list.add("aaa");
}
// 计数器大小定义为集合大小,避免处理不一致导致主线程无限等待
CountDownLatch countDownLatch = new CountDownLatch(list.size());
// 循环处理List
list.parallelStream().forEach(str -> {
// 任务提交线程池
CompletableFuture.supplyAsync(() -> {
try {
// 用户数据Check
dealUserData(str);
} finally {
countDownLatch.countDown();
}
return 1;
}, executorService);
});
// 主线程等待所有子线程都执行完成时,恢复执行主线程
countDownLatch.await();
System.out.println("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@");
// 数据批次计数器
counterBatch++;
// TODO 模拟执行30批
if (counterBatch >= 30) {
break;
}
}
System.out.println("counterBatch=" + counterBatch);
// 关闭线程池
closePool();
} catch (Exception e) {
System.out.println("异常日志");
}
}
/**
* 模拟根据用户Id处理用户数据的逻辑
*/
public static void dealUserData(String uId) {
System.out.println(Thread.currentThread().getName() + "======用户数据处理完成======");
}
/**
* 开启线程池
*/
public static void openPool() {
if (null == executorService || executorService.isShutdown()) {
executorService = ThreadPoolUtil.newThreadPool("MyThread", CORE_THREADS , CORE_THREADS);
}
}
/**
* 关闭线程池
*/
public static void closePool() {
if (executorService != null) {
ThreadPoolUtil.shutdown(executorService);
}
}
}
程序源码分析:
1)模拟从数据库每一次取出一批数据,每批数据为10条;
2)CountDownLatch计数器大小定义与数据条数相同,这里就为10;
3)然后循环List,没一条数据创建一个线程,然后提交线程池,同时需要进行countDown(),每次提交减1。
4)主线程也就是这里的main线程,调用了await()方法,await()方法表示等待线程池的线程执行完成,
然后才恢复主线程,进行下一次的循环取批数据处理。
从而我们可以实现每一批数据取出后,交由线程池多线程处理,并且主线程会等待子线程都执行完成,
然后才恢复执行,进行下一次的循环取批处理。