Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,18 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(true)

val COMET_SHUFFLE_DIRECT_NATIVE_ENABLED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directNative.enabled")
.category(CATEGORY_SHUFFLE)
.doc(
"When enabled, the native shuffle writer will directly execute the child native plan " +
"instead of reading intermediate batches via JNI. This optimization avoids the " +
"JNI round-trip for single-source native plans (e.g., Scan -> Filter -> Project). " +
"This is an experimental feature and is disabled by default.")
.internal()
.booleanConf
.createWithDefault(false)

val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode")
.category(CATEGORY_SHUFFLE)
.doc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ import org.apache.comet.serde.QueryPlanSerde.serializeDataType

/**
* A [[ShuffleWriter]] that will delegate shuffle write to native shuffle.
*
* @param childNativePlan
* When provided, the shuffle writer will execute this native plan directly and pipe its output
* to the ShuffleWriter, avoiding the JNI round-trip for intermediate batches. This is used for
* direct native execution optimization when the shuffle's child is a single-source native plan.
*/
class CometNativeShuffleWriter[K, V](
outputPartitioning: Partitioning,
Expand All @@ -53,7 +58,8 @@ class CometNativeShuffleWriter[K, V](
mapId: Long,
context: TaskContext,
metricsReporter: ShuffleWriteMetricsReporter,
rangePartitionBounds: Option[Seq[InternalRow]] = None)
rangePartitionBounds: Option[Seq[InternalRow]] = None,
childNativePlan: Option[Operator] = None)
extends ShuffleWriter[K, V]
with Logging {

Expand Down Expand Up @@ -163,150 +169,150 @@ class CometNativeShuffleWriter[K, V](
}

private def getNativePlan(dataFile: String, indexFile: String): Operator = {
val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput")
val opBuilder = OperatorOuterClass.Operator.newBuilder()

val scanTypes = outputAttributes.flatten { attr =>
serializeDataType(attr.dataType)
}

if (scanTypes.length == outputAttributes.length) {
// When childNativePlan is provided, we use it directly as the input to ShuffleWriter.
// Otherwise, we create a Scan operator that reads from JNI input ("ShuffleWriterInput").
val inputOperator: Operator = childNativePlan.getOrElse {
val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput")
val scanTypes = outputAttributes.flatten { attr =>
serializeDataType(attr.dataType)
}
if (scanTypes.length != outputAttributes.length) {
throw new UnsupportedOperationException(
s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.")
}
scanBuilder.addAllFields(scanTypes.asJava)
OperatorOuterClass.Operator.newBuilder().setScan(scanBuilder).build()
}

val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder()
shuffleWriterBuilder.setOutputDataFile(dataFile)
shuffleWriterBuilder.setOutputIndexFile(indexFile)
val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder()
shuffleWriterBuilder.setOutputDataFile(dataFile)
shuffleWriterBuilder.setOutputIndexFile(indexFile)

if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match {
case "zstd" => CompressionCodec.Zstd
case "lz4" => CompressionCodec.Lz4
case "snappy" => CompressionCodec.Snappy
case other => throw new UnsupportedOperationException(s"invalid codec: $other")
}
shuffleWriterBuilder.setCodec(codec)
} else {
shuffleWriterBuilder.setCodec(CompressionCodec.None)
if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match {
case "zstd" => CompressionCodec.Zstd
case "lz4" => CompressionCodec.Lz4
case "snappy" => CompressionCodec.Snappy
case other => throw new UnsupportedOperationException(s"invalid codec: $other")
}
shuffleWriterBuilder.setCompressionLevel(
CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get)
shuffleWriterBuilder.setWriteBufferSize(
CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt)
shuffleWriterBuilder.setCodec(codec)
} else {
shuffleWriterBuilder.setCodec(CompressionCodec.None)
}
shuffleWriterBuilder.setCompressionLevel(
CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get)
shuffleWriterBuilder.setWriteBufferSize(
CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt)

outputPartitioning match {
case p if isSinglePartitioning(p) =>
val partitioning = PartitioningOuterClass.SinglePartition.newBuilder()
outputPartitioning match {
case p if isSinglePartitioning(p) =>
val partitioning = PartitioningOuterClass.SinglePartition.newBuilder()

val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder()
shuffleWriterBuilder.setPartitioning(
partitioningBuilder.setSinglePartition(partitioning).build())
case _: HashPartitioning =>
val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning]
val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder()
shuffleWriterBuilder.setPartitioning(
partitioningBuilder.setSinglePartition(partitioning).build())
case _: HashPartitioning =>
val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning]

val partitioning = PartitioningOuterClass.HashPartition.newBuilder()
partitioning.setNumPartitions(outputPartitioning.numPartitions)
val partitioning = PartitioningOuterClass.HashPartition.newBuilder()
partitioning.setNumPartitions(outputPartitioning.numPartitions)

val partitionExprs = hashPartitioning.expressions
.flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes))
val partitionExprs = hashPartitioning.expressions
.flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes))

