From ee779238d9e48c00b360879019caaebdee8b7b9c Mon Sep 17 00:00:00 2001 From: wakun Date: Fri, 14 Oct 2022 10:57:10 +0800 Subject: [PATCH] [CARMEL-6185] Expose row count for RepeatableIterator (#1074) * [CARMEL-6185] Expose row count for RepeatableIterator * fix code style * Fix code style * Fix UT * Update code * Update code --- .../apache/spark/InternalAccumulator.scala | 1 + .../apache/spark/executor/TaskMetrics.scala | 8 +++++ .../main/scala/org/apache/spark/rdd/RDD.scala | 8 ++--- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../spark/scheduler/IterableJobWaiter.scala | 12 +++++-- .../apache/spark/scheduler/JobListener.scala | 5 +++ .../spark/scheduler/TaskResultStore.scala | 33 +++++++++++------ .../apache/spark/util/JsonProtocolSuite.scala | 35 +++++++++++-------- .../spark/sql/execution/SparkPlan.scala | 6 ++-- .../spark/sql/SpillDirectResultSuite.scala | 4 +++ .../spark/sql/SpillIndirectResultSuite.scala | 5 +++ .../SparkExecuteStatementOperation.scala | 7 ++-- 12 files changed, 87 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index a39f5b3e85f6c..454fc14e31ed9 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -43,6 +43,7 @@ private[spark] object InternalAccumulator { val UPDATED_BLOCK_STATUSES = METRICS_PREFIX + "updatedBlockStatuses" val PRUNED_STATS = "index.prunedStats" val TEST_ACCUM = METRICS_PREFIX + "testAccumulator" + val RECORDS_OUTPUT = OUTPUT_METRICS_PREFIX + "recordsOutput" // scalastyle:off diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ab25efadc472c..2a164622c8012 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -57,6 +57,7 @@ class TaskMetrics private[spark] () extends Serializable { private val _peakExecutionMemory = new LongAccumulator private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] private val _prunedStats = new PrunedMetricsAccum + private val _recordsOutput = new LongAccumulator def prunedStats: PrunedMetricsAccum = _prunedStats /** @@ -113,6 +114,11 @@ class TaskMetrics private[spark] () extends Serializable { */ def peakExecutionMemory: Long = _peakExecutionMemory.sum + /** + * Total number of records output. + */ + def recordsOutput: Long = _recordsOutput.sum + /** * Storage statuses of any blocks that have been updated as a result of this task. * @@ -152,6 +158,7 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def setPrunedStats(v: List[PrunedStats]): Unit = { _prunedStats.setValue(v) } + private[spark] def setRecordsOutput(v: Long): Unit = _recordsOutput.setValue(v) /** * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted @@ -226,6 +233,7 @@ class TaskMetrics private[spark] () extends Serializable { PEAK_EXECUTION_MEMORY -> _peakExecutionMemory, UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses, PRUNED_STATS -> _prunedStats, + RECORDS_OUTPUT -> _recordsOutput, shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched, shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched, shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead, diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e9934424fee38..95cad1ef544c9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -324,12 +324,12 @@ abstract class RDD[T: ClassTag]( * expansion rate = (number of output rows in the task) / (number of input rows in task). */ final def expansionLimitedIterator(split: Partition, context: TaskContext): Iterator[T] = { - val innerItrator = iterator(split, context) + val innerIterator = iterator(split, context) if (maxExpandRate > 0) { new Iterator[T] { private var output = 0 override def hasNext: Boolean = { - innerItrator.hasNext + innerIterator.hasNext } override def next(): T = { output += 1 @@ -345,11 +345,11 @@ abstract class RDD[T: ClassTag]( } output = 0 } - innerItrator.next() + innerIterator.next() } } } else { - innerItrator + innerIterator } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 9e960782207c1..9d31fffbf16f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1942,7 +1942,7 @@ private[spark] class DAGScheduler( // taskSucceeded runs some user code that might throw an exception. Make sure // we are resilient against that. try { - job.listener.taskSucceeded(rt.outputId, event.result) + job.listener.taskSucceeded(rt.outputId, event.result, event.taskMetrics) } catch { case e: Throwable if !Utils.isFatalError(e) => // TODO: Perhaps we want to mark the resultStage as failed? diff --git a/core/src/main/scala/org/apache/spark/scheduler/IterableJobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/IterableJobWaiter.scala index 5812645c01d25..44582c9e414e4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/IterableJobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/IterableJobWaiter.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag +import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging /** @@ -48,6 +49,7 @@ private[spark] class IterableJobWaiter[U: ClassTag, R]( // to hold the result data as spilled files or in memory array private var spilledResultData: Option[Array[SpilledPartitionResult]] = None private val resultData: Array[U] = new Array[U](totalTasks) + private var allRowCount: Long = 0L // indicate whether the in memory result data has been cleaned after // the result data is spilled to disk @@ -66,7 +68,11 @@ private[spark] class IterableJobWaiter[U: ClassTag, R]( dagScheduler.cancelJob(jobId, None) } - override def taskSucceeded(index: Int, result: Any): Unit = { + override def taskSucceeded(index: Int, result: Any): Unit = + taskSucceeded(index, result, new TaskMetrics) + + override def taskSucceeded(index: Int, result: Any, taskMetrics: TaskMetrics): Unit = { + allRowCount += taskMetrics.recordsOutput result match { case spilledPartitionResult: Array[SpilledPartitionResult] => spilledResultData = Some(spilledPartitionResult) @@ -109,12 +115,12 @@ private[spark] class IterableJobWaiter[U: ClassTag, R]( if (spilledResultData.nonEmpty) { logInfo(s"Return result as a SpilledResultIterator for job $jobId " + s"with files ${spilledResultData.get.map(_.file.getPath).mkString(",")}") - SpilledResultIterator[U, R](spilledResultData.get, resultConverter, + SpilledResultIterator[U, R](spilledResultData.get, resultConverter, allRowCount, dagScheduler.sc.conf.getBoolean("spark.sql.thriftserver.cleanShareResultFiles", false), dagScheduler.sc.conf.getBoolean("spark.sql.thriftserver.shareResult", true)) } else { logInfo(s"Return result as a SimpleRepeatableIterator for job $jobId") - SimpleRepeatableIterator(resultData, resultConverter) + SimpleRepeatableIterator(resultData, resultConverter, allRowCount) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala index e0f7c8f02132d..fe1f318b7e2e4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import org.apache.spark.executor.TaskMetrics + /** * Interface used to listen for job completion or failure events after submitting a job to the * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole @@ -24,5 +26,8 @@ package org.apache.spark.scheduler */ private[spark] trait JobListener { def taskSucceeded(index: Int, result: Any): Unit + + def taskSucceeded(index: Int, result: Any, taskMetrics: TaskMetrics): Unit = + taskSucceeded(index, result) def jobFailed(exception: Exception): Unit } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultStore.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultStore.scala index 2383de710f2c5..fba2dfe098e67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultStore.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultStore.scala @@ -180,13 +180,18 @@ private[spark] case class SpilledPartitionResult( /** * the interface for iterator which supports read from the start again */ -private[spark] trait RepeatableIterator[T] extends Iterator[T] { +private[spark] abstract class RepeatableIterator[T](_rowCount: Long) extends Iterator[T] { def backToStart(): Unit def close(): Unit def copy(): RepeatableIterator[T] + + def rowCount(): Long = _rowCount + + // Do not call this method, length will < 0 if rowCount > Int.MAX_VALUE + override def length: Int = rowCount.toInt } /** @@ -194,8 +199,9 @@ private[spark] trait RepeatableIterator[T] extends Iterator[T] { */ private[spark] case class SimpleRepeatableIterator[T, U]( originData: Array[U], - resultConverter: U => Iterator[T]) - extends RepeatableIterator[T] { + resultConverter: U => Iterator[T], + _rowCount: Long) + extends RepeatableIterator[T](_rowCount) { private var it: Iterator[T] = originData.iterator.flatMap(resultConverter) @@ -210,24 +216,29 @@ private[spark] case class SimpleRepeatableIterator[T, U]( override def close(): Unit = { } - // length calculation is time consuming - override def length: Int = { - originData.iterator.flatMap(resultConverter).length - } - override def copy(): RepeatableIterator[T] = { - new SimpleRepeatableIterator[T, U](originData, resultConverter) + new SimpleRepeatableIterator[T, U](originData, resultConverter, rowCount) } } /** * The iterator implementation to read from spilled files + * data of spilledResults: Array[SpilledPartitionResult] + * file blockId offset length + * /data/yarn/tmp/file1, "temp_local_001", 0, 100 + * /data/yarn/tmp/file1, "temp_local_002", 100, 200 + * /data/yarn/tmp/file2, "temp_local_003", 0, 400 + * + * nextBatchStream() will clean the temp file and then read a new SpilledPartitionResult. + * readNextBatch() will convert the partition result to currentBatch: Iterator[R] */ private[spark] case class SpilledResultIterator[U, R]( spilledResults: Array[SpilledPartitionResult], converter: U => Iterator[R], + _rowCount: Long, cleanShareResultFiles: Boolean = false, - override val isTraversableAgain: Boolean) extends RepeatableIterator[R] with Logging { + override val isTraversableAgain: Boolean) + extends RepeatableIterator[R](_rowCount) with Logging { private val serializer = SparkEnv.get.serializer.newInstance() private val serializerManager = SparkEnv.get.serializerManager @@ -378,7 +389,7 @@ private[spark] case class SpilledResultIterator[U, R]( } override def copy(): RepeatableIterator[R] = { - new SpilledResultIterator[U, R](spilledResults, converter, cleanShareResultFiles, + new SpilledResultIterator[U, R](spilledResults, converter, rowCount, cleanShareResultFiles, isTraversableAgain) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index f77932d439fb7..3dd263379bb02 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -2093,104 +2093,111 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | { | "ID": 11, - | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", + | "Name": "${RECORDS_OUTPUT}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 12, - | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", + | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 13, - | "Name": "${shuffleRead.REMOTE_BYTES_READ}", + | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 14, - | "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}", + | "Name": "${shuffleRead.REMOTE_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 15, - | "Name": "${shuffleRead.LOCAL_BYTES_READ}", + | "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 16, - | "Name": "${shuffleRead.FETCH_WAIT_TIME}", + | "Name": "${shuffleRead.LOCAL_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 17, - | "Name": "${shuffleRead.RECORDS_READ}", + | "Name": "${shuffleRead.FETCH_WAIT_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 18, - | "Name": "${shuffleWrite.BYTES_WRITTEN}", + | "Name": "${shuffleRead.RECORDS_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 19, - | "Name": "${shuffleWrite.RECORDS_WRITTEN}", + | "Name": "${shuffleWrite.BYTES_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 20, - | "Name": "${shuffleWrite.WRITE_TIME}", + | "Name": "${shuffleWrite.RECORDS_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 21, + | "Name": "${shuffleWrite.WRITE_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 22, | "Name": "${input.BYTES_READ}", | "Update": 2100, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 22, + | "ID": 23, | "Name": "${input.RECORDS_READ}", | "Update": 21, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 23, + | "ID": 24, | "Name": "${output.BYTES_WRITTEN}", | "Update": 1200, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 24, + | "ID": 25, | "Name": "${output.RECORDS_WRITTEN}", | "Update": 12, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 25, + | "ID": 26, | "Name": "$TEST_ACCUM", | "Update": 0, | "Internal": true, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 1c78385d7a8ff..a09ccfbf8c0bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.AbstractIterator import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import org.apache.spark.{broadcast, SparkEnv, TaskKilledException} +import org.apache.spark.{broadcast, SparkEnv, TaskContext, TaskKilledException} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} @@ -362,6 +362,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ out.writeInt(-1) out.flush() out.close() + TaskContext.get().taskMetrics().setRecordsOutput(count) Iterator((count, bos.toByteArray)) } } @@ -544,7 +545,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ throw new IllegalArgumentException(s"Limit cannot exceed threshold ${conf.limitMaxRows}") } logInfo(s"Return limit result as a SimpleRepeatableIterator.") - SimpleRepeatableIterator[R, InternalRow](executeTake(n), row => Iterator(proj(row))) + val array = executeTake(n) + SimpleRepeatableIterator[R, InternalRow](array, row => Iterator(proj(row)), array.length) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SpillDirectResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SpillDirectResultSuite.scala index bd58ba86452fa..b7de4881fbdb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SpillDirectResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SpillDirectResultSuite.scala @@ -57,6 +57,7 @@ class SpillDirectResultSuite extends QueryTest with SQLTestUtils with SharedSpar val df = sql(query) val it = df.collectAsIterator() assert(it.isInstanceOf[SimpleRepeatableIterator[Any, Any]]) + assert(it.asInstanceOf[SimpleRepeatableIterator[Any, Any]].rowCount == 10) assert(it.length == 10) @@ -89,6 +90,7 @@ class SpillDirectResultSuite extends QueryTest with SQLTestUtils with SharedSpar val df = sql(query) val it = df.collectAsIterator() assert(it.isInstanceOf[SpilledResultIterator[Any, Any]]) + assert(it.asInstanceOf[SpilledResultIterator[Any, Any]].rowCount == 300) val rs = ArrayBuffer[Row]() while (it.hasNext) { @@ -115,6 +117,7 @@ class SpillDirectResultSuite extends QueryTest with SQLTestUtils with SharedSpar val df = sql(query) val it = df.collectAsIterator() assert(it.isInstanceOf[SpilledResultIterator[Any, Any]]) + assert(it.asInstanceOf[SpilledResultIterator[Any, Any]].rowCount == 300) val rs = ArrayBuffer[Row]() while (it.hasNext) { @@ -191,6 +194,7 @@ class SpillDirectResultSuite extends QueryTest with SQLTestUtils with SharedSpar val df = sql(query) val it = df.collectAsIterator() assert(it.isInstanceOf[SpilledResultIterator[Any, Any]]) + assert(it.asInstanceOf[SpilledResultIterator[Any, Any]].rowCount == 300) val rs = ArrayBuffer[Row]() while (it.hasNext) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SpillIndirectResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SpillIndirectResultSuite.scala index 3e906fd997082..44643ed56d786 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SpillIndirectResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SpillIndirectResultSuite.scala @@ -57,6 +57,7 @@ class SpillIndirectResultSuite extends QueryTest with SQLTestUtils with SharedSp val df = sql(query) val it = df.collectAsIterator() assert(it.isInstanceOf[SimpleRepeatableIterator[Any, Any]]) + assert(it.asInstanceOf[SimpleRepeatableIterator[Any, Any]].rowCount == 10) assert(it.length == 10) @@ -89,6 +90,7 @@ class SpillIndirectResultSuite extends QueryTest with SQLTestUtils with SharedSp val df = sql(query) val it = df.collectAsIterator() assert(it.isInstanceOf[SpilledResultIterator[Any, Any]]) + assert(it.asInstanceOf[SpilledResultIterator[Any, Any]].rowCount == 300) val rs = ArrayBuffer[Row]() while (it.hasNext) { @@ -115,6 +117,7 @@ class SpillIndirectResultSuite extends QueryTest with SQLTestUtils with SharedSp val df = sql(query) val it = df.collectAsIterator() assert(it.isInstanceOf[SimpleRepeatableIterator[Any, Any]]) + assert(it.asInstanceOf[SimpleRepeatableIterator[Any, Any]].rowCount == 100) assert(it.length == 100) @@ -148,6 +151,7 @@ class SpillIndirectResultSuite extends QueryTest with SQLTestUtils with SharedSp val df = sql(query) val it = df.collectAsIterator() assert(it.isInstanceOf[SimpleRepeatableIterator[Any, Any]]) + assert(it.asInstanceOf[SimpleRepeatableIterator[Any, Any]].rowCount == 100) assert(it.length == 100) @@ -189,6 +193,7 @@ class SpillIndirectResultSuite extends QueryTest with SQLTestUtils with SharedSp val df = sql(query) val it = df.collectAsIterator() assert(it.isInstanceOf[SpilledResultIterator[Any, Any]]) + assert(it.asInstanceOf[SpilledResultIterator[Any, Any]].rowCount == 300) val rs = ArrayBuffer[Row]() while (it.hasNext) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index b69ee2cfa28a2..3f5b1cb8af31c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -505,11 +505,10 @@ private[hive] class SparkExecuteStatementOperation( Some(result.collectAsIterator()) } - finalRowsLength = if (repeatableResultList.nonEmpty && - repeatableResultList.get.isInstanceOf[SpilledResultIterator[_, _]]) { + finalRowsLength = repeatableResultList.get.rowCount() + if (repeatableResultList.get.isInstanceOf[SpilledResultIterator[_, _]]) { isSpill = true - ErrRowCountType.SpillSizeUnknown - } else ErrRowCountType.Unknown + } logInfo(s"collectAsIterator for statement $statement") } else {