Spark源码分析:从collect入手

要研究spark的源码,输入和输出是很好的切入点。Dataset的基本操作就是collect了。

// spark/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)

private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
  SQLExecution.withNewExecutionId(qe, Some(name)) {
    qe.executedPlan.resetMetrics()
    action(qe.executedPlan)
  }
}

private def collectFromPlan(plan: SparkPlan): Array[T] = {
  val fromRow = resolvedEnc.createDeserializer()
  plan.executeCollect().map(fromRow)
}

这段代码看上去有点绕,我们做一些语义上的替换,就变成下面这样。

def collect(): Array[T] = {
  queryExecution.executedPlan.executeCollect().map(fromRow)
}

这样看起来简单多了,下面再逐步深入,看executedPlan和executeCollect做了什么。

executedPlan

executedPlan是QueryExecution的一个方法,我们来跟踪一下。

lazy val executedPlan: SparkPlan = {
  assertOptimized()
  executePhase(QueryPlanningTracker.PLANNING) {
    QueryExecution.prepareForExecution(preparations, sparkPlan.clone())
  }
}

在阅读代码的过程中,发现Spark里有一些helper函数,比如上面的executePhase,主要目的是统计运行时间,这类函数不影响主要流程,为了便于理解核心的逻辑,我们可以把它们简化掉。

// sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
lazy val executedPlan: SparkPlan = {
  assertOptimized()
  QueryExecution.prepareForExecution(preparations, sparkPlan.clone())
}

lazy val optimizedPlan: LogicalPlan = {
  val plan = sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker)
  plan.setAnalyzed()
  plan
}

lazy val withCachedData: LogicalPlan = sparkSession.withActive {
  assertAnalyzed()
  assertSupported()
  sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone())
}

lazy val analyzed: LogicalPlan = {
  sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker)
}

lazy val sparkPlan: SparkPlan = {
  QueryExecution.createSparkPlan(sparkSession, planner, optimizedPlan.clone())
}

上面的代码列出了executedPlan及它所调用的函数。在调用 assertOptimized 的时候,真正执行的是optimizedPlan。同样的,withCachedData又会调用analyzed。所以 logicalPlan先经过analyzed,再optimized。再通过createSparkPlan转为SparkPlan,最后再变成executedPlan。整个过程如下图所示。

经过这些操作,就是把LogicalPlan转为了SparkPlan。

executeCollect

def executeCollect(): Array[InternalRow] = {
  val byteArrayRdd = getByteArrayRdd()

  val results = ArrayBuffer[InternalRow]()
  byteArrayRdd.collect().foreach { countAndBytes =>
    decodeUnsafeRows(countAndBytes._2).foreach(results.+=)
  }
  results.toArray
}

executeCollect看起来也非常的简单,就是先得到一个rdd,然后通过rdd的collect方法再得到InternalRow类型的数组。在Dataset的collect方法中,再对这个数组执行 .map(fromRow) 就得到真实类型的数组了。

getByteArrayRdd 方法中,调用了execute,而execute又调用了doExecute方法。而 doExecute 是个抽象函数。

// sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
private def getByteArrayRdd(n: Int = -1, takeFromEnd: Boolean = false): RDD[(Long, Array[Byte])] = {
  execute().mapPartitionsInternal { iter =>
  }
}

final def execute(): RDD[InternalRow] = executeQuery {
  if (isCanonicalizedPlan) {
    throw new IllegalStateException("A canonicalized plan is not supposed to be executed.")
  }
  doExecute()
}

protected def doExecute(): RDD[InternalRow]

SparkPlan有很多的子类,doExecute的具体逻辑跟子类相关。而到底是哪个子类,跟DataFrame是怎么构建的又有关。比如,如果读取本地的csv文件来生成DataFrame,那这里的SparkPlan类型就是 FileSourceScanExec

fromRow

fromRow是这样创建的: val fromRow = resolvedEnc.createDeserializer() 。就是Deserializer类的一个实例,然后支持apply方法,所以可以直接map。它所做的事情就是把 InternalRow 转为实际类型 T

case class ExpressionEncoder[T] {
  def createDeserializer(): Deserializer[T] = new Deserializer[T](optimizedDeserializer)
}

object ExpressionEncoder {
  class Deserializer[T](private val expressions: Seq[Expression])
      extends (InternalRow => T) with Serializable {
    override def apply(row: InternalRow): T = {
      // ...
    }
  }
}

总结

collect 大概经历了这些环节:

  1. 将LogicalPlan转换为SparkPlan。

  2. 拿到RDD,再用RDD的collect获得 InternalRow 类型的数组。

  3. 通过Deserializer,将 InternalRow 转换为 T

也许你会觉得,看完这些又好像什么都没看,因为细节我们没有再深入下去。其实看代码就是要把握好一个度,知道什么时候应该深入细节,什么时候应该跳出来看整体框架。如果一直想搞清楚细节,就会无穷无尽的深入下去。过度沉溺于细节,反而会看不清楚整个的脉络和架构。

虽然我们还不了解细节,但至少我们脑子里已经产生了许多疑问。

  1. LogicalPlan和SparkPlan是啥,从LogicalPlan到SparkPlan中间那几步到底是处理什么事情?具体使用哪个子类是如何确定的?

  2. 看起来还是最终通过rdd的collect方法实现的,那rdd是怎么生成的?rdd的collect又做了什么操作?

  3. Deserializer是如何把InternalRow转为T的?

带着这些疑问,我们可以再继续深入下去。