if (partitionExprs.length != hashPartitioning.expressions.length) {
throw new UnsupportedOperationException(
s"Partitioning $hashPartitioning is not supported.")
}
if (partitionExprs.length != hashPartitioning.expressions.length) {
throw new UnsupportedOperationException(
s"Partitioning $hashPartitioning is not supported.")
}

partitioning.addAllHashExpression(partitionExprs.asJava)

val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder()
shuffleWriterBuilder.setPartitioning(
partitioningBuilder.setHashPartition(partitioning).build())
case _: RangePartitioning =>
val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning]

val partitioning = PartitioningOuterClass.RangePartition.newBuilder()
partitioning.setNumPartitions(outputPartitioning.numPartitions)

// Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering
// DataFusion will deduplicate identical sort expressions in LexOrdering,
// so we need to transform boundary rows to match the deduplicated structure
val seenExprs = mutable.HashSet[Expression]()
val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept)

rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) =>
if (seenExprs.contains(sortOrder.child)) {
deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion
} else {
seenExprs += sortOrder.child
deduplicationMap += (idx -> true) // Will be kept by DataFusion
}
partitioning.addAllHashExpression(partitionExprs.asJava)

val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder()
shuffleWriterBuilder.setPartitioning(
partitioningBuilder.setHashPartition(partitioning).build())
case _: RangePartitioning =>
val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning]

val partitioning = PartitioningOuterClass.RangePartition.newBuilder()
partitioning.setNumPartitions(outputPartitioning.numPartitions)

// Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering
// DataFusion will deduplicate identical sort expressions in LexOrdering,
// so we need to transform boundary rows to match the deduplicated structure
val seenExprs = mutable.HashSet[Expression]()
val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept)

rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) =>
if (seenExprs.contains(sortOrder.child)) {
deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion
} else {
seenExprs += sortOrder.child
deduplicationMap += (idx -> true) // Will be kept by DataFusion
}
}

{
// Serialize the ordering expressions for comparisons
val orderingExprs = rangePartitioning.ordering
.flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes))
if (orderingExprs.length != rangePartitioning.ordering.length) {
throw new UnsupportedOperationException(
s"Partitioning $rangePartitioning is not supported.")
}
partitioning.addAllSortOrders(orderingExprs.asJava)
{
// Serialize the ordering expressions for comparisons
val orderingExprs = rangePartitioning.ordering
.flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes))
if (orderingExprs.length != rangePartitioning.ordering.length) {
throw new UnsupportedOperationException(
s"Partitioning $rangePartitioning is not supported.")
}
partitioning.addAllSortOrders(orderingExprs.asJava)
}

// Convert Spark's sequence of InternalRows that represent partitioning boundaries to
// sequences of Literals, where each outer entry represents a boundary row, and each
// internal entry is a value in that row. In other words, these are stored in row major
// order, not column major
val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType))

