diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index fd97fe3fa2..a2af60142b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -53,9 +53,11 @@ object CometExecUtils { limit: Int, offset: Int = 0): RDD[ColumnarBatch] = { val numParts = childPlan.getNumPartitions + // Serialize the plan once before mapping to avoid repeated serialization per partition + val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get + val serializedPlan = CometExec.serializeNativePlan(limitOp) childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => - val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get - CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx) + CometExec.getCometIterator(Seq(iter), outputAttribute.length, serializedPlan, numParts, idx) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 2517c19f26..2abe783172 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -133,12 +133,15 @@ case class CometTakeOrderedAndProjectExec( CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit) } else { val numParts = childRDD.getNumPartitions + // Serialize the plan once before mapping to avoid repeated serialization per partition + val topK = + CometExecUtils + .getTopKNativePlan(child.output, sortOrder, child, limit) + .get + val serializedTopK = CometExec.serializeNativePlan(topK) + val numOutputCols = child.output.length childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => - val topK = - CometExecUtils - .getTopKNativePlan(child.output, sortOrder, child, limit) - .get - CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx) + CometExec.getCometIterator(Seq(iter), numOutputCols, serializedTopK, numParts, idx) } } @@ -154,11 +157,19 @@ case class CometTakeOrderedAndProjectExec( new CometShuffledBatchRDD(dep, readMetrics) } + // Serialize the plan once before mapping to avoid repeated serialization per partition + val topKAndProjection = CometExecUtils + .getProjectionNativePlan(projectList, child.output, sortOrder, child, limit, offset) + .get + val serializedTopKAndProjection = CometExec.serializeNativePlan(topKAndProjection) + val finalOutputLength = output.length singlePartitionRDD.mapPartitionsInternal { iter => - val topKAndProjection = CometExecUtils - .getProjectionNativePlan(projectList, child.output, sortOrder, child, limit, offset) - .get - val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection, 1, 0) + val it = CometExec.getCometIterator( + Seq(iter), + finalOutputLength, + serializedTopKAndProjection, + 1, + 0) setSubqueries(it.id, this) Option(TaskContext.get()).foreach { context => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index f4f97b8312..cb70986170 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -113,6 +113,19 @@ object CometExec { def newIterId: Long = curId.getAndIncrement() + /** + * Serialize a native plan to bytes. Use this method to serialize the plan once before calling + * getCometIterator for each partition, avoiding repeated serialization. + */ + def serializeNativePlan(nativePlan: Operator): Array[Byte] = { + val size = nativePlan.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(bytes) + nativePlan.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + bytes + } + def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, @@ -130,6 +143,28 @@ object CometExec { encryptedFilePaths = Seq.empty) } + /** + * Create a CometExecIterator with a pre-serialized native plan. Use this overload when + * executing the same plan across multiple partitions to avoid serializing the plan repeatedly. + */ + def getCometIterator( + inputs: Seq[Iterator[ColumnarBatch]], + numOutputCols: Int, + serializedPlan: Array[Byte], + numParts: Int, + partitionIdx: Int): CometExecIterator = { + new CometExecIterator( + newIterId, + inputs, + numOutputCols, + serializedPlan, + CometMetricNode(Map.empty), + numParts, + partitionIdx, + broadcastedHadoopConfForEncryption = None, + encryptedFilePaths = Seq.empty) + } + def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, @@ -139,11 +174,7 @@ object CometExec { partitionIdx: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]], encryptedFilePaths: Seq[String]): CometExecIterator = { - val size = nativePlan.getSerializedSize - val bytes = new Array[Byte](size) - val codedOutput = CodedOutputStream.newInstance(bytes) - nativePlan.writeTo(codedOutput) - codedOutput.checkNoSpaceLeft() + val bytes = serializeNativePlan(nativePlan) new CometExecIterator( newIterId, inputs,