Shuffle
的定义
- 我们都知道
Spark
是一个基于内存的、分布式的、迭代计算框架。在执行Spark
作业的时候,会将数据先加载到Spark
内存中,内存不够就会存储在磁盘中,那么数据就会以Partition
的方式存储在各个节点上,我们编写的代码就是操作节点上的Partiton
数据。之前我们也分析了怎么我们的代码是怎么做操Partition
上的数据,其实就是有Driver
将Task
发送到每个节点上的Executor
上去执行,在操作Partiton
上的数据时候,遇到Action
操作的时候会生成一个新的Partition
,而这个Partition
是由多个节点上的Partition
组成的,这样就实现了跨界点,我们管这种操作就叫Spark
的Shuffle
操作。其实比较比较通俗的来说,其实就是上一个Stage
的输出,下一个Stage
拉取这个输出的过程就是Shuffle
。
- 我们都知道
Shuffle
的原理普通
Shuffle
:- 在
Spark1.6
的版本之前,Shuffle
就是一个没有经过优化的Shuffle
,它的原理就是每一个ShuffleMapTask
都会根据ResultTask
数量在内存创建多个bucket
,并且会在该ShuffleMapTask
节点上的磁盘上为每一个bucket
创建一个blockFile
文件,如果是4个ResultTask
,那么就会有4*4=16
个blockFile
文件。 Task
的运行结果全部写入到缓存bucket
中,然后将bucket
的数据全部刷新到blockFile
文件中。ShuffleMapTask
就会将Task
的执行状态以及结果存储的地址封装成MapStatus
对象发送给Driver
。- 当
ResultTask
运行的时候,就会向Driver
发送请求来获取它所依赖的blockFile
文件的信息。 ResultTask根据这些信息利用
BlockManger
来将本地数据或者远程的数据通过网络或者直接读取的方式拉取过来存到内存中作为自己的输入数据,ResultTask
计算结束以后,将结果返回给我们。这就是早起版本Shuffle
的原理。
- 在
- 优化后的
Shuffle
:
- 在
Spark1.6
以后,对shuffle
进行了优化。优化的原理是根据core
来优化的,因为运行在每一个Executor
上的Task都是并行运行的,例如有两个core
,如果有4
个Task
,4
个ResultTask
。这个时候只能并行运行两个Task
,然后ShuffleMapTask
会根据ResultTask
的数量来创建8
个bucket
,然后在根据bucket
在本地磁盘创建8
个BlockFile
。 - 当另外两个Task执行的时候也会根据
ResultTask
创建8
个bucket
,但是这个时候,不会在本地磁盘上创建BlockFile
了,而是将结果追加到前两个Task
对应的8
个Block这里写代码片
File文件中。 - 将结果写入
BlockFile
的时候,首先不会等到结果全部写入内存以后在刷新到磁盘上,而是当内存达到一定的阈值就会将数据刷新到磁盘中个,这样防止了OOM
。
- 在
Shuffle
的写源码分析ShuffleMapTask
的runTask
方法:该方法中writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
方法就是Shuffle
写操作的开始,Spark
默认的写操作是HashShuffleWriter
//该方法的作用执行Task,然后将结果返回给调度器 //其中MapStatus封装了块管理器的地址,以及每个reduce的输出大小 //以便于传递reduce的任务 override def runTask(context: TaskContext): MapStatus = { //记录反序列化RDD的开始时间 val deserializeStartTime = System.currentTimeMillis() //创建一个序列化器 val ser = SparkEnv.get.closureSerializer.newInstance() // 反序列化广播变量来的得到RDD val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null try { //获取ShuffleManager val manager = SparkEnv.get.shuffleManager //利用ShuffleManager获取ShuffleWriter,ShuffleWriter的功能就是将Task计算的结果 //持久化到shuffle文件中,作为子Stage的输入 writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) //通过ShuffleWriter将结果进行持久化到shuffle文件中,作为子Stage的输入 //rdd.iterator(partition, context)这个方法里就会执行我们自己编写的业务逻辑代码 writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) //关闭writer,将元数据写入MapStatus中,然后返回 writer.stop(success = true).get } catch { case e: Exception => try { if (writer != null) { writer.stop(success = false) } } catch { case e: Exception => log.debug("Could not stop writer", e) } throw e } }
HashShuffleWriter
中的writer
方法://ShuffleMapTask计算出来的结果写入磁盘的方法 override def write(records: Iterator[Product2[K, V]]): Unit = { //判断是否对ShuffleMapTask计算后得到的Partition对应的Iterator集合进行Map端的本地聚合 val iter = if (dep.aggregator.isDefined) { //如果dep.mapSideCombine和dep.aggregator.isDefined为true,那么就进行Map端的本地聚合 if (dep.mapSideCombine) { //开始本地聚合的方法 dep.aggregator.get.combineValuesByKey(records, context) } else { records } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") records } //遍历聚合以后的数据,调用partitioner,默认是HashPartitioner,生成bucketId //然后将数据写入bucketId对应的Bucket中去 for (elem <- iter) { val bucketId = dep.partitioner.getPartition(elem._1) //shuffle:它就是ShuffleWriterGroup对象,其实就是为ShuffleWriterGroup定义的一组writer //它是由FileShuffleBlockResolver这类调用forMapTask方法生成的, // shuffle.writers(bucketId)根据bucketId从ShuffleWriterGroup对应的一组writer中找出对应的DiskBlockObjectWriter //然后将数据写入对应的文件中 shuffle.writers(bucketId).write(elem._1, elem._2) } }
FileShuffleBlockResolver
类的forMapTask
方法:该方法的作用就是根据Map Task
得到一个ShuffleWriterGroup
。def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { //为每个ShuffleMapTask实例化一个ShuffleWriterGroup new ShuffleWriterGroup { //实例化ShuffleState并保存shuffleId与ShuffleState的对应关系 shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) //根据shuffleId获取ShuffleState private val shuffleState = shuffleStates(shuffleId) val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() //实例化Writer:DiskBlockObjectWriter val writers: Array[DiskBlockObjectWriter] = { Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => //生成blockId val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) //根据blcokId生成blockFile,用于存储写入的数据 val blockFile = blockManager.diskBlockManager.getFile(blockId) //生成一个临时文件,也就是把数据写入那个目录里 val tmp = Utils.tempFileWith(blockFile) //生成writer:实例化Writer:DiskBlockObjectWriter,用于将数据写入磁盘 //这里的BufferSize的默认大小是32kb,可以通过spark.shuffle.file.buffer重新设置大小 blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics) } } // Creating the file to write to and creating a disk writer both involve interacting with // the disk, so should be included in the shuffle write time. writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { shuffleState.completedMapTasks.add(mapId) } } }
Shuffle
的读源码分析- 在上一讲Spark(十六)Executor执行Task的原理与源码分析(二)中我们已经对
ResultTask
进行了源码的分析,ResultTask
里的runTask
方法就是开始计算最终我们想要的结果,那么既然要计算结果就需要从远程或者本地拉去上一个Stage
处理过后的结果,也就是Shuffle
写入到磁盘中的数据。从上一篇的源码分析可以找出,ResultTask
拉取数据的起点是ShuffledRDD
类里的compute
方法。 ShuffleRDD
的compute
方法override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] //回调用shuffleManger的getReader方法获取SortShuffleManager, //调用它的read方法拉取数据 SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) .read() .asInstanceOf[Iterator[(K, C)]] }
SortShuffleManager
里的getRead
方法override def getReader[K, C]( handle: ShuffleHandle, startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { //创建一个BlockStoreShuffleReader new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) }
BlockStoreShuffleReader
的read
方法override def read(): Iterator[Product2[K, C]] = { //实例化ShuffleBlockFetcherIterator,在实例化这个对象的时候,回调用它内部的initialze方法,这个方法会调用splitLocalRemoteBlocks方法来路由拉取数据的策略, //拉取数据策略分为两种,一种是本地策略,一种是远程策略。 val blockFetcherItr = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, blockManager, //通过消息发送获取ShuffleMapTask存储数据的位置 mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility //设置每次拉取数据的大小 SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, //设置从远程节点拉取快的数量 SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) // 根据配置对流进行压缩和加密,构建一个包装流 val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => serializerManager.wrapStream(blockId, inputStream) } //构建一个序列化器 val serializerInstance = dep.serializer.newInstance() // 为每个包装流创建一个键/值迭代器。 val recordIter = wrappedStreams.flatMap { wrappedStream => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } // 每条记录读取后更新指标。这样在UI界面上就能看见相关信息 val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() //生成一个完整的迭代器 val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map { record => readMetrics.incRecordsRead(1) record }, context.taskMetrics().mergeShuffleReadMetrics()) // 为了这个任务可以取消,那么就必须使用可中断的迭代器 val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) //判断数据聚合操作是否被定义 val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { //判断是否启用数据的聚合操作,因为这个Stage相于下一个Stage是Map端 if (dep.mapSideCombine) { val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] //如果启用就调用combineCombinersByKey方法 dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else {//不启用数据聚合操作 val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] //如果不启用就调用combineValueByKey方法 dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else {//如果聚合操作没有被定义就会报错 require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // 根据keyOrdering属性判断是否对输出结果 dep.keyOrdering match { //如果keyOrdering不为空,就对输出结果进行排序 case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, // the ExternalSorter won't spill to disk. //创建ExternalSorter对象对数据进行排序,如果spark.shuffle.spill没有开启 //ExternalSorter不会将数据持久化到磁盘上的 val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } }
MapOutputTracker
类里的getMapSizeByExecutorId
方法,该方法的作用就是告诉Executor
去获取每个shuffle block
服务器的Url
和输出大小。def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") //根据ShuffleId获取ShuffleBlock的元数据,包含数据地址和数据的大小 val statuses = getStatuses(shuffleId) // statuses.synchronized { //调用MapOutputTracker的converMapStatus方法,给定一组映射状态和一系列的映射分区, //然后返回一个Tuple2序列,该序列里的的元素是元组,元组的第一个元素是BlockManagerId,第二个元素是也是一个元组 //这个元组的第一个元素是BlockId,第二个元素是block快的大小 //也就是说这个方法返回的是数据在哪个节点以及这个节点的那些block return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } }
MapOutputTracker
类里的getStatus
方法,该方法的作用就是利用ShuffleId
过去对应的MapStatus
(Block
的元数据),它的原理就是首先从给本地获取MapStatus
,如果没有就通过网络拉取MapStatus
。//这个方法利用了算法 //这个方法就是根据shuffleId获取MapStatus方法 private def getStatuses(shuffleId: Int): Array[MapStatus] = { //根据shuffleId获取本地的MapStatus数组,因为当ResultTask拉取MapStatus的时候 //会把它放到内存缓存中。 val statuses = mapStatuses.get(shuffleId).orNull //如果status为空,那么就说明本地没有对应的status,这样就会利用网络拉去MapStatus if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") //拉取的开始时间 val startTime = System.currentTimeMillis //定义MapStatus数组 var fetchedStatuses: Array[MapStatus] = null //由于是并行处理,可能其他Task也在利用网络拉取MapStatus //所以避免数据的同步问题需要加上synchronize关键字 fetching.synchronized { // Someone else is fetching it; wait for them to be done while (fetching.contains(shuffleId)) { try { //如果其他Task正在拉取数据,那么就等待它完成再继续执行 fetching.wait() } catch { case e: InterruptedException => } } // Either while we waited the fetch happened successfully, or // someone fetched it in between the get and the fetching.synchronized. //等待过后继续调用mapStatus的get方法获取MapStatus fetchedStatuses = mapStatuses.get(shuffleId).orNull if (fetchedStatuses == null) { // We have to do the fetch, get others to wait for us. //加入到内存缓存中,等待这个线程编程等待状态,因为上边的while循环就是 //将线程编程等待状态,需要调用notifyAll来唤醒所有的线程 fetching += shuffleId } } //如果fetchedStatuses还是等于空的话,就会真正的开始从MapOutputTracker中获取MapStatus if (fetchedStatuses == null) { // We won the race to fetch the statuses; do so logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { //调用askTracker方法,发送GetMapOutputStatuses消息获取MapStatus val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) //反序列化传过来的MapStatus数组对象 fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") //将拉取过来的MapStatus放入到内存缓冲中,内存缓存结构为HashMap mapStatuses.put(shuffleId, fetchedStatuses) } finally { fetching.synchronized { //移除当前执行完的fetch fetching -= shuffleId //由于当前线程正在执行的时候,其他线程正在处于等待状态 //需要调用notifyAll方法来唤醒其他线程,继续获取MapStatus fetching.notifyAll() } } } logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + s"${System.currentTimeMillis - startTime} ms") //如果fetchedStatuses不等于空的话就直接返回,如果为空就抛出异常 if (fetchedStatuses != null) { return fetchedStatuses } else { logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) } //如果本地有ShuffleId对应的MapStatus就直接返回 } else { return statuses } }
MapOutputTracker
类里的askTracker
方法,该方法的作用就是向MapOutputTrackerMasterEndpoint
发送GetMapOutputStatus
消息请求获取MapStatus
protected def askTracker[T: ClassTag](message: Any): T = { try { //向MapOutputTrackerMasterEndpoint发送GetMapOutputStatus消息,并设定请求失败重试的次数,超时时间使用默认的 trackerEndpoint.askWithRetry[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) throw new SparkException("Error communicating with MapOutputTracker", e) } }
MapOutputTrackerMasterEndpoint
类里的receiveAndReply
方法,该方法的作用就是接收ResultTask
发送过来的GetMapOutputStatus
消息,调用MapOutputTrackerMaster
的post
方法,获取MapStatus
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { //当MapOutputTrackerMasterEndpoint接收到ResultTask发送的GetMapOutputStatus消息后,调用MapOutputTrackerMaster的post方法 case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) //MapOutputTrackerMaster调用post方法,获取MapStatus val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") context.reply(true) stop() }
MapOutputTrackerMaster
的post
方法,该方法的作用就是将请求用到的GetMapOutputMessage
消息放入到队列里,后续会利用多线程的方式来执行请求,获取MapStatus
def post(message: GetMapOutputMessage): Unit = { //不直接发送请求,而是将GetMapOutputMessage放入请求MapStatus对队列中,利用多线程的方式来请求MapStatus mapOutputRequests.offer(message) }
MessageLoop
类是一个发送MapOutputMessage
消息的循环体,利用多线程的方式循环调用getSerializedMapOutputStatuses
方法从本地获取MapStatus
,然后返回给ResultTask
//该类是继承了Runnable抽象类的一个调度消息的循环体 private class MessageLoop extends Runnable { override def run(): Unit = { try { while (true) { try { val data = mapOutputRequests.take() if (data == PoisonPill) { // Put PoisonPill back so that other MessageLoops can see it. mapOutputRequests.offer(PoisonPill) return } val context = data.context val shuffleId = data.shuffleId val hostPort = context.senderAddress.hostPort logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) //获取MapStatus val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) //返回数据 context.reply(mapOutputStatuses) } catch { case NonFatal(e) => logError(e.getMessage, e) } } } catch { case ie: InterruptedException => // exit } } }
MapOutputTracker
类里的converMapStatus
方法,该方法的作用就是根据参数组装数据结构。//该方法的作用就是利用传过来的参数组装成一个数据结构,以供后续使用 //该数据结构就是一个序列里边的结构就是block数据在哪个节点上以及对应哪些block和block的大小 private def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { for (part <- startPartition until endPartition) { splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) } } } splitsByAddress.toSeq }
ShuffleBlockFetcherIterator
类里的initilize
方法//在初始化ShuffleBlockFetcherIterator对象时候会调用这个initilize方法 private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. context.addTaskCompletionListener(_ => cleanup()) // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order //打散远程的block的容器中的元素,放入队列中 fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight) // 从远程拉取block数据 fetchUpToMaxBytes() val numFetches = remoteRequests.size - fetchRequests.size logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) // 从本地拉取block数据 fetchLocalBlocks() logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime)) }
ShuffleBlockFetcherIterator
类里的splitLocalRemoteBlock
方法,该方法的作用就是根据block
所在的位置不同,封装不同的block
信息,为后续拉取block
数据做准备private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. //设置每次从5个节点上同时拉去数据 val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) //创建FetcherRequest容器 //FetchRequest:封装了远程block块的信息 val remoteRequests = new ArrayBuffer[FetchRequest] // Tracks total number of blocks (including zero sized blocks) //记录块的总数 var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size //判断当前block块是否在本地 if (address.executorId == blockManager.blockManagerId.executorId) { //过滤掉为0的block块,并把blockId缓存在内存中 localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) //记录要获取本地的块的总数 numBlocksToFetch += localBlocks.size //如果块不在本地,那么就需要从远程拉去块的信息 } else { //生成blockInfos的迭代器 val iterator = blockInfos.iterator //block数据量的边界值 var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { //获取远程block的Id以及大小 val (blockId, size) = iterator.next() // 判断远程block块大小是否为0 if (size > 0) { //将block信息加入到容器中 curBlocks += ((blockId, size)) //将blockId加入到内存缓存中,缓存结构为HashSet remoteBlocks += blockId //需要从远程拉取block块的个数 numBlocksToFetch += 1 //记录循环到现在为止的所有块的大小,目的是为了定义一个边界, //这个边界是为了防止拉取block数据的时候超出最大的允许拉取的block数据量 curRequestSize += size } else if (size < 0) { //如果远程的块的大小为0,跑出异常 throw new BlockException(blockId, "Negative block size " + size) } //因为拉去远程block块的时候只能并行从5个节点上拉取数据,当curRequestSize大于等于最大的请求数量 if (curRequestSize >= targetRequestSize) { //就会将block块封装成FetcherRequest然后加入到容器总 remoteRequests += new FetchRequest(address, curBlocks) //相当于清空curBlocks集合 curBlocks = new ArrayBuffer[(BlockId, Long)] logDebug(s"Creating fetch request of $curRequestSize at $address") //将curRequestSize curRequestSize = 0 } } // 因为block块信息遍历到最后curRequestSize >= targetRequestSize这个不成立 //所以就把最后一个block封装成FetchRequest加入到remoteRequests容器中 if (curBlocks.nonEmpty) { remoteRequests += new FetchRequest(address, curBlocks) } } } logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") remoteRequests }
ShuffleBlockFetcherIterator
类里的fetchUpToMaxBytes
方法,该方法的作用就是循环远程消息队列里的block
信息,发送请求获取block
信息private def fetchUpToMaxBytes(): Unit = { //while循环远程队列里的block信息,向远程发送请求获取block数据 while (fetchRequests.nonEmpty && (bytesInFlight == 0 || (reqsInFlight + 1 <= maxReqsInFlight && bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) { //调用sendRequest方法,发送请求获取block数据 sendRequest(fetchRequests.dequeue()) } }
ShuffleBlockFetcherIterator
类里的sendRequest
方法,该方法的作用是发送请求获取block
数据private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) bytesInFlight += req.size reqsInFlight += 1 // so we can look up the size of each blockID、 //将blockid与block大小的元组结构转换成Map结构 val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap //将block的Id放到HashSet容器中 val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) // 请求的远端地址 val address = req.address //ShuffleClient是一个抽象类,默认调用的是BlockTransferService这个子类的fetchBlock方法, //BlockTransferService里的fetchBlock方法也是一个抽象方法,这个方法是NettyBlockTransferService这个类实现的 shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, new BlockFetchingListener { override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { //请求成功 ShuffleBlockFetcherIterator.this.synchronized { if (!isZombie) { // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() remainingBlocks -= blockId results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, remainingBlocks.isEmpty)) logDebug("remainingBlocks: " + remainingBlocks) } } logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } //请求失败 override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) results.put(new FailureFetchResult(BlockId(blockId), address, e)) } } ) }
NettyBlockTransferService
的fetcherBlock
方法,该方的作用就是拉取block
数据。NettyBlockTransferService
必须在调用init方法后才能提供服务。这个方法在执行前,必须执行以下步骤才能成功拉取block数据- 创建
RpcServer
(实际是其子类NettyBlockRpcServer
) - 创建
TransportContext
- 创建
Rpc
客户端工厂TransportClientFactory
- 创建
Netty
服务器TransportServer
,可以修改属性spark.blockManager.port
改变TransportServer
的端口
override def fetchBlocks( host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() } } val maxRetries = transportConf.maxIORetries() if (maxRetries > 0) { // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's // a bug in this code. We should remove the if statement once we're sure of the stability. new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start() } else { blockFetchStarter.createAndStart(blockIds, listener) } } catch { case e: Exception => logError("Exception while beginning fetchBlocks", e) blockIds.foreach(listener.onBlockFetchFailure(_, e)) } }
- 创建
NettyBlockTransferService
的init
方法,override def init(blockDataManager: BlockDataManager): Unit = { //初始化NettyRpcServer,用于接受上传或者拉取block数据的请求 val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager)) clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager, securityManager.isSaslEncryptionEnabled())) } //初始化TransportContext,既可以创建Netty服务端,又可以创建Netty客户端 //transportConf主要控制Netty客户端与服务端的线程数量 //rpcHandle负责客户端请求服务端的时候,提供block的上传下载功能,其实就是NettyBlockRpcServer对象 transportContext = new TransportContext(transportConf, rpcHandler) //实例化一个能够创建Netty客户端的工厂类,用于创建Netty客户端 clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId logInfo(s"Server created on ${hostName}:${server.getPort}") }
NettyBlockRpcServer
的reveive
方法,该方法的作用就是接受拉取block
的请求override def receive( client: TransportClient, rpcMessage: ByteBuffer, responseContext: RpcResponseCallback): Unit = { val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") message match { //处理拉去block数据请求 case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. val (level: StorageLevel, classTag: ClassTag[_]) = { serializer .newInstance() .deserialize(ByteBuffer.wrap(uploadBlock.metadata)) .asInstanceOf[(StorageLevel, ClassTag[_])] } val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) val blockId = BlockId(uploadBlock.blockId) blockManager.putBlockData(blockId, data, level, classTag) responseContext.onSuccess(ByteBuffer.allocate(0)) } }
- 在上一讲Spark(十六)Executor执行Task的原理与源码分析(二)中我们已经对
Spark Core(十七)Spark的Shuffle原理与源码分析
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Suubyy/article/details/82023369
猜你喜欢
转载自blog.csdn.net/Suubyy/article/details/82023369
今日推荐
周排行