diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 89dbb6468d..65061282c9 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -319,6 +319,18 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_SHUFFLE_DIRECT_NATIVE_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directNative.enabled") + .category(CATEGORY_SHUFFLE) + .doc( + "When enabled, the native shuffle writer will directly execute the child native plan " + + "instead of reading intermediate batches via JNI. This optimization avoids the " + + "JNI round-trip for single-source native plans (e.g., Scan -> Filter -> Project). " + + "This is an experimental feature and is disabled by default.") + .internal() + .booleanConf + .createWithDefault(false) + val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode") .category(CATEGORY_SHUFFLE) .doc( diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index b5d15b41f4..5655c7e492 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -43,6 +43,11 @@ import org.apache.comet.serde.QueryPlanSerde.serializeDataType /** * A [[ShuffleWriter]] that will delegate shuffle write to native shuffle. + * + * @param childNativePlan + * When provided, the shuffle writer will execute this native plan directly and pipe its output + * to the ShuffleWriter, avoiding the JNI round-trip for intermediate batches. This is used for + * direct native execution optimization when the shuffle's child is a single-source native plan. */ class CometNativeShuffleWriter[K, V]( outputPartitioning: Partitioning, @@ -53,7 +58,8 @@ class CometNativeShuffleWriter[K, V]( mapId: Long, context: TaskContext, metricsReporter: ShuffleWriteMetricsReporter, - rangePartitionBounds: Option[Seq[InternalRow]] = None) + rangePartitionBounds: Option[Seq[InternalRow]] = None, + childNativePlan: Option[Operator] = None) extends ShuffleWriter[K, V] with Logging { @@ -163,150 +169,150 @@ class CometNativeShuffleWriter[K, V]( } private def getNativePlan(dataFile: String, indexFile: String): Operator = { - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") - val opBuilder = OperatorOuterClass.Operator.newBuilder() - - val scanTypes = outputAttributes.flatten { attr => - serializeDataType(attr.dataType) - } - - if (scanTypes.length == outputAttributes.length) { + // When childNativePlan is provided, we use it directly as the input to ShuffleWriter. + // Otherwise, we create a Scan operator that reads from JNI input ("ShuffleWriterInput"). + val inputOperator: Operator = childNativePlan.getOrElse { + val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + if (scanTypes.length != outputAttributes.length) { + throw new UnsupportedOperationException( + s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + } scanBuilder.addAllFields(scanTypes.asJava) + OperatorOuterClass.Operator.newBuilder().setScan(scanBuilder).build() + } - val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() - shuffleWriterBuilder.setOutputDataFile(dataFile) - shuffleWriterBuilder.setOutputIndexFile(indexFile) + val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() + shuffleWriterBuilder.setOutputDataFile(dataFile) + shuffleWriterBuilder.setOutputIndexFile(indexFile) - if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { - val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { - case "zstd" => CompressionCodec.Zstd - case "lz4" => CompressionCodec.Lz4 - case "snappy" => CompressionCodec.Snappy - case other => throw new UnsupportedOperationException(s"invalid codec: $other") - } - shuffleWriterBuilder.setCodec(codec) - } else { - shuffleWriterBuilder.setCodec(CompressionCodec.None) + if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { + val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { + case "zstd" => CompressionCodec.Zstd + case "lz4" => CompressionCodec.Lz4 + case "snappy" => CompressionCodec.Snappy + case other => throw new UnsupportedOperationException(s"invalid codec: $other") } - shuffleWriterBuilder.setCompressionLevel( - CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) - shuffleWriterBuilder.setWriteBufferSize( - CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt) + shuffleWriterBuilder.setCodec(codec) + } else { + shuffleWriterBuilder.setCodec(CompressionCodec.None) + } + shuffleWriterBuilder.setCompressionLevel( + CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) + shuffleWriterBuilder.setWriteBufferSize( + CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt) - outputPartitioning match { - case p if isSinglePartitioning(p) => - val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() + outputPartitioning match { + case p if isSinglePartitioning(p) => + val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setSinglePartition(partitioning).build()) - case _: HashPartitioning => - val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setSinglePartition(partitioning).build()) + case _: HashPartitioning => + val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] - val partitioning = PartitioningOuterClass.HashPartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) + val partitioning = PartitioningOuterClass.HashPartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) - val partitionExprs = hashPartitioning.expressions - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + val partitionExprs = hashPartitioning.expressions + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (partitionExprs.length != hashPartitioning.expressions.length) { - throw new UnsupportedOperationException( - s"Partitioning $hashPartitioning is not supported.") - } + if (partitionExprs.length != hashPartitioning.expressions.length) { + throw new UnsupportedOperationException( + s"Partitioning $hashPartitioning is not supported.") + } - partitioning.addAllHashExpression(partitionExprs.asJava) - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setHashPartition(partitioning).build()) - case _: RangePartitioning => - val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] - - val partitioning = PartitioningOuterClass.RangePartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) - - // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering - // DataFusion will deduplicate identical sort expressions in LexOrdering, - // so we need to transform boundary rows to match the deduplicated structure - val seenExprs = mutable.HashSet[Expression]() - val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) - - rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => - if (seenExprs.contains(sortOrder.child)) { - deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion - } else { - seenExprs += sortOrder.child - deduplicationMap += (idx -> true) // Will be kept by DataFusion - } + partitioning.addAllHashExpression(partitionExprs.asJava) + + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setHashPartition(partitioning).build()) + case _: RangePartitioning => + val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] + + val partitioning = PartitioningOuterClass.RangePartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + + // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering + // DataFusion will deduplicate identical sort expressions in LexOrdering, + // so we need to transform boundary rows to match the deduplicated structure + val seenExprs = mutable.HashSet[Expression]() + val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) + + rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => + if (seenExprs.contains(sortOrder.child)) { + deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion + } else { + seenExprs += sortOrder.child + deduplicationMap += (idx -> true) // Will be kept by DataFusion } + } - { - // Serialize the ordering expressions for comparisons - val orderingExprs = rangePartitioning.ordering - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (orderingExprs.length != rangePartitioning.ordering.length) { - throw new UnsupportedOperationException( - s"Partitioning $rangePartitioning is not supported.") - } - partitioning.addAllSortOrders(orderingExprs.asJava) + { + // Serialize the ordering expressions for comparisons + val orderingExprs = rangePartitioning.ordering + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + if (orderingExprs.length != rangePartitioning.ordering.length) { + throw new UnsupportedOperationException( + s"Partitioning $rangePartitioning is not supported.") } + partitioning.addAllSortOrders(orderingExprs.asJava) + } - // Convert Spark's sequence of InternalRows that represent partitioning boundaries to - // sequences of Literals, where each outer entry represents a boundary row, and each - // internal entry is a value in that row. In other words, these are stored in row major - // order, not column major - val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) - - // Transform boundary rows to match DataFusion's deduplicated structure - val transformedBoundaryExprs: Seq[Seq[Literal]] = - rangePartitionBounds.get.map((row: InternalRow) => { - // For every InternalRow, map its values to Literals - val allLiterals = - row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => - Literal(value, valueType) - } - - // Keep only the literals that correspond to non-deduplicated expressions - allLiterals - .zip(deduplicationMap) - .filter(_._2._2) // Keep only where isKept = true - .map(_._1) // Extract the literal + // Convert Spark's sequence of InternalRows that represent partitioning boundaries to + // sequences of Literals, where each outer entry represents a boundary row, and each + // internal entry is a value in that row. In other words, these are stored in row major + // order, not column major + val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) + + // Transform boundary rows to match DataFusion's deduplicated structure + val transformedBoundaryExprs: Seq[Seq[Literal]] = + rangePartitionBounds.get.map((row: InternalRow) => { + // For every InternalRow, map its values to Literals + val allLiterals = + row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => + Literal(value, valueType) + } + + // Keep only the literals that correspond to non-deduplicated expressions + allLiterals + .zip(deduplicationMap) + .filter(_._2._2) // Keep only where isKept = true + .map(_._1) // Extract the literal + }) + + { + // Convert the sequences of Literals to a collection of serialized BoundaryRows + val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs + .map((rowLiterals: Seq[Literal]) => { + // Serialize each sequence of Literals as a BoundaryRow + val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); + val serializedExprs = + rowLiterals.map(lit_value => + QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) + rowBuilder.addAllPartitionBounds(serializedExprs.asJava) + rowBuilder.build() }) + partitioning.addAllBoundaryRows(boundaryRows.asJava) + } - { - // Convert the sequences of Literals to a collection of serialized BoundaryRows - val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs - .map((rowLiterals: Seq[Literal]) => { - // Serialize each sequence of Literals as a BoundaryRow - val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); - val serializedExprs = - rowLiterals.map(lit_value => - QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) - rowBuilder.addAllPartitionBounds(serializedExprs.asJava) - rowBuilder.build() - }) - partitioning.addAllBoundaryRows(boundaryRows.asJava) - } - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setRangePartition(partitioning).build()) - - case _ => - throw new UnsupportedOperationException( - s"Partitioning $outputPartitioning is not supported.") - } + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setRangePartition(partitioning).build()) - val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() - shuffleWriterOpBuilder - .setShuffleWriter(shuffleWriterBuilder) - .addChildren(opBuilder.setScan(scanBuilder).build()) - .build() - } else { - // There are unsupported scan type - throw new UnsupportedOperationException( - s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + case _ => + throw new UnsupportedOperationException( + s"Partitioning $outputPartitioning is not supported.") } + + val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() + shuffleWriterOpBuilder + .setShuffleWriter(shuffleWriterBuilder) + .addChildren(inputOperator) + .build() } override def stop(success: Boolean): Option[MapStatus] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala index 2b74e5a168..3528b6d2c9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType +import org.apache.comet.serde.OperatorOuterClass.Operator + /** * A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle. */ @@ -49,7 +51,9 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val outputAttributes: Seq[Attribute] = Seq.empty, val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty, val numParts: Int = 0, - val rangePartitionBounds: Option[Seq[InternalRow]] = None) + val rangePartitionBounds: Option[Seq[InternalRow]] = None, + // For direct native execution: the child's native plan to compose with ShuffleWriter + val childNativePlan: Option[Operator] = None) extends ShuffleDependency[K, V, C]( _rdd, partitioner, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 1805711d01..96b8bcbacc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -34,8 +34,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometPlan, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometSinkPlaceHolder} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} @@ -52,6 +53,7 @@ import org.apache.comet.CometConf import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo} import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported} +import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.operator.CometSink import org.apache.comet.shims.ShimCometShuffleExchangeExec @@ -89,9 +91,94 @@ case class CometShuffleExchangeExec( private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + /** + * Information about direct native execution optimization. When the child is a single-source + * native plan with a fully native scan (CometNativeScanExec), we can pass the child's native + * plan to the shuffle writer and execute: Scan -> Filter -> Project -> ShuffleWriter all in + * native code, avoiding the JNI round-trip for intermediate batches. + * + * Currently only supports CometNativeScanExec (fully native scans that read files directly via + * DataFusion). JVM scan wrappers (CometScanExec, CometBatchScanExec) still require JNI input + * and are not optimized. + */ + @transient private lazy val directNativeExecutionInfo: Option[DirectNativeExecutionInfo] = { + if (!CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.get()) { + None + } else if (shuffleType != CometNativeShuffle) { + None + } else { + // Check if direct native execution is possible + outputPartitioning match { + case _: RangePartitioning => + // RangePartitioning requires sampling the data to compute bounds, + // which requires executing the child plan. Fall back to current behavior. + None + case _ => + child match { + case nativeChild: CometNativeExec => + // Find input sources using foreachUntilCometInput + val inputSources = scala.collection.mutable.ArrayBuffer.empty[SparkPlan] + nativeChild.foreachUntilCometInput(nativeChild)(inputSources += _) + + // Only optimize single-source native scan case for now + // JVM scan wrappers (CometScanExec, CometBatchScanExec) still need JNI input, + // so we don't optimize those yet + // Check if the plan contains subqueries (e.g., bloom filters with might_contain). + // Subqueries are registered with the parent execution context ID, but direct + // native shuffle creates a new execution context, so subquery lookup would fail. + val containsSubquery = nativeChild.exists { p => + p.expressions.exists(_.exists(_.isInstanceOf[ScalarSubquery])) + } + if (containsSubquery) { + // Fall back to avoid subquery lookup failures + None + } else if (inputSources.size == 1) { + inputSources.head match { + case scan: CometNativeScanExec => + // Fully native scan - no JNI input needed, native code reads files directly + // Get the partition count from the underlying scan + val numPartitions = scan.originalPlan.inputRDD.getNumPartitions + Some(DirectNativeExecutionInfo(nativeChild.nativeOp, numPartitions)) + case _ => + // Other input sources (JVM scans, shuffle, broadcast, etc.) - fall back + None + } + } else { + // Multiple input sources (joins, unions) - fall back for now + None + } + case _ => + None + } + } + } + } + + /** + * Returns true if direct native execution optimization is being used for this shuffle. This is + * primarily intended for testing to verify the optimization is applied correctly. + */ + def isDirectNativeExecution: Boolean = directNativeExecutionInfo.isDefined + + /** + * Creates an RDD that provides empty iterators for each partition. Used when direct native + * execution is enabled - the shuffle writer will execute the full native plan which reads data + * directly (no JNI input needed). + */ + private def createEmptyPartitionRDD(numPartitions: Int): RDD[ColumnarBatch] = { + sparkContext.parallelize(Seq.empty[ColumnarBatch], numPartitions) + } + @transient lazy val inputRDD: RDD[_] = if (shuffleType == CometNativeShuffle) { - // CometNativeShuffle assumes that the input plan is Comet plan. - child.executeColumnar() + directNativeExecutionInfo match { + case Some(info) => + // Direct native execution: create an RDD with empty partitions. + // The shuffle writer will execute the full native plan which reads data directly. + createEmptyPartitionRDD(info.numPartitions) + case None => + // Fall back to current behavior: execute child and pass intermediate batches + child.executeColumnar() + } } else if (shuffleType == CometColumnarShuffle) { // CometColumnarShuffle uses Spark's row-based execute() API. For Spark row-based plans, // rows flow directly. For Comet native plans, their doExecute() wraps with ColumnarToRowExec @@ -142,7 +229,8 @@ case class CometShuffleExchangeExec( child.output, outputPartitioning, serializer, - metrics) + metrics, + directNativeExecutionInfo.map(_.childNativePlan)) metrics("numPartitions").set(dep.partitioner.numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( @@ -538,7 +626,9 @@ object CometShuffleExchangeExec outputAttributes: Seq[Attribute], outputPartitioning: Partitioning, serializer: Serializer, - metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + metrics: Map[String, SQLMetric], + childNativePlan: Option[Operator] = None) + : ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { val numParts = rdd.getNumPartitions // The code block below is mostly brought over from @@ -605,7 +695,8 @@ object CometShuffleExchangeExec outputAttributes = outputAttributes, shuffleWriteMetrics = metrics, numParts = numParts, - rangePartitionBounds = rangePartitionBounds) + rangePartitionBounds = rangePartitionBounds, + childNativePlan = childNativePlan) dependency } @@ -810,3 +901,15 @@ object CometShuffleExchangeExec dependency } } + +/** + * Information needed for direct native execution optimization. + * + * @param childNativePlan + * The child's native operator plan to compose with ShuffleWriter + * @param numPartitions + * The number of partitions (from the underlying scan) + */ +private[shuffle] case class DirectNativeExecutionInfo( + childNativePlan: Operator, + numPartitions: Int) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index aa47dfa166..367ec4a90e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -238,7 +238,8 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { mapId, context, metrics, - dep.rangePartitionBounds) + dep.rangePartitionBounds, + dep.childNativePlan) case bypassMergeSortHandle: CometBypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new CometBypassMergeSortShuffleWriter( env.blockManager, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala new file mode 100644 index 0000000000..05c608310c --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.{CometTestBase, DataFrame} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.col + +import org.apache.comet.CometConf + +/** + * Test suite for the direct native shuffle execution optimization. + * + * This optimization allows the native shuffle writer to directly execute the child native plan + * instead of reading intermediate batches via JNI. This avoids the JNI round-trip for + * single-source native plans (e.g., Scan -> Filter -> Project -> Shuffle). + */ +class CometDirectNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "native", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion", + CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "true") { + testFun + } + } + } + + import testImplicits._ + + test("direct native execution: simple scan with hash partitioning") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // Verify the optimization is applied + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1, "Expected exactly one shuffle") + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should be enabled for single-source native scan") + + // Verify correctness + checkSparkAnswer(df) + } + } + + test("direct native execution: scan with filter and project") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + val df = sql("SELECT _1, _2 * 2 as doubled FROM tbl WHERE _1 > 10") + .repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with filter and project") + + checkSparkAnswer(df) + } + } + + test("direct native execution: single partition") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(1) + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with single partition") + + checkSparkAnswer(df) + } + } + + test("direct native execution: multiple hash columns") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1", $"_2") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with multiple hash columns") + + checkSparkAnswer(df) + } + } + + test("direct native execution: aggregation before shuffle") { + withParquetTable((0 until 100).map(i => (i % 10, (i + 1).toLong)), "tbl") { + val df = sql("SELECT _1, SUM(_2) as total FROM tbl GROUP BY _1") + .repartition(5, col("_1")) + + // This involves partial aggregation -> shuffle -> final aggregation + // The direct native execution applies to the shuffle that reads from the partial agg + checkSparkAnswer(df) + } + } + + test("direct native execution disabled: config is false") { + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "false") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + !shuffles.head.isDirectNativeExecution, + "Direct native execution should be disabled when config is false") + + checkSparkAnswer(df) + } + } + } + + test("direct native execution disabled: range partitioning") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartitionByRange(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + !shuffles.head.isDirectNativeExecution, + "Direct native execution should not be used for range partitioning") + + checkSparkAnswer(df) + } + } + + test("direct native execution disabled: JVM columnar shuffle mode") { + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // JVM shuffle mode uses CometColumnarShuffle, not CometNativeShuffle + val shuffles = findShuffleExchanges(df) + shuffles.foreach { shuffle => + assert( + !shuffle.isDirectNativeExecution, + "Direct native execution should not be used with JVM shuffle mode") + } + + checkSparkAnswer(df) + } + } + } + + test("direct native execution: multiple shuffles in same query") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl") + .repartition(10, $"_1") + .select($"_1", $"_2" + 1 as "_2_plus") + .repartition(5, $"_2_plus") + + // First shuffle reads from scan, second reads from previous shuffle output + // Only the first shuffle should use direct native execution + // AQE might combine some shuffles, so just verify results are correct + checkSparkAnswer(df) + } + } + + test("direct native execution: various data types") { + withParquetTable( + (0 until 50).map(i => + (i, i.toLong, i.toFloat, i.toDouble, i.toString, i % 2 == 0, BigDecimal(i))), + "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + checkSparkAnswer(df) + } + } + + test("direct native execution: complex filter and multiple projections") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i % 5)), "tbl") { + val df = sql(""" + |SELECT _1 * 2 as doubled, + | _2 + _3 as sum_col, + | _1 + _2 as combined + |FROM tbl + |WHERE _1 > 20 AND _3 < 3 + |""".stripMargin) + .repartition(10, col("doubled")) + + // Note: Native shuffle might fall back depending on expression support + // Just verify correctness - the optimization is best-effort + checkSparkAnswer(df) + } + } + + test("direct native execution: results match non-optimized path") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + // Run with optimization enabled + val dfOptimized = sql("SELECT _1, _2 FROM tbl WHERE _1 > 50").repartition(10, $"_1") + val optimizedResult = dfOptimized.collect().sortBy(_.getInt(0)) + + // Run with optimization disabled and collect results + var nonOptimizedResult: Array[org.apache.spark.sql.Row] = Array.empty + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "false") { + val dfNonOptimized = sql("SELECT _1, _2 FROM tbl WHERE _1 > 50").repartition(10, $"_1") + nonOptimizedResult = dfNonOptimized.collect().sortBy(_.getInt(0)) + } + + // Results should match + assert(optimizedResult.length == nonOptimizedResult.length, "Row counts should match") + optimizedResult.zip(nonOptimizedResult).foreach { case (opt, nonOpt) => + assert(opt == nonOpt, s"Rows should match: $opt vs $nonOpt") + } + } + } + + test("direct native execution: large number of partitions") { + withParquetTable((0 until 1000).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(201, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + checkSparkAnswer(df) + } + } + + test("direct native execution: empty table") { + withParquetTable(Seq.empty[(Int, Long)], "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // Should handle empty tables gracefully + val result = df.collect() + assert(result.isEmpty) + } + } + + test("direct native execution: all rows filtered out") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl WHERE _1 > 1000").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + val result = df.collect() + assert(result.isEmpty, "Result should be empty when all rows are filtered") + } + } + + /** + * Helper method to find CometShuffleExchangeExec nodes in a DataFrame's execution plan. + */ + private def findShuffleExchanges(df: DataFrame): Seq[CometShuffleExchangeExec] = { + val plan = stripAQEPlan(df.queryExecution.executedPlan) + plan.collect { case s: CometShuffleExchangeExec => s } + } +}