diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala index dbc3e17f83..f8da68d59f 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala @@ -19,7 +19,6 @@ package org.apache.comet.parquet -import java.io.ByteArrayOutputStream import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} @@ -43,6 +42,8 @@ import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String +import com.google.protobuf.CodedOutputStream + import org.apache.comet.parquet.SourceFilterSerde.{createBinaryExpr, createNameExpr, createUnaryExpr, createValueExpr} import org.apache.comet.serde.ExprOuterClass import org.apache.comet.serde.QueryPlanSerde.scalarFunctionExprToProto @@ -885,10 +886,12 @@ class ParquetFilters( def createNativeFilters(predicates: Seq[sources.Filter]): Option[Array[Byte]] = { predicates.reduceOption(sources.And).flatMap(createNativeFilter).map { expr => - val outputStream = new ByteArrayOutputStream() - expr.writeTo(outputStream) - outputStream.close() - outputStream.toByteArray + val size = expr.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(bytes) + expr.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + bytes } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index fd97fe3fa2..ef5507ef36 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -53,9 +53,13 @@ object CometExecUtils { limit: Int, offset: Int = 0): RDD[ColumnarBatch] = { val numParts = childPlan.getNumPartitions + val numOutputCols = outputAttribute.length + // Serialize the plan once and broadcast to all executors to avoid repeated serialization + val serializedPlan = CometExec.serializePlan( + CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get) + val broadcastPlan = childPlan.sparkContext.broadcast(serializedPlan) childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => - val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get - CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx) + CometExec.getCometIterator(Seq(iter), numOutputCols, broadcastPlan.value, numParts, idx) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index f153a691ef..39e7ac6eef 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.comet -import java.io.ByteArrayOutputStream - import scala.jdk.CollectionConverters._ import org.apache.hadoop.fs.Path @@ -34,6 +32,8 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils +import com.google.protobuf.CodedOutputStream + import org.apache.comet.CometExecIterator import org.apache.comet.serde.OperatorOuterClass.Operator @@ -75,10 +75,12 @@ case class CometNativeWriteExec( sparkContext.collectionAccumulator[FileCommitProtocol.TaskCommitMessage]("taskCommitMessages") override def serializedPlanOpt: SerializedPlan = { - val outputStream = new ByteArrayOutputStream() - nativeOp.writeTo(outputStream) - outputStream.close() - SerializedPlan(Some(outputStream.toByteArray)) + val size = nativeOp.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(bytes) + nativeOp.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + SerializedPlan(Some(bytes)) } override def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -196,10 +198,11 @@ case class CometNativeWriteExec( val nativeMetrics = CometMetricNode.fromCometPlan(this) - val outputStream = new ByteArrayOutputStream() - modifiedNativeOp.writeTo(outputStream) - outputStream.close() - val planBytes = outputStream.toByteArray + val size = modifiedNativeOp.getSerializedSize + val planBytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(planBytes) + modifiedNativeOp.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() val execIterator = new CometExecIterator( CometExec.newIterId, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 2517c19f26..bec855a5a1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -133,12 +133,20 @@ case class CometTakeOrderedAndProjectExec( CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit) } else { val numParts = childRDD.getNumPartitions + val numOutputCols = child.output.length + // Serialize the plan once and broadcast to avoid repeated serialization + val serializedTopK = CometExec.serializePlan( + CometExecUtils + .getTopKNativePlan(child.output, sortOrder, child, limit) + .get) + val broadcastTopK = sparkContext.broadcast(serializedTopK) childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => - val topK = - CometExecUtils - .getTopKNativePlan(child.output, sortOrder, child, limit) - .get - CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx) + CometExec.getCometIterator( + Seq(iter), + numOutputCols, + broadcastTopK.value, + numParts, + idx) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0a435e5b7a..94f46f0dd3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.comet -import java.io.ByteArrayOutputStream import java.util.Locale import scala.collection.mutable @@ -50,6 +49,7 @@ import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects +import com.google.protobuf.CodedOutputStream import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo} @@ -113,6 +113,20 @@ object CometExec { def newIterId: Long = curId.getAndIncrement() + /** + * Serializes a native plan operator to a byte array. This method should be called once outside + * of partition iteration, and the resulting bytes can be reused across all partitions to avoid + * repeated serialization overhead. + */ + def serializePlan(nativePlan: Operator): Array[Byte] = { + val size = nativePlan.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(bytes) + nativePlan.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + bytes + } + def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, @@ -130,6 +144,28 @@ object CometExec { encryptedFilePaths = Seq.empty) } + /** + * Creates a CometExecIterator from pre-serialized plan bytes. Use this overload when the same + * plan is used across multiple partitions to avoid serializing the plan repeatedly. + */ + def getCometIterator( + inputs: Seq[Iterator[ColumnarBatch]], + numOutputCols: Int, + serializedPlan: Array[Byte], + numParts: Int, + partitionIdx: Int): CometExecIterator = { + new CometExecIterator( + newIterId, + inputs, + numOutputCols, + serializedPlan, + CometMetricNode(Map.empty), + numParts, + partitionIdx, + broadcastedHadoopConfForEncryption = None, + encryptedFilePaths = Seq.empty) + } + def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, @@ -139,10 +175,7 @@ object CometExec { partitionIdx: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]], encryptedFilePaths: Seq[String]): CometExecIterator = { - val outputStream = new ByteArrayOutputStream() - nativePlan.writeTo(outputStream) - outputStream.close() - val bytes = outputStream.toByteArray + val bytes = serializePlan(nativePlan) new CometExecIterator( newIterId, inputs, @@ -414,10 +447,12 @@ abstract class CometNativeExec extends CometExec { def convertBlock(): CometNativeExec = { def transform(arg: Any): AnyRef = arg match { case serializedPlan: SerializedPlan if serializedPlan.isEmpty => - val out = new ByteArrayOutputStream() - nativeOp.writeTo(out) - out.close() - SerializedPlan(Some(out.toByteArray)) + val size = nativeOp.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(bytes) + nativeOp.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + SerializedPlan(Some(bytes)) case other: AnyRef => other case null => null }