diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 9050a84fda56e..e6d2eccaf201c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -81,10 +81,13 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) dataframe).foreach(responseObserver.onNext) case proto.Plan.OpTypeCase.COMMAND => val command = request.getPlan.getCommand - planner.transformCommand(command, tracker) match { - case Some(plan) => - val qe = - new QueryExecution(session, plan, tracker, shuffleCleanupMode = shuffleCleanupMode) + planner.transformCommand(command) match { + case Some(transformer) => + val qe = new QueryExecution( + session, + transformer(tracker), + tracker, + shuffleCleanupMode = shuffleCleanupMode) qe.assertCommandExecuted() executeHolder.eventsManager.postFinished() case None => diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7320c6e3918c8..64a6ebf1a5222 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2646,12 +2646,12 @@ class SparkConnectPlanner( process(command, new MockObserver()) } - def transformCommand( - command: proto.Command, - tracker: QueryPlanningTracker): Option[LogicalPlan] = { + def transformCommand(command: proto.Command): Option[QueryPlanningTracker => LogicalPlan] = { command.getCommandTypeCase match { case proto.Command.CommandTypeCase.WRITE_OPERATION => - Some(transformWriteOperation(command.getWriteOperation, tracker)) + Some(transformWriteOperation(command.getWriteOperation)) + case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 => + Some(transformWriteOperationV2(command.getWriteOperationV2)) case _ => None } @@ -2660,6 +2660,11 @@ class SparkConnectPlanner( def process( command: proto.Command, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + val transformerOpt = transformCommand(command) + if (transformerOpt.isDefined) { + transformAndRunCommand(transformerOpt.get) + return + } command.getCommandTypeCase match { case proto.Command.CommandTypeCase.REGISTER_FUNCTION => handleRegisterUserDefinedFunction(command.getRegisterFunction) @@ -2667,12 +2672,8 @@ class SparkConnectPlanner( handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction) case proto.Command.CommandTypeCase.REGISTER_DATA_SOURCE => handleRegisterUserDefinedDataSource(command.getRegisterDataSource) - case proto.Command.CommandTypeCase.WRITE_OPERATION => - handleWriteOperation(command.getWriteOperation) case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW => handleCreateViewCommand(command.getCreateDataframeView) - case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 => - handleWriteOperationV2(command.getWriteOperationV2) case proto.Command.CommandTypeCase.EXTENSION => handleCommandPlugin(command.getExtension) case proto.Command.CommandTypeCase.SQL_COMMAND => @@ -3089,8 +3090,16 @@ class SparkConnectPlanner( executeHolder.eventsManager.postFinished() } - private def transformWriteOperation( - writeOperation: proto.WriteOperation, + /** + * Transforms the write operation. + * + * The input write operation contains a reference to the input plan and transforms it to the + * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the + * parameters of the WriteOperation into the corresponding methods calls. + * + * @param writeOperation + */ + private def transformWriteOperation(writeOperation: proto.WriteOperation)( tracker: QueryPlanningTracker): LogicalPlan = { // Transform the input plan into the logical plan. val plan = transformRelation(writeOperation.getInput) @@ -3149,29 +3158,15 @@ class SparkConnectPlanner( } } - private def runCommand(command: LogicalPlan, tracker: QueryPlanningTracker): Unit = { - val qe = new QueryExecution(session, command, tracker) - qe.assertCommandExecuted() - } - - /** - * Transforms the write operation and executes it. - * - * The input write operation contains a reference to the input plan and transforms it to the - * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the - * parameters of the WriteOperation into the corresponding methods calls. - * - * @param writeOperation - */ - private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = { + private def transformAndRunCommand(transformer: QueryPlanningTracker => LogicalPlan): Unit = { val tracker = executeHolder.eventsManager.createQueryPlanningTracker() - runCommand(transformWriteOperation(writeOperation, tracker), tracker) - + val qe = new QueryExecution(session, transformer(tracker), tracker) + qe.assertCommandExecuted() executeHolder.eventsManager.postFinished() } /** - * Transforms the write operation and executes it. + * Transforms the write operation. * * The input write operation contains a reference to the input plan and transforms it to the * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the @@ -3179,11 +3174,11 @@ class SparkConnectPlanner( * * @param writeOperation */ - private def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = { + private def transformWriteOperationV2(writeOperation: proto.WriteOperationV2)( + tracker: QueryPlanningTracker): LogicalPlan = { // Transform the input plan into the logical plan. val plan = transformRelation(writeOperation.getInput) // And create a Dataset from the plan. - val tracker = executeHolder.eventsManager.createQueryPlanningTracker() val dataset = Dataset.ofRows(session, plan, tracker) val w = dataset.writeTo(table = writeOperation.getTableName) @@ -3214,32 +3209,28 @@ class SparkConnectPlanner( writeOperation.getMode match { case proto.WriteOperationV2.Mode.MODE_CREATE => if (writeOperation.hasProvider) { - w.using(writeOperation.getProvider).create() - } else { - w.create() + w.using(writeOperation.getProvider) } + w.createCommand() case proto.WriteOperationV2.Mode.MODE_OVERWRITE => - w.overwrite(Column(transformExpression(writeOperation.getOverwriteCondition))) + w.overwriteCommand(Column(transformExpression(writeOperation.getOverwriteCondition))) case proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS => - w.overwritePartitions() + w.overwritePartitionsCommand() case proto.WriteOperationV2.Mode.MODE_APPEND => - w.append() + w.appendCommand() case proto.WriteOperationV2.Mode.MODE_REPLACE => if (writeOperation.hasProvider) { - w.using(writeOperation.getProvider).replace() - } else { - w.replace() + w.using(writeOperation.getProvider) } + w.replaceCommand(orCreate = false) case proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE => if (writeOperation.hasProvider) { - w.using(writeOperation.getProvider).createOrReplace() - } else { - w.createOrReplace() + w.using(writeOperation.getProvider) } + w.replaceCommand(orCreate = true) case other => throw InvalidInputErrors.invalidEnum(other) } - executeHolder.eventsManager.postFinished() } private def handleWriteStreamOperationStart( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala index c6eacfe8f1ed9..8d5d91cb2243e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala @@ -148,14 +148,17 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) /** @inheritdoc */ override def create(): Unit = { - runCommand( - CreateTableAsSelect( - UnresolvedIdentifier(tableName), - partitioning.getOrElse(Seq.empty) ++ clustering, - logicalPlan, - buildTableSpec(), - options.toMap, - false)) + runCommand(createCommand()) + } + + private[sql] def createCommand(): LogicalPlan = { + CreateTableAsSelect( + UnresolvedIdentifier(tableName), + partitioning.getOrElse(Seq.empty) ++ clustering, + logicalPlan, + buildTableSpec(), + options.toMap, + false) } private def buildTableSpec(): UnresolvedTableSpec = { @@ -186,28 +189,37 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) /** @inheritdoc */ @throws(classOf[NoSuchTableException]) def append(): Unit = { - val append = AppendData.byName( + runCommand(appendCommand()) + } + + private[sql] def appendCommand(): LogicalPlan = { + AppendData.byName( UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)), logicalPlan, options.toMap) - runCommand(append) } /** @inheritdoc */ @throws(classOf[NoSuchTableException]) def overwrite(condition: Column): Unit = { - val overwrite = OverwriteByExpression.byName( + runCommand(overwriteCommand(condition)) + } + + private[sql] def overwriteCommand(condition: Column): LogicalPlan = { + OverwriteByExpression.byName( UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), logicalPlan, expression(condition), options.toMap) - runCommand(overwrite) } /** @inheritdoc */ @throws(classOf[NoSuchTableException]) def overwritePartitions(): Unit = { - val dynamicOverwrite = OverwritePartitionsDynamic.byName( + runCommand(overwritePartitionsCommand()) + } + + private[sql] def overwritePartitionsCommand(): LogicalPlan = { + OverwritePartitionsDynamic.byName( UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), logicalPlan, options.toMap) - runCommand(dynamicOverwrite) } /** @@ -220,13 +232,17 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) } private def internalReplace(orCreate: Boolean): Unit = { - runCommand(ReplaceTableAsSelect( + runCommand(replaceCommand(orCreate)) + } + + private[sql] def replaceCommand(orCreate: Boolean): LogicalPlan = { + ReplaceTableAsSelect( UnresolvedIdentifier(tableName), partitioning.getOrElse(Seq.empty) ++ clustering, logicalPlan, buildTableSpec(), writeOptions = options.toMap, - orCreate = orCreate)) + orCreate = orCreate) } }