// Transform boundary rows to match DataFusion's deduplicated structure
val transformedBoundaryExprs: Seq[Seq[Literal]] =
rangePartitionBounds.get.map((row: InternalRow) => {
// For every InternalRow, map its values to Literals
val allLiterals =
row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) =>
Literal(value, valueType)
}

// Keep only the literals that correspond to non-deduplicated expressions
allLiterals
.zip(deduplicationMap)
.filter(_._2._2) // Keep only where isKept = true
.map(_._1) // Extract the literal
// Convert Spark's sequence of InternalRows that represent partitioning boundaries to
// sequences of Literals, where each outer entry represents a boundary row, and each
// internal entry is a value in that row. In other words, these are stored in row major
// order, not column major
val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType))

// Transform boundary rows to match DataFusion's deduplicated structure
val transformedBoundaryExprs: Seq[Seq[Literal]] =
rangePartitionBounds.get.map((row: InternalRow) => {
// For every InternalRow, map its values to Literals
val allLiterals =
row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) =>
Literal(value, valueType)
}

// Keep only the literals that correspond to non-deduplicated expressions
allLiterals
.zip(deduplicationMap)
.filter(_._2._2) // Keep only where isKept = true
.map(_._1) // Extract the literal
})

{
// Convert the sequences of Literals to a collection of serialized BoundaryRows
val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs
.map((rowLiterals: Seq[Literal]) => {
// Serialize each sequence of Literals as a BoundaryRow
val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder();
val serializedExprs =
rowLiterals.map(lit_value =>
QueryPlanSerde.exprToProto(lit_value, outputAttributes).get)
rowBuilder.addAllPartitionBounds(serializedExprs.asJava)
rowBuilder.build()
})
partitioning.addAllBoundaryRows(boundaryRows.asJava)
}

{
// Convert the sequences of Literals to a collection of serialized BoundaryRows
val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs
.map((rowLiterals: Seq[Literal]) => {
// Serialize each sequence of Literals as a BoundaryRow
val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder();
val serializedExprs =
rowLiterals.map(lit_value =>
QueryPlanSerde.exprToProto(lit_value, outputAttributes).get)
rowBuilder.addAllPartitionBounds(serializedExprs.asJava)
rowBuilder.build()
})
partitioning.addAllBoundaryRows(boundaryRows.asJava)
}

val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder()
shuffleWriterBuilder.setPartitioning(
partitioningBuilder.setRangePartition(partitioning).build())

case _ =>
throw new UnsupportedOperationException(
s"Partitioning $outputPartitioning is not supported.")
}
val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder()
shuffleWriterBuilder.setPartitioning(
partitioningBuilder.setRangePartition(partitioning).build())

val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder()
shuffleWriterOpBuilder
.setShuffleWriter(shuffleWriterBuilder)
.addChildren(opBuilder.setScan(scanBuilder).build())
.build()
} else {
// There are unsupported scan type
throw new UnsupportedOperationException(
s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.")
case _ =>
throw new UnsupportedOperationException(
s"Partitioning $outputPartitioning is not supported.")
}

val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder()
shuffleWriterOpBuilder
.setShuffleWriter(shuffleWriterBuilder)
.addChildren(inputOperator)
.build()
}

override def stop(success: Boolean): Option[MapStatus] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType

import org.apache.comet.serde.OperatorOuterClass.Operator

/**
* A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle.
*/
Expand All @@ -49,7 +51,9 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val outputAttributes: Seq[Attribute] = Seq.empty,
val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty,
val numParts: Int = 0,
val rangePartitionBounds: Option[Seq[InternalRow]] = None)
val rangePartitionBounds: Option[Seq[InternalRow]] = None,
// For direct native execution: the child's native plan to compose with ShuffleWriter
val childNativePlan: Option[Operator] = None)
extends ShuffleDependency[K, V, C](
_rdd,
partitioner,
Expand Down
Loading
Loading