diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala index 311bab084bbf7..11d025c55a96d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala @@ -19,22 +19,25 @@ package org.apache.spark.sql.execution.datasources.v2 import org.roaringbitmap.longlong.Roaring64Bitmap +import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.catalyst.expressions.BasePredicate +import org.apache.spark.sql.catalyst.expressions.BindReferences import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.Projection import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral, GeneratePredicate, JavaCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Context, Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.UnaryExecNode +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.types.BooleanType case class MergeRowsExec( isSourceRowPresent: Expression, @@ -44,7 +47,7 @@ case class MergeRowsExec( notMatchedBySourceInstructions: Seq[Instruction], checkCardinality: Boolean, output: Seq[Attribute], - child: SparkPlan) extends UnaryExecNode { + child: SparkPlan) extends UnaryExecNode with CodegenSupport { override lazy val metrics: Map[String, SQLMetric] = Map( "numTargetRowsCopied" -> SQLMetrics.createMetric(sparkContext, @@ -92,6 +95,277 @@ case class MergeRowsExec( child.execute().mapPartitions(processPartition) } + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + // Save the input variables that were passed to doConsume + val inputCurrentVars = input + + // code for instruction execution code + generateInstructionExecutionCode(ctx, inputCurrentVars) + } + + + /** + * code for cardinality validation + */ + private def generateCardinalityValidationCode(ctx: CodegenContext, rowIdOrdinal: Int, + input: Seq[ExprCode]): ExprCode = { + val bitmapClass = classOf[Roaring64Bitmap] + val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds", + v => s"$v = new ${bitmapClass.getName}();") + + val currentRowId = input(rowIdOrdinal) + val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName + ".MODULE$" + val code = + code""" + |${currentRowId.code} + |if ($rowIdBitmap.contains(${currentRowId.value})) { + | throw $queryExecutionErrorsClass.mergeCardinalityViolationError(); + |} + |$rowIdBitmap.add(${currentRowId.value}); + """.stripMargin + ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass)) + } + + /** + * Generate code for instruction execution based on row presence conditions + */ + private def generateInstructionExecutionCode(ctx: CodegenContext, + inputExprs: Seq[ExprCode]): String = { + + // code for evaluating src/tgt presence conditions + val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, child.output, inputExprs) + val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, child.output, inputExprs) + + // code for each instruction type + val matchedInstructionsCode = generateInstructionsCode(ctx, matchedInstructions, + "matched", inputExprs, sourcePresent = true) + val notMatchedInstructionsCode = generateInstructionsCode(ctx, notMatchedInstructions, + "notMatched", inputExprs, sourcePresent = true) + val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx, + notMatchedBySourceInstructions, "notMatchedBySource", inputExprs, sourcePresent = false) + + val cardinalityValidationCode = if (checkCardinality) { + val rowIdOrdinal = child.output.indexWhere(attr => conf.resolver(attr.name, ROW_ID)) + assert(rowIdOrdinal != -1, "Cannot find row ID attr") + generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code + } else { + "" + } + + s""" + |${sourcePresentExpr.code} + |${targetPresentExpr.code} + | + |if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) { + | $cardinalityValidationCode + | $matchedInstructionsCode + |} else if (${sourcePresentExpr.value}) { + | $notMatchedInstructionsCode + |} else if (${targetPresentExpr.value}) { + | $notMatchedBySourceInstructionsCode + |} + """.stripMargin + } + + /** + * Generate code for executing a sequence of instructions + */ + private def generateInstructionsCode(ctx: CodegenContext, instructions: Seq[Instruction], + instructionType: String, + inputExprs: Seq[ExprCode], + sourcePresent: Boolean): String = { + if (instructions.isEmpty) { + "" + } else { + val instructionCodes = instructions.map(instruction => + generateSingleInstructionCode(ctx, instruction, inputExprs, sourcePresent)) + + s""" + |${instructionCodes.mkString("\n")} + |return; + """.stripMargin + } + } + + private def generateSingleInstructionCode(ctx: CodegenContext, + instruction: Instruction, + inputExprs: Seq[ExprCode], + sourcePresent: Boolean): String = { + instruction match { + case Keep(context, condition, outputExprs) => + val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs) + val code = generatePredicateCode(ctx, condition, child.output, inputExprs) + + // Generate metric updates based on context + val metricUpdateCode = generateMetricUpdateCode(ctx, context, sourcePresent) + + s""" + |${code.code} + |if (${code.value}) { + | $metricUpdateCode + | ${consume(ctx, projectionExpr)} + | return; + |} + """.stripMargin + + case Discard(condition) => + val code = generatePredicateCode(ctx, condition, child.output, inputExprs) + val metricUpdateCode = generateDeleteMetricUpdateCode(ctx, sourcePresent) + + s""" + |${code.code} + |if (${code.value}) { + | $metricUpdateCode + | return; // Discar row + |} + """.stripMargin + + case Split(condition, outputExprs, otherOutputExprs) => + val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs) + val otherProjectionExpr = generateProjectionCode(ctx, otherOutputExprs, inputExprs) + val code = generatePredicateCode(ctx, condition, child.output, inputExprs) + val metricUpdateCode = generateUpdateMetricUpdateCode(ctx, sourcePresent) + + s""" + |${code.code} + |if (${code.value}) { + | $metricUpdateCode + | ${consume(ctx, projectionExpr)} + | ${consume(ctx, otherProjectionExpr)} + | return; + |} + """.stripMargin + case _ => + // Codegen not implemented + throw new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_3073", + messageParameters = Map("instruction" -> instruction.toString)) + } + } + + /** + * metric update code based on Keep's context + */ + private def generateMetricUpdateCode(ctx: CodegenContext, context: Context, + sourcePresent: Boolean): String = { + context match { + case Copy => + val copyMetric = metricTerm(ctx, "numTargetRowsCopied") + s"$copyMetric.add(1);" + + case Insert => + val insertMetric = metricTerm(ctx, "numTargetRowsInserted") + s"$insertMetric.add(1);" + + case Update => + generateUpdateMetricUpdateCode(ctx, sourcePresent) + + case Delete => + generateDeleteMetricUpdateCode(ctx, sourcePresent) + + case _ => + throw new IllegalArgumentException(s"Unexpected context for KeepExec: $context") + } + } + + private def generateUpdateMetricUpdateCode(ctx: CodegenContext, + sourcePresent: Boolean): String = { + val updateMetric = metricTerm(ctx, "numTargetRowsUpdated") + if (sourcePresent) { + val matchedUpdateMetric = metricTerm(ctx, "numTargetRowsMatchedUpdated") + + s""" + |$updateMetric.add(1); + |$matchedUpdateMetric.add(1); + """.stripMargin + } else { + val notMatchedBySourceUpdateMetric = metricTerm(ctx, "numTargetRowsNotMatchedBySourceUpdated") + + s""" + |$updateMetric.add(1); + |$notMatchedBySourceUpdateMetric.add(1); + """.stripMargin + } + } + + private def generateDeleteMetricUpdateCode(ctx: CodegenContext, + sourcePresent: Boolean): String = { + val deleteMetric = metricTerm(ctx, "numTargetRowsDeleted") + if (sourcePresent) { + val matchedDeleteMetric = metricTerm(ctx, "numTargetRowsMatchedDeleted") + + s""" + |$deleteMetric.add(1); + |$matchedDeleteMetric.add(1); + """.stripMargin + } else { + val notMatchedBySourceDeleteMetric = metricTerm(ctx, "numTargetRowsNotMatchedBySourceDeleted") + + s""" + |$deleteMetric.add(1); + |$notMatchedBySourceDeleteMetric.add(1); + """.stripMargin + } + } + + /** + * Helper method to save and restore CodegenContext state for code generation. + * + * This is needed because when generating code for expressions, the CodegenContext + * state (currentVars and INPUT_ROW) gets modified during expression evaluation. + * This method temporarily sets the context to the input variables from doConsume + * and restores the original state after the block completes. + */ + private def withCodegenContext[T]( + ctx: CodegenContext, + inputCurrentVars: Seq[ExprCode])(block: => T): T = { + val originalCurrentVars = ctx.currentVars + val originalInputRow = ctx.INPUT_ROW + try { + // Set to the input variables saved in doConsume + ctx.currentVars = inputCurrentVars + block + } finally { + // Restore original context + ctx.currentVars = originalCurrentVars + ctx.INPUT_ROW = originalInputRow + } + } + + private def generatePredicateCode(ctx: CodegenContext, + predicate: Expression, + inputAttrs: Seq[Attribute], + inputCurrentVars: Seq[ExprCode]): ExprCode = { + withCodegenContext(ctx, inputCurrentVars) { + val boundPredicate = BindReferences.bindReference(predicate, inputAttrs) + val ev = boundPredicate.genCode(ctx) + val predicateVar = ctx.freshName("predicateResult") + val code = code""" + |${ev.code} + |boolean $predicateVar = !${ev.isNull} && ${ev.value}; + """.stripMargin + ExprCode(code, FalseLiteral, + JavaCode.variable(predicateVar, BooleanType)) + } + } + + private def generateProjectionCode(ctx: CodegenContext, + outputExprs: Seq[Expression], + inputCurrentVars: Seq[ExprCode]): Seq[ExprCode] = { + withCodegenContext(ctx, inputCurrentVars) { + val boundExprs = outputExprs.map(BindReferences.bindReference(_, child.output)) + boundExprs.map(_.genCode(ctx)) + } + } + private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val isSourceRowPresentPred = createPredicate(isSourceRowPresent) val isTargetRowPresentPred = createPredicate(isTargetRowPresent)