Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ object CometExecUtils {
limit: Int,
offset: Int = 0): RDD[ColumnarBatch] = {
val numParts = childPlan.getNumPartitions
// Serialize the plan once before mapping to avoid repeated serialization per partition
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get
val serializedPlan = CometExec.serializeNativePlan(limitOp)
childPlan.mapPartitionsWithIndexInternal { case (idx, iter) =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx)
CometExec.getCometIterator(Seq(iter), outputAttribute.length, serializedPlan, numParts, idx)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,15 @@ case class CometTakeOrderedAndProjectExec(
CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit)
} else {
val numParts = childRDD.getNumPartitions
// Serialize the plan once before mapping to avoid repeated serialization per partition
val topK =
CometExecUtils
.getTopKNativePlan(child.output, sortOrder, child, limit)
.get
val serializedTopK = CometExec.serializeNativePlan(topK)
val numOutputCols = child.output.length
childRDD.mapPartitionsWithIndexInternal { case (idx, iter) =>
val topK =
CometExecUtils
.getTopKNativePlan(child.output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx)
CometExec.getCometIterator(Seq(iter), numOutputCols, serializedTopK, numParts, idx)
}
}

Expand All @@ -154,11 +157,19 @@ case class CometTakeOrderedAndProjectExec(
new CometShuffledBatchRDD(dep, readMetrics)
}

// Serialize the plan once before mapping to avoid repeated serialization per partition
val topKAndProjection = CometExecUtils
.getProjectionNativePlan(projectList, child.output, sortOrder, child, limit, offset)
.get
val serializedTopKAndProjection = CometExec.serializeNativePlan(topKAndProjection)
val finalOutputLength = output.length
singlePartitionRDD.mapPartitionsInternal { iter =>
val topKAndProjection = CometExecUtils
.getProjectionNativePlan(projectList, child.output, sortOrder, child, limit, offset)
.get
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection, 1, 0)
val it = CometExec.getCometIterator(
Seq(iter),
finalOutputLength,
serializedTopKAndProjection,
1,
0)
setSubqueries(it.id, this)

Option(TaskContext.get()).foreach { context =>
Expand Down
41 changes: 36 additions & 5 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ object CometExec {

def newIterId: Long = curId.getAndIncrement()

/**
* Serialize a native plan to bytes. Use this method to serialize the plan once before calling
* getCometIterator for each partition, avoiding repeated serialization.
*/
def serializeNativePlan(nativePlan: Operator): Array[Byte] = {
val size = nativePlan.getSerializedSize
val bytes = new Array[Byte](size)
val codedOutput = CodedOutputStream.newInstance(bytes)
nativePlan.writeTo(codedOutput)
codedOutput.checkNoSpaceLeft()
bytes
}

def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
Expand All @@ -130,6 +143,28 @@ object CometExec {
encryptedFilePaths = Seq.empty)
}

/**
* Create a CometExecIterator with a pre-serialized native plan. Use this overload when
* executing the same plan across multiple partitions to avoid serializing the plan repeatedly.
*/
def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
serializedPlan: Array[Byte],
numParts: Int,
partitionIdx: Int): CometExecIterator = {
new CometExecIterator(
newIterId,
inputs,
numOutputCols,
serializedPlan,
CometMetricNode(Map.empty),
numParts,
partitionIdx,
broadcastedHadoopConfForEncryption = None,
encryptedFilePaths = Seq.empty)
}

def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
Expand All @@ -139,11 +174,7 @@ object CometExec {
partitionIdx: Int,
broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]],
encryptedFilePaths: Seq[String]): CometExecIterator = {
val size = nativePlan.getSerializedSize
val bytes = new Array[Byte](size)
val codedOutput = CodedOutputStream.newInstance(bytes)
nativePlan.writeTo(codedOutput)
codedOutput.checkNoSpaceLeft()
val bytes = serializeNativePlan(nativePlan)
new CometExecIterator(
newIterId,
inputs,
Expand Down
Loading