diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs index 09cec8ec7..4e87dc6c6 100644 --- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs @@ -2,10 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.Spark.E2ETest.Utils; using Microsoft.Spark.Sql; using Microsoft.Spark.Sql.Streaming; +using Microsoft.Spark.Sql.Types; using Xunit; +using static Microsoft.Spark.Sql.Functions; namespace Microsoft.Spark.E2ETest.IpcTests { @@ -59,5 +65,213 @@ public void TestSignaturesV2_3_X() Assert.IsType(dsw.Trigger(Trigger.Once())); } + + [SkipIfSparkVersionIsLessThan(Versions.V2_4_0)] + public void TestForeach() + { + // Temporary folder to put our test stream input. + using var srcTempDirectory = new TemporaryDirectory(); + string streamInputPath = Path.Combine(srcTempDirectory.Path, "streamInput"); + + Func intToStrUdf = Udf(i => i.ToString()); + + // id column: [1, 2, ..., 99] + // idStr column: "id" column converted to string + // idAndIdStr column: Struct column composed from the "id" and "idStr" column. + _spark.Range(1, 100) + .WithColumn("idStr", intToStrUdf(Col("id"))) + .WithColumn("idAndIdStr", Struct("id", "idStr")) + .Write() + .Json(streamInputPath); + + // Test a scenario where IForeachWriter runs without issues. + // If everything is working as expected, then: + // - Triggering stream will not throw an exception + // - 3 CSV files will be created in the temporary directory. + // - 0 Exception files will be created in the temporary directory. + // - The CSV files will contain valid data to read, where the + // expected entries will contain [1111, 2222, ..., 99999999] + TestAndValidateForeach( + streamInputPath, + new TestForeachWriter(), + 3, + 0, + Enumerable.Range(1, 99).Select(i => Convert.ToInt32($"{i}{i}{i}{i}"))); + + // Test scenario where IForeachWriter.Open returns false. + // When IForeachWriter.Open returns false, then IForeachWriter.Process + // is not called. Verify that: + // - Triggering stream will not throw an exception + // - 3 CSV files will be created in the temporary directory. + // - 0 Exception files will be created in the temporary directory. + // - The CSV files will not contain valid data to read. + TestAndValidateForeach( + streamInputPath, + new TestForeachWriterOpenFailure(), + 3, + 0, + Enumerable.Empty()); + + // Test scenario where IForeachWriter.Process throws an Exception. + // When IForeachWriter.Process throws an Exception, then the exception + // is rethrown by ForeachWriterWrapper. We will limit the partitions + // to 1 to make validating this scenario simpler. Verify that: + // - Triggering stream throws an exception. + // - 1 CSV file will be created in the temporary directory. + // - 1 Exception will be created in the temporary directory. The + // thrown exception from Process() will be sent to Close(). + // - The CSV file will not contain valid data to read. + TestAndValidateForeach( + streamInputPath, + new TestForeachWriterProcessFailure(), + 1, + 1, + Enumerable.Empty()); + } + + private void TestAndValidateForeach( + string streamInputPath, + TestForeachWriter foreachWriter, + int expectedCSVFiles, + int expectedExceptionFiles, + IEnumerable expectedOutput) + { + // Temporary folder the TestForeachWriter will write to. + using var dstTempDirectory = new TemporaryDirectory(); + foreachWriter.WritePath = dstTempDirectory.Path; + + // Read streamInputPath, repartition data, then + // call TestForeachWriter on the data. + DataStreamWriter dsw = _spark + .ReadStream() + .Schema(new StructType(new[] + { + new StructField("id", new IntegerType()), + new StructField("idStr", new StringType()), + new StructField("idAndIdStr", new StructType(new[] + { + new StructField("id", new IntegerType()), + new StructField("idStr", new StringType()) + })) + })) + .Json(streamInputPath) + .Repartition(expectedCSVFiles) + .WriteStream() + .Foreach(foreachWriter); + + // Trigger the stream batch once. + if (expectedExceptionFiles > 0) + { + Assert.Throws( + () => dsw.Trigger(Trigger.Once()).Start().AwaitTermination()); + } + else + { + dsw.Trigger(Trigger.Once()).Start().AwaitTermination(); + } + + // Verify that TestForeachWriter created a unique .csv when + // ForeachWriter.Open was called on each partitionId. + Assert.Equal( + expectedCSVFiles, + Directory.GetFiles(dstTempDirectory.Path, "*.csv").Length); + + // Only if ForeachWriter.Process(Row) throws an exception, will + // ForeachWriter.Close(Exception) create a file with the + // .exception extension. + Assert.Equal( + expectedExceptionFiles, + Directory.GetFiles(dstTempDirectory.Path, "*.exception").Length); + + // Read in the *.csv file(s) generated by the TestForeachWriter. + // If there are multiple input files, sorting by "id" will make + // validation simpler. Contents of the *.csv will only be populated + // on successful calls to the ForeachWriter.Process method. + DataFrame foreachWriterOutputDF = _spark + .Read() + .Schema("id INT") + .Csv(dstTempDirectory.Path) + .Sort("id"); + + // Validate expected *.csv data. + Assert.Equal( + expectedOutput.Select(i => new object[] { i }), + foreachWriterOutputDF.Collect().Select(r => r.Values)); + } + + [Serializable] + private class TestForeachWriter : IForeachWriter + { + [NonSerialized] + private StreamWriter _streamWriter; + + private long _partitionId; + + private long _epochId; + + internal string WritePath { get; set; } + + public void Close(Exception errorOrNull) + { + if (errorOrNull != null) + { + FileStream fs = File.Create( + Path.Combine( + WritePath, + $"Close-{_partitionId}-{_epochId}.exception")); + fs.Dispose(); + } + + _streamWriter?.Dispose(); + } + + public virtual bool Open(long partitionId, long epochId) + { + _partitionId = partitionId; + _epochId = epochId; + try + { + _streamWriter = new StreamWriter( + Path.Combine( + WritePath, + $"sink-foreachWriter-{_partitionId}-{_epochId}.csv")); + return true; + } + catch + { + return false; + } + } + + public virtual void Process(Row value) + { + Row idAndIdStr = value.GetAs("idAndIdStr"); + _streamWriter.WriteLine( + string.Format("{0}{1}{2}{3}", + value.GetAs("id"), + value.GetAs("idStr"), + idAndIdStr.GetAs("id"), + idAndIdStr.GetAs("idStr"))); + } + } + + [Serializable] + private class TestForeachWriterOpenFailure : TestForeachWriter + { + public override bool Open(long partitionId, long epochId) + { + base.Open(partitionId, epochId); + return false; + } + } + + [Serializable] + private class TestForeachWriterProcessFailure : TestForeachWriter + { + public override void Process(Row value) + { + throw new Exception("TestForeachWriterProcessFailure Process(Row) failure."); + } + } } } diff --git a/src/csharp/Microsoft.Spark.Worker/Payload.cs b/src/csharp/Microsoft.Spark.Worker/Payload.cs index e483ccf25..d3a709300 100644 --- a/src/csharp/Microsoft.Spark.Worker/Payload.cs +++ b/src/csharp/Microsoft.Spark.Worker/Payload.cs @@ -3,81 +3,10 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; -using System.Linq; using Microsoft.Spark.Utils; namespace Microsoft.Spark.Worker { - /// - /// TaskContext stores information related to a task. - /// - internal class TaskContext - { - internal int StageId { get; set; } - - internal int PartitionId { get; set; } - - internal int AttemptNumber { get; set; } - - internal long AttemptId { get; set; } - - internal bool IsBarrier { get; set; } - - internal int Port { get; set; } - - internal string Secret { get; set; } - - internal IEnumerable Resources { get; set; } = new List(); - - internal Dictionary LocalProperties { get; set; } = - new Dictionary(); - - public override bool Equals(object obj) - { - if (!(obj is TaskContext other)) - { - return false; - } - - return (StageId == other.StageId) && - (PartitionId == other.PartitionId) && - (AttemptNumber == other.AttemptNumber) && - (AttemptId == other.AttemptId) && - Resources.SequenceEqual(other.Resources) && - (LocalProperties.Count == other.LocalProperties.Count) && - !LocalProperties.Except(other.LocalProperties).Any(); - } - - public override int GetHashCode() - { - return StageId; - } - - internal class Resource - { - internal string Key { get; set; } - internal string Value { get; set; } - internal IEnumerable Addresses { get; set; } = new List(); - - public override bool Equals(object obj) - { - if (!(obj is Resource other)) - { - return false; - } - - return (Key == other.Key) && - (Value == other.Value) && - Addresses.SequenceEqual(Addresses); - } - - public override int GetHashCode() - { - return Key.GetHashCode(); - } - } - } - /// /// BroadcastVariables stores information on broadcast variables. /// diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/PayloadProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/PayloadProcessor.cs index 2cb7f09a3..5e2f809f6 100644 --- a/src/csharp/Microsoft.Spark.Worker/Processor/PayloadProcessor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Processor/PayloadProcessor.cs @@ -54,7 +54,10 @@ internal Payload Process(Stream stream) payload.SplitIndex = BinaryPrimitives.ReadInt32BigEndian(splitIndexBytes); payload.Version = SerDe.ReadString(stream); + payload.TaskContext = new TaskContextProcessor(_version).Process(stream); + TaskContextHolder.Set(payload.TaskContext); + payload.SparkFilesDir = SerDe.ReadString(stream); if (Utils.SettingUtils.IsDatabricks) diff --git a/src/csharp/Microsoft.Spark/Attributes.cs b/src/csharp/Microsoft.Spark/Attributes.cs index 6f68607a7..bc7bb5a4d 100644 --- a/src/csharp/Microsoft.Spark/Attributes.cs +++ b/src/csharp/Microsoft.Spark/Attributes.cs @@ -73,4 +73,12 @@ public DeprecatedAttribute(string version) { } } + + /// + /// Custom attribute to denote that a class is a Udf Wrapper. + /// + [AttributeUsage(AttributeTargets.Class)] + internal sealed class UdfWrapperAttribute : Attribute + { + } } diff --git a/src/csharp/Microsoft.Spark/RDD.cs b/src/csharp/Microsoft.Spark/RDD.cs index 7eda57c61..baa4855ac 100644 --- a/src/csharp/Microsoft.Spark/RDD.cs +++ b/src/csharp/Microsoft.Spark/RDD.cs @@ -261,16 +261,14 @@ public RDD Sample(bool withReplacement, double fraction, long? seed = null) public IEnumerable Collect() { (int port, string secret) = CollectAndServe(); - using (ISocketWrapper socket = SocketFactory.CreateSocket()) - { - socket.Connect(IPAddress.Loopback, port, secret); + using ISocketWrapper socket = SocketFactory.CreateSocket(); + socket.Connect(IPAddress.Loopback, port, secret); - var collector = new RDD.Collector(); - System.IO.Stream stream = socket.InputStream; - foreach (T element in collector.Collect(stream, _serializedMode).Cast()) - { - yield return element; - } + var collector = new RDD.Collector(); + System.IO.Stream stream = socket.InputStream; + foreach (T element in collector.Collect(stream, _serializedMode).Cast()) + { + yield return element; } } @@ -341,6 +339,7 @@ private JvmObjectReference GetJvmRef() /// /// Input type /// Output type + [UdfWrapper] internal sealed class MapUdfWrapper { private readonly Func _func; @@ -361,6 +360,7 @@ internal IEnumerable Execute(int pid, IEnumerable input) /// /// Input type /// Output type + [UdfWrapper] internal sealed class FlatMapUdfWrapper { private readonly Func> _func; @@ -382,6 +382,7 @@ internal IEnumerable Execute(int pid, IEnumerable input) /// /// Input type /// Output type + [UdfWrapper] internal sealed class MapPartitionsUdfWrapper { private readonly Func, IEnumerable> _func; @@ -403,6 +404,7 @@ internal IEnumerable Execute(int pid, IEnumerable input) /// /// Input type /// Output type + [UdfWrapper] internal sealed class MapPartitionsWithIndexUdfWrapper { private readonly Func, IEnumerable> _func; @@ -423,9 +425,11 @@ internal IEnumerable Execute(int pid, IEnumerable input) /// Helper to map the UDF for Filter() to /// . /// + [UdfWrapper] internal class FilterUdfWrapper { private readonly Func _func; + internal FilterUdfWrapper(Func func) { _func = func; diff --git a/src/csharp/Microsoft.Spark/RDD/Collector.cs b/src/csharp/Microsoft.Spark/RDD/Collector.cs index 489fe3b75..f9442d988 100644 --- a/src/csharp/Microsoft.Spark/RDD/Collector.cs +++ b/src/csharp/Microsoft.Spark/RDD/Collector.cs @@ -5,8 +5,11 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Runtime.Serialization.Formatters.Binary; using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Utils; using static Microsoft.Spark.Utils.CommandSerDe; namespace Microsoft.Spark.RDD @@ -47,6 +50,8 @@ internal static IDeserializer GetDeserializer(SerializedMode mode) return new BinaryDeserializer(); case SerializedMode.String: return new StringDeserializer(); + case SerializedMode.Row: + return new RowDeserializer(); default: throw new ArgumentException($"Unsupported mode found {mode}"); } @@ -83,5 +88,21 @@ public object Deserialize(Stream stream, int length) return SerDe.ReadString(stream, length); } } + + /// + /// Deserializer for Pickled Rows. + /// + private sealed class RowDeserializer : IDeserializer + { + public object Deserialize(Stream stream, int length) + { + // Refer to the AutoBatchedPickler class in spark/core/src/main/scala/org/apache/ + // spark/api/python/SerDeUtil.scala regarding how the Rows may be batched. + return PythonSerDe.GetUnpickledObjects(stream, length) + .Cast() + .Select(rc => rc.GetRow()) + .ToArray(); + } + } } } diff --git a/src/csharp/Microsoft.Spark/RDD/WorkerFunction.cs b/src/csharp/Microsoft.Spark/RDD/WorkerFunction.cs index c1f556ef3..bf54bdbbb 100644 --- a/src/csharp/Microsoft.Spark/RDD/WorkerFunction.cs +++ b/src/csharp/Microsoft.Spark/RDD/WorkerFunction.cs @@ -37,7 +37,7 @@ internal static WorkerFunction Chain( WorkerFunction outerFunction) { return new WorkerFunction( - new WrokerFuncChainHelper( + new WorkerFuncChainHelper( innerFunction.Func, outerFunction.Func).Execute); } @@ -45,12 +45,13 @@ internal static WorkerFunction Chain( /// /// Helper to chain two delegates. /// - private sealed class WrokerFuncChainHelper + [UdfWrapper] + private sealed class WorkerFuncChainHelper { private readonly ExecuteDelegate _innerFunc; private readonly ExecuteDelegate _outerFunc; - internal WrokerFuncChainHelper(ExecuteDelegate innerFunc, ExecuteDelegate outerFunc) + internal WorkerFuncChainHelper(ExecuteDelegate innerFunc, ExecuteDelegate outerFunc) { _innerFunc = innerFunc; _outerFunc = outerFunc; diff --git a/src/csharp/Microsoft.Spark/Sql/ArrowGroupedMapUdfWrapper.cs b/src/csharp/Microsoft.Spark/Sql/ArrowGroupedMapUdfWrapper.cs index 63c042cc7..cc4b134e8 100644 --- a/src/csharp/Microsoft.Spark/Sql/ArrowGroupedMapUdfWrapper.cs +++ b/src/csharp/Microsoft.Spark/Sql/ArrowGroupedMapUdfWrapper.cs @@ -13,6 +13,7 @@ namespace Microsoft.Spark.Sql /// /// UDF serialization requires a "wrapper" object in order to serialize/deserialize. /// + [UdfWrapper] internal sealed class ArrowGroupedMapUdfWrapper { private readonly Func _func; diff --git a/src/csharp/Microsoft.Spark/Sql/ArrowUdfWrapper.cs b/src/csharp/Microsoft.Spark/Sql/ArrowUdfWrapper.cs index 538e05f85..bece86531 100644 --- a/src/csharp/Microsoft.Spark/Sql/ArrowUdfWrapper.cs +++ b/src/csharp/Microsoft.Spark/Sql/ArrowUdfWrapper.cs @@ -13,6 +13,7 @@ namespace Microsoft.Spark.Sql /// /// Specifies the type of the first argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T : IArrowArray where TResult : IArrowArray @@ -45,6 +46,7 @@ internal IArrowArray Execute(ReadOnlyMemory input, int[] argOffsets /// Specifies the type of the first argument to the UDF. /// Specifies the type of the second argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T1 : IArrowArray where T2 : IArrowArray @@ -80,6 +82,7 @@ internal IArrowArray Execute(ReadOnlyMemory input, int[] argOffsets /// Specifies the type of the second argument to the UDF. /// Specifies the type of the third argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T1 : IArrowArray where T2 : IArrowArray @@ -118,6 +121,7 @@ internal IArrowArray Execute(ReadOnlyMemory input, int[] argOffsets /// Specifies the type of the third argument to the UDF. /// Specifies the type of the fourth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T1 : IArrowArray where T2 : IArrowArray @@ -159,6 +163,7 @@ internal IArrowArray Execute(ReadOnlyMemory input, int[] argOffsets /// Specifies the type of the fourth argument to the UDF. /// Specifies the type of the fifth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T1 : IArrowArray where T2 : IArrowArray @@ -203,6 +208,7 @@ internal IArrowArray Execute(ReadOnlyMemory input, int[] argOffsets /// Specifies the type of the fifth argument to the UDF. /// Specifies the type of the sixth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T1 : IArrowArray where T2 : IArrowArray @@ -250,6 +256,7 @@ internal IArrowArray Execute(ReadOnlyMemory input, int[] argOffsets /// Specifies the type of the sixth argument to the UDF. /// Specifies the type of the seventh argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T1 : IArrowArray where T2 : IArrowArray @@ -300,6 +307,7 @@ internal IArrowArray Execute(ReadOnlyMemory input, int[] argOffsets /// Specifies the type of the seventh argument to the UDF. /// Specifies the type of the eighth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T1 : IArrowArray where T2 : IArrowArray @@ -353,6 +361,7 @@ internal IArrowArray Execute(ReadOnlyMemory input, int[] argOffsets /// Specifies the type of the eighth argument to the UDF. /// Specifies the type of the ninth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T1 : IArrowArray where T2 : IArrowArray @@ -409,6 +418,7 @@ internal IArrowArray Execute(ReadOnlyMemory input, int[] argOffsets /// Specifies the type of the ninth argument to the UDF. /// Specifies the type of the tenth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal sealed class ArrowUdfWrapper where T1 : IArrowArray where T2 : IArrowArray diff --git a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs index 53bf85c7f..40608cbf7 100644 --- a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs +++ b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs @@ -874,7 +874,7 @@ public DataFrameWriter Write() => /// /// DataStreamWriter object public DataStreamWriter WriteStream() => - new DataStreamWriter((JvmObjectReference)_jvmObject.Invoke("writeStream")); + new DataStreamWriter((JvmObjectReference)_jvmObject.Invoke("writeStream"), this); /// /// Returns row objects based on the function (either "toPythonIterator" or diff --git a/src/csharp/Microsoft.Spark/Sql/ForeachWriter.cs b/src/csharp/Microsoft.Spark/Sql/ForeachWriter.cs new file mode 100644 index 000000000..27b47f7d0 --- /dev/null +++ b/src/csharp/Microsoft.Spark/Sql/ForeachWriter.cs @@ -0,0 +1,193 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.Spark.Sql +{ + /// + /// Interface for writing custom logic to process data generated by a query. This is + /// often used to write the output of a streaming query to arbitrary storage systems. + /// + /// + /// + /// Any implementation of this interface will be used by Spark in the following way: + /// + /// + /// + /// A single instance of this class is responsible of all the data generated by a single task + /// in a query. In other words, one instance is responsible for processing one partition of the + /// data generated in a distributed manner. + /// + /// + /// + /// + /// Any implementation of this class must be because each + /// task will get a fresh serialized-deserialized copy of the provided object. Hence, it is + /// strongly recommended that any initialization for writing data (e.g.opening a connection or + /// starting a transaction) is done after the method has been + /// called, which signifies that the task is ready to generate data. + /// + /// + /// + /// + /// The lifecycle of the methods are as follows: + /// + /// For each partition with partitionId: + /// ... For each batch/epoch of streaming data(if its streaming query) with epochId: + /// ....... Method Open(partitionId, epochId) is called. + /// ....... If Open returns true: + /// ........... For each row in the partition and batch/epoch, method Process(row) is called. + /// ....... Method Close(errorOrNull) is called with error(if any) seen while processing rows. + /// + /// + /// + /// + /// + /// + /// + /// Important points to note: + /// + /// + /// + /// The partitionId and epochId can be used to deduplicate generated data + /// when failures cause reprocessing of some input data. This depends on the execution + /// mode of the query. If the streaming query is being executed in the micro-batch + /// mode, then every partition represented by a unique tuple(partition_id, epoch_id) + /// is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used + /// to deduplicate and/or transactionally commit data and achieve exactly-once + /// guarantees. However, if the streaming query is being executed in the continuous + /// mode, then this guarantee does not hold and therefore should not be used for + /// deduplication. + /// + /// + /// + /// + /// + public interface IForeachWriter + { + /// + /// Called when starting to process one partition of new data in the executor. + /// + /// The partition id. + /// A unique id for data deduplication. + /// True if successful, false otherwise. + bool Open(long partitionId, long epochId); + + /// + /// Called to process each in the executor side. This method + /// will be called only if Open returns true. + /// + /// The row to process. + void Process(Row row); + + /// + /// Called when stopping to process one partition of new data in the executor side. This is + /// guaranteed to be called either returns true or + /// false. However, won't be called in the following + /// cases: + /// + /// + /// + /// CLR/JVM crashes without throwing a . + /// + /// + /// + /// + /// throws an . + /// + /// + /// + /// + /// + /// The thrown during processing or null if there was no error. + /// + void Close(Exception errorOrNull); + } + + /// + /// Wraps a and calls the appropriate methods as decribed in + /// the lifecycle documentation for the interface. + /// + internal class ForeachWriterWrapper + { + private readonly IForeachWriter _foreachWriter; + + internal ForeachWriterWrapper(IForeachWriter foreachWriter) => + _foreachWriter = foreachWriter; + + internal IEnumerable Process(int partitionId, IEnumerable rows) + { + if (!TaskContextHolder.Get().LocalProperties.TryGetValue( + "streaming.sql.batchId", + out string epochIdStr) || !long.TryParse(epochIdStr, out long epochId)) + { + throw new Exception( + $"Could not get or parse batch id from TaskContext - batchId: {epochIdStr}"); + } + + Exception error = null; + bool opened = _foreachWriter.Open(partitionId, epochId); + + try + { + if (opened) + { + foreach (Row row in rows) + { + _foreachWriter.Process(row); + } + } + } + catch (Exception e) + { + error = e; + } + finally + { + _foreachWriter.Close(error); + + if (error != null) + { + throw error; + } + } + + // An empty IEnumerable is returned because ForEach is a sink operation, + // but something needs to be returned to work within the UDF framework. + return Enumerable.Empty(); + } + } + + /// + /// Wraps the given Func object, which represents a UDF. + /// When this UdfWrapper is processed, the PythonEvalType is + /// . The CommandExecutor expects the + /// method to match the + /// delegate. This UdfWrapper helps map + /// the UDF for to . + /// + [UdfWrapper] + internal class ForeachWriterWrapperUdfWrapper + { + private readonly Func, IEnumerable> _func; + + internal ForeachWriterWrapperUdfWrapper( + Func, IEnumerable> func) + { + _func = func; + } + + internal IEnumerable Execute(int pid, IEnumerable input) + { + // input is an IEnumerable, where each Row[] is batched using the + // org.apache.spark.api.python.SerDeUtil.AutoBatchedPickler algorithm. + // Refer to spark/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ + // PythonForeachWriter.scala for more information. + return _func(pid, input.Cast().SelectMany(row => row)); + } + } +} diff --git a/src/csharp/Microsoft.Spark/Sql/PicklingUdfWrapper.cs b/src/csharp/Microsoft.Spark/Sql/PicklingUdfWrapper.cs index 7c4d5d02c..721dc049b 100644 --- a/src/csharp/Microsoft.Spark/Sql/PicklingUdfWrapper.cs +++ b/src/csharp/Microsoft.Spark/Sql/PicklingUdfWrapper.cs @@ -10,6 +10,7 @@ namespace Microsoft.Spark.Sql /// Wraps the given Func object, which represents a UDF. /// /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -30,6 +31,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// /// Specifies the type of the first argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -51,6 +53,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// Specifies the type of the first argument to the UDF. /// Specifies the type of the second argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -73,6 +76,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// Specifies the type of the second argument to the UDF. /// Specifies the type of the third argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -99,6 +103,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// Specifies the type of the third argument to the UDF. /// Specifies the type of the fourth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -127,6 +132,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// Specifies the type of the fourth argument to the UDF. /// Specifies the type of the fifth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -157,6 +163,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// Specifies the type of the fifth argument to the UDF. /// Specifies the type of the sixth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -189,6 +196,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// Specifies the type of the sixth argument to the UDF. /// Specifies the type of the seventh argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -223,6 +231,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// Specifies the type of the seventh argument to the UDF. /// Specifies the type of the eighth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -259,6 +268,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// Specifies the type of the eighth argument to the UDF. /// Specifies the type of the ninth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; @@ -296,6 +306,7 @@ internal object Execute(int splitIndex, object[] input, int[] argOffsets) /// Specifies the type of the ninth argument to the UDF. /// Specifies the type of the tenth argument to the UDF. /// Specifies the return type of the UDF. + [UdfWrapper] internal class PicklingUdfWrapper { private readonly Func _func; diff --git a/src/csharp/Microsoft.Spark/Sql/RowConstructor.cs b/src/csharp/Microsoft.Spark/Sql/RowConstructor.cs index 083e94cb5..2ad613061 100644 --- a/src/csharp/Microsoft.Spark/Sql/RowConstructor.cs +++ b/src/csharp/Microsoft.Spark/Sql/RowConstructor.cs @@ -13,7 +13,7 @@ namespace Microsoft.Spark.Sql /// /// RowConstructor is a custom unpickler for GenericRowWithSchema in Spark. /// Refer to spark/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ - /// EvaluatePython.scala how GenericRowWithSchema is being pickeld. + /// EvaluatePython.scala how GenericRowWithSchema is being pickled. /// internal sealed class RowConstructor : IObjectConstructor { diff --git a/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs b/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs index d29c80e7a..2cf752459 100644 --- a/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs +++ b/src/csharp/Microsoft.Spark/Sql/Streaming/DataStreamWriter.cs @@ -4,6 +4,8 @@ using System.Collections.Generic; using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Sql.Types; +using Microsoft.Spark.Utils; namespace Microsoft.Spark.Sql.Streaming { @@ -14,8 +16,13 @@ namespace Microsoft.Spark.Sql.Streaming public sealed class DataStreamWriter : IJvmObjectReferenceProvider { private readonly JvmObjectReference _jvmObject; + private readonly DataFrame _df; - internal DataStreamWriter(JvmObjectReference jvmObject) => _jvmObject = jvmObject; + internal DataStreamWriter(JvmObjectReference jvmObject, DataFrame df) + { + _jvmObject = jvmObject; + _df = df; + } JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; @@ -76,7 +83,7 @@ public DataStreamWriter PartitionBy(params string[] colNames) /// /// Name of the option /// Value of the option - /// This DataStreamReader object + /// This DataStreamWriter object public DataStreamWriter Option(string key, string value) { OptionInternal(key, value); @@ -88,7 +95,7 @@ public DataStreamWriter Option(string key, string value) /// /// Name of the option /// Value of the option - /// This DataStreamReader object + /// This DataStreamWriter object public DataStreamWriter Option(string key, bool value) { OptionInternal(key, value); @@ -100,7 +107,7 @@ public DataStreamWriter Option(string key, bool value) /// /// Name of the option /// Value of the option - /// This DataStreamReader object + /// This DataStreamWriter object public DataStreamWriter Option(string key, long value) { OptionInternal(key, value); @@ -112,7 +119,7 @@ public DataStreamWriter Option(string key, long value) /// /// Name of the option /// Value of the option - /// This DataStreamReader object + /// This DataStreamWriter object public DataStreamWriter Option(string key, double value) { OptionInternal(key, value); @@ -123,7 +130,7 @@ public DataStreamWriter Option(string key, double value) /// Adds output options for the underlying data source. /// /// Key/value options - /// This DataStreamReader object + /// This DataStreamWriter object public DataStreamWriter Options(Dictionary options) { _jvmObject.Invoke("options", options); @@ -134,7 +141,7 @@ public DataStreamWriter Options(Dictionary options) /// Sets the trigger for the stream query. /// /// Trigger object - /// This DataStreamReader object + /// This DataStreamWriter object public DataStreamWriter Trigger(Trigger trigger) { _jvmObject.Invoke("trigger", trigger); @@ -148,7 +155,7 @@ public DataStreamWriter Trigger(Trigger trigger) /// in the associated SQLContext. /// /// Query name - /// This DataStreamReader object + /// This DataStreamWriter object public DataStreamWriter QueryName(string queryName) { _jvmObject.Invoke("queryName", queryName); @@ -169,12 +176,41 @@ public StreamingQuery Start(string path = null) return new StreamingQuery((JvmObjectReference)_jvmObject.Invoke("start")); } + /// + /// Sets the output of the streaming query to be processed using the provided + /// writer object. See for more details on the + /// lifecycle and semantics. + /// + /// + /// This DataStreamWriter object + [Since(Versions.V2_4_0)] + public DataStreamWriter Foreach(IForeachWriter writer) + { + RDD.WorkerFunction.ExecuteDelegate wrapper = + new ForeachWriterWrapperUdfWrapper( + new ForeachWriterWrapper(writer).Process).Execute; + + _jvmObject.Invoke( + "foreach", + _jvmObject.Jvm.CallConstructor( + "org.apache.spark.sql.execution.python.PythonForeachWriter", + UdfUtils.CreatePythonFunction( + _jvmObject.Jvm, + CommandSerDe.Serialize( + wrapper, + CommandSerDe.SerializedMode.Row, + CommandSerDe.SerializedMode.Row)), + DataType.FromJson(_jvmObject.Jvm, _df.Schema().Json))); + + return this; + } + /// /// Helper function to add given key/value pair as a new option. /// /// Name of the option /// Value of the option - /// This DataFrameReader object + /// This DataStreamWriter object private DataStreamWriter OptionInternal(string key, object value) { _jvmObject.Invoke("option", key, value); diff --git a/src/csharp/Microsoft.Spark/TaskContext.cs b/src/csharp/Microsoft.Spark/TaskContext.cs new file mode 100644 index 000000000..01f756a36 --- /dev/null +++ b/src/csharp/Microsoft.Spark/TaskContext.cs @@ -0,0 +1,95 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.Spark +{ + /// + /// TaskContext stores information related to a task. + /// + internal class TaskContext + { + internal int StageId { get; set; } + + internal int PartitionId { get; set; } + + internal int AttemptNumber { get; set; } + + internal long AttemptId { get; set; } + + internal bool IsBarrier { get; set; } + + internal int Port { get; set; } + + internal string Secret { get; set; } + + internal IEnumerable Resources { get; set; } = new List(); + + internal Dictionary LocalProperties { get; set; } = + new Dictionary(); + + public override bool Equals(object obj) + { + if (!(obj is TaskContext other)) + { + return false; + } + + return (StageId == other.StageId) && + (PartitionId == other.PartitionId) && + (AttemptNumber == other.AttemptNumber) && + (AttemptId == other.AttemptId) && + Resources.SequenceEqual(other.Resources) && + (LocalProperties.Count == other.LocalProperties.Count) && + !LocalProperties.Except(other.LocalProperties).Any(); + } + + public override int GetHashCode() + { + return StageId; + } + + internal class Resource + { + internal string Key { get; set; } + internal string Value { get; set; } + internal IEnumerable Addresses { get; set; } = new List(); + + public override bool Equals(object obj) + { + if (!(obj is Resource other)) + { + return false; + } + + return (Key == other.Key) && + (Value == other.Value) && + Addresses.SequenceEqual(Addresses); + } + + public override int GetHashCode() + { + return Key.GetHashCode(); + } + } + } + + // TaskContextHolder contains the TaskContext for the current Thread. + internal static class TaskContextHolder + { + // Multiple Tasks can be assigned to a Worker process. Each + // Task will run in its own thread until completion. Therefore + // we set this field as a thread local variable, where each + // thread will have its own copy of the TaskContext. + [ThreadStatic] + internal static TaskContext s_taskContext; + + internal static TaskContext Get() => s_taskContext; + + internal static void Set(TaskContext tc) => s_taskContext = tc; + } +} diff --git a/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs b/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs index 9c6d0f53f..53c87c527 100644 --- a/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs +++ b/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs @@ -44,7 +44,7 @@ internal enum SerializedMode /// - RDD: * /// * /// * - /// * + /// * /// [Serializable] private sealed class UdfWrapperNode @@ -181,7 +181,8 @@ private static void SerializeUdfs( List udfs) { UdfSerDe.UdfData udfData = UdfSerDe.Serialize(func); - if (udfData.MethodName != UdfWrapperMethodName) + if ((udfData.MethodName != UdfWrapperMethodName) || + !Attribute.IsDefined(func.Target.GetType(), typeof(UdfWrapperAttribute))) { // Found the actual UDF. if (parent != null) diff --git a/src/csharp/Microsoft.Spark/Versions.cs b/src/csharp/Microsoft.Spark/Versions.cs index 5955ffda8..162bb2178 100644 --- a/src/csharp/Microsoft.Spark/Versions.cs +++ b/src/csharp/Microsoft.Spark/Versions.cs @@ -17,6 +17,6 @@ internal static class Versions // The following is used to check the compatibility of UDFs between // the driver and worker side. This needs to be updated only when there // is a breaking change on the UDF contract. - internal const string CurrentVersion = "0.4.0"; + internal const string CurrentVersion = "0.9.0"; } }