对于简单的并行任务,你可以通过“线程池+Future”的方案来解决;如果任务之间有聚合关系,无论是AND聚合还是OR聚合,都可以通过CompletableFuture来解决;而批量的并行任务,则可以通过CompletionService来解决。
分治,即分而治之,是一种解决复杂问题的思维方法和模式;具体来讲,指的是把一个复杂的问题分解成多个相似的子问题,然后再把子问题分解成更小的子问题,直到子问题简单到可以直接求解。
1. 分治任务模型
分为两个阶段:
- 一个阶段是任务分解,也就是将任务迭代地分解为子任务,直至子任务可以直接计算出结果;
- 另一个阶段是结果合并,即逐层合并子任务的执行结果,直至获得最终结果。
2. Fork/Join的使用
Fork对应任务分解,Join对应结果合并。
Fork/Join计算框架主要包含分治任务的线程池ForkJoinPool和分治任务ForkJoinTask。这两部分的关系类似于ThreadPoolExecutor和Runnable的关系。
ForkJoinTask是一个抽象类,最核心的是fork()方法和join()方法,其中fork()方法会异步地执行一个子任务,而join()方法则会阻塞当前线程来等待子任务的执行结果。
ForkJoinTask有两个子类——RecursiveAction和RecursiveTask,都是用递归的方式来处理分治任务的。这两个子类都定义了抽象方法compute(),RecursiveAction定义的compute()没有返回值,而RecursiveTask定义的compute()方法是有返回值的。两个子类也是抽象类,需要定义子类去扩展。
代码例子
public class MyTest2 {
public static void main(String[] args) {
// 创建分治任务线程池
ForkJoinPool fjp = new ForkJoinPool(4);
// 创建分治任务
Fibonacci fib = new Fibonacci(30);
// 启动分治任务
Integer result = fjp.invoke(fib);
// 输出结果
System.out.println(result);
}
// 递归任务
static class Fibonacci extends RecursiveTask<Integer> {
final int n;
Fibonacci(int n) {
this.n = n;
}
protected Integer compute() {
if (n <= 1)
return n;
Fibonacci f1 = new Fibonacci(n - 1);
// 创建子任务
f1.fork();
Fibonacci f2 = new Fibonacci(n - 2);
// 等待子任务结果,并合并结果
return f2.compute() + f1.join();
}
}
}
3. ForkJoinPool工作原理
ThreadPoolExecutor本质上是一个生产者-消费者模式的实现,内部有一个任务队列,这个任务队列是生产者和消费者通信的媒介;ThreadPoolExecutor可以有多个工作线程,但是这些工作线程都共享一个任务队列。
ForkJoinPool本质上也是一个生产者-消费者的实现, 内部有多个任务队列,当我们通过ForkJoinPool的invoke()或者submit()方法提交任务时,ForkJoinPool根据一定的路由规则把任务提交到一个任务队列中,如果任务在执行过程中会创建出子任务,那么子任务会提交到工作线程对应的任务队列中。
ForkJoinPool支持一种叫做“任务窃取”的机制,如果工作线程空闲了,那它可以“窃取”其他工作任务队列里的任务,所有的工作线程都不会闲下来。
ForkJoinPool中的任务队列采用的是双端队列,工作线程正常获取任务和“窃取任务”分别是从任务队列不同的端消费,这样能避免很多不必要的数据竞争。
4. 模拟MapReduce统计单词数量
统计一个文件里面每个单词的数量,先用二分法递归地将一个文件拆分成更小的文件,直到文件里只有一行数据,然后统计这一行数据里单词的数量,最后再逐级汇总结果。
示例程序用一个字符串数组 String[] fc 来模拟文件内容,fc里面的元素与文件里面的行数据一一对应。关键的代码在 compute() 这个方法里面,这是一个递归方法,前半部分数据fork一个递归任务去处理(关键代码mr1.fork()),后半部分数据则在当前任务中递归处理(mr2.compute())。
public class MyTest2 {
public static void main(String[] args) {
String[] fc = { "hello world", "hello me", "hello fork", "hello join", "fork join in world" };
// 创建ForkJoin线程池
ForkJoinPool fjp = new ForkJoinPool(3);
// 创建任务
MR mr = new MR(fc, 0, fc.length);
// 启动任务
Map<String, Long> result = fjp.invoke(mr);
// 输出结果
result.forEach((k, v) -> System.out.println(k + ":" + v));
}
// MR模拟类
static class MR extends RecursiveTask<Map<String, Long>> {
private String[] fc;
private int start, end;
// 构造函数
MR(String[] fc, int fr, int to) {
this.fc = fc;
this.start = fr;
this.end = to;
}
@Override
protected Map<String, Long> compute() {
if (end - start == 1) {
return calc(fc[start]);
} else {
int mid = (start + end) / 2;
MR mr1 = new MR(fc, start, mid);
mr1.fork();
MR mr2 = new MR(fc, mid, end);
// 计算子任务,并返回合并的结果
return merge(mr2.compute(), mr1.join());
}
}
// 合并结果
private Map<String, Long> merge(Map<String, Long> r1, Map<String, Long> r2) {
Map<String, Long> result = new HashMap<>();
result.putAll(r1);
// 合并结果
r2.forEach((k, v) -> {
Long c = result.get(k);
if (c != null)
result.put(k, c + v);
else
result.put(k, v);
});
return result;
}
// 统计单词数量
private Map<String, Long> calc(String line) {
Map<String, Long> result = new HashMap<>();
// 分割单词
String[] words = line.split("\\s+");
// 统计单词数量
for (String w : words) {
Long v = result.get(w);
if (v != null)
result.put(w, v + 1);
else
result.put(w, 1L);
}
return result;
}
}
}