Skip to content

Commit 5f53809

Browse files
authored
Expose DataStreamWriter.Foreach API (#387)
1 parent a8db985 commit 5f53809

File tree

17 files changed

+624
-97
lines changed

17 files changed

+624
-97
lines changed

src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
7+
using System.IO;
8+
using System.Linq;
9+
using Microsoft.Spark.E2ETest.Utils;
610
using Microsoft.Spark.Sql;
711
using Microsoft.Spark.Sql.Streaming;
12+
using Microsoft.Spark.Sql.Types;
813
using Xunit;
14+
using static Microsoft.Spark.Sql.Functions;
915

1016
namespace Microsoft.Spark.E2ETest.IpcTests
1117
{
@@ -59,5 +65,213 @@ public void TestSignaturesV2_3_X()
5965

6066
Assert.IsType<DataStreamWriter>(dsw.Trigger(Trigger.Once()));
6167
}
68+
69+
[SkipIfSparkVersionIsLessThan(Versions.V2_4_0)]
70+
public void TestForeach()
71+
{
72+
// Temporary folder to put our test stream input.
73+
using var srcTempDirectory = new TemporaryDirectory();
74+
string streamInputPath = Path.Combine(srcTempDirectory.Path, "streamInput");
75+
76+
Func<Column, Column> intToStrUdf = Udf<int, string>(i => i.ToString());
77+
78+
// id column: [1, 2, ..., 99]
79+
// idStr column: "id" column converted to string
80+
// idAndIdStr column: Struct column composed from the "id" and "idStr" column.
81+
_spark.Range(1, 100)
82+
.WithColumn("idStr", intToStrUdf(Col("id")))
83+
.WithColumn("idAndIdStr", Struct("id", "idStr"))
84+
.Write()
85+
.Json(streamInputPath);
86+
87+
// Test a scenario where IForeachWriter runs without issues.
88+
// If everything is working as expected, then:
89+
// - Triggering stream will not throw an exception
90+
// - 3 CSV files will be created in the temporary directory.
91+
// - 0 Exception files will be created in the temporary directory.
92+
// - The CSV files will contain valid data to read, where the
93+
// expected entries will contain [1111, 2222, ..., 99999999]
94+
TestAndValidateForeach(
95+
streamInputPath,
96+
new TestForeachWriter(),
97+
3,
98+
0,
99+
Enumerable.Range(1, 99).Select(i => Convert.ToInt32($"{i}{i}{i}{i}")));
100+
101+
// Test scenario where IForeachWriter.Open returns false.
102+
// When IForeachWriter.Open returns false, then IForeachWriter.Process
103+
// is not called. Verify that:
104+
// - Triggering stream will not throw an exception
105+
// - 3 CSV files will be created in the temporary directory.
106+
// - 0 Exception files will be created in the temporary directory.
107+
// - The CSV files will not contain valid data to read.
108+
TestAndValidateForeach(
109+
streamInputPath,
110+
new TestForeachWriterOpenFailure(),
111+
3,
112+
0,
113+
Enumerable.Empty<int>());
114+
115+
// Test scenario where IForeachWriter.Process throws an Exception.
116+
// When IForeachWriter.Process throws an Exception, then the exception
117+
// is rethrown by ForeachWriterWrapper. We will limit the partitions
118+
// to 1 to make validating this scenario simpler. Verify that:
119+
// - Triggering stream throws an exception.
120+
// - 1 CSV file will be created in the temporary directory.
121+
// - 1 Exception will be created in the temporary directory. The
122+
// thrown exception from Process() will be sent to Close().
123+
// - The CSV file will not contain valid data to read.
124+
TestAndValidateForeach(
125+
streamInputPath,
126+
new TestForeachWriterProcessFailure(),
127+
1,
128+
1,
129+
Enumerable.Empty<int>());
130+
}
131+
132+
private void TestAndValidateForeach(
133+
string streamInputPath,
134+
TestForeachWriter foreachWriter,
135+
int expectedCSVFiles,
136+
int expectedExceptionFiles,
137+
IEnumerable<int> expectedOutput)
138+
{
139+
// Temporary folder the TestForeachWriter will write to.
140+
using var dstTempDirectory = new TemporaryDirectory();
141+
foreachWriter.WritePath = dstTempDirectory.Path;
142+
143+
// Read streamInputPath, repartition data, then
144+
// call TestForeachWriter on the data.
145+
DataStreamWriter dsw = _spark
146+
.ReadStream()
147+
.Schema(new StructType(new[]
148+
{
149+
new StructField("id", new IntegerType()),
150+
new StructField("idStr", new StringType()),
151+
new StructField("idAndIdStr", new StructType(new[]
152+
{
153+
new StructField("id", new IntegerType()),
154+
new StructField("idStr", new StringType())
155+
}))
156+
}))
157+
.Json(streamInputPath)
158+
.Repartition(expectedCSVFiles)
159+
.WriteStream()
160+
.Foreach(foreachWriter);
161+
162+
// Trigger the stream batch once.
163+
if (expectedExceptionFiles > 0)
164+
{
165+
Assert.Throws<Exception>(
166+
() => dsw.Trigger(Trigger.Once()).Start().AwaitTermination());
167+
}
168+
else
169+
{
170+
dsw.Trigger(Trigger.Once()).Start().AwaitTermination();
171+
}
172+
173+
// Verify that TestForeachWriter created a unique .csv when
174+
// ForeachWriter.Open was called on each partitionId.
175+
Assert.Equal(
176+
expectedCSVFiles,
177+
Directory.GetFiles(dstTempDirectory.Path, "*.csv").Length);
178+
179+
// Only if ForeachWriter.Process(Row) throws an exception, will
180+
// ForeachWriter.Close(Exception) create a file with the
181+
// .exception extension.
182+
Assert.Equal(
183+
expectedExceptionFiles,
184+
Directory.GetFiles(dstTempDirectory.Path, "*.exception").Length);
185+
186+
// Read in the *.csv file(s) generated by the TestForeachWriter.
187+
// If there are multiple input files, sorting by "id" will make
188+
// validation simpler. Contents of the *.csv will only be populated
189+
// on successful calls to the ForeachWriter.Process method.
190+
DataFrame foreachWriterOutputDF = _spark
191+
.Read()
192+
.Schema("id INT")
193+
.Csv(dstTempDirectory.Path)
194+
.Sort("id");
195+
196+
// Validate expected *.csv data.
197+
Assert.Equal(
198+
expectedOutput.Select(i => new object[] { i }),
199+
foreachWriterOutputDF.Collect().Select(r => r.Values));
200+
}
201+
202+
[Serializable]
203+
private class TestForeachWriter : IForeachWriter
204+
{
205+
[NonSerialized]
206+
private StreamWriter _streamWriter;
207+
208+
private long _partitionId;
209+
210+
private long _epochId;
211+
212+
internal string WritePath { get; set; }
213+
214+
public void Close(Exception errorOrNull)
215+
{
216+
if (errorOrNull != null)
217+
{
218+
FileStream fs = File.Create(
219+
Path.Combine(
220+
WritePath,
221+
$"Close-{_partitionId}-{_epochId}.exception"));
222+
fs.Dispose();
223+
}
224+
225+
_streamWriter?.Dispose();
226+
}
227+
228+
public virtual bool Open(long partitionId, long epochId)
229+
{
230+
_partitionId = partitionId;
231+
_epochId = epochId;
232+
try
233+
{
234+
_streamWriter = new StreamWriter(
235+
Path.Combine(
236+
WritePath,
237+
$"sink-foreachWriter-{_partitionId}-{_epochId}.csv"));
238+
return true;
239+
}
240+
catch
241+
{
242+
return false;
243+
}
244+
}
245+
246+
public virtual void Process(Row value)
247+
{
248+
Row idAndIdStr = value.GetAs<Row>("idAndIdStr");
249+
_streamWriter.WriteLine(
250+
string.Format("{0}{1}{2}{3}",
251+
value.GetAs<int>("id"),
252+
value.GetAs<string>("idStr"),
253+
idAndIdStr.GetAs<int>("id"),
254+
idAndIdStr.GetAs<string>("idStr")));
255+
}
256+
}
257+
258+
[Serializable]
259+
private class TestForeachWriterOpenFailure : TestForeachWriter
260+
{
261+
public override bool Open(long partitionId, long epochId)
262+
{
263+
base.Open(partitionId, epochId);
264+
return false;
265+
}
266+
}
267+
268+
[Serializable]
269+
private class TestForeachWriterProcessFailure : TestForeachWriter
270+
{
271+
public override void Process(Row value)
272+
{
273+
throw new Exception("TestForeachWriterProcessFailure Process(Row) failure.");
274+
}
275+
}
62276
}
63277
}

src/csharp/Microsoft.Spark.Worker/Payload.cs

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -3,81 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Collections.Generic;
6-
using System.Linq;
76
using Microsoft.Spark.Utils;
87

98
namespace Microsoft.Spark.Worker
109
{
11-
/// <summary>
12-
/// TaskContext stores information related to a task.
13-
/// </summary>
14-
internal class TaskContext
15-
{
16-
internal int StageId { get; set; }
17-
18-
internal int PartitionId { get; set; }
19-
20-
internal int AttemptNumber { get; set; }
21-
22-
internal long AttemptId { get; set; }
23-
24-
internal bool IsBarrier { get; set; }
25-
26-
internal int Port { get; set; }
27-
28-
internal string Secret { get; set; }
29-
30-
internal IEnumerable<Resource> Resources { get; set; } = new List<Resource>();
31-
32-
internal Dictionary<string, string> LocalProperties { get; set; } =
33-
new Dictionary<string, string>();
34-
35-
public override bool Equals(object obj)
36-
{
37-
if (!(obj is TaskContext other))
38-
{
39-
return false;
40-
}
41-
42-
return (StageId == other.StageId) &&
43-
(PartitionId == other.PartitionId) &&
44-
(AttemptNumber == other.AttemptNumber) &&
45-
(AttemptId == other.AttemptId) &&
46-
Resources.SequenceEqual(other.Resources) &&
47-
(LocalProperties.Count == other.LocalProperties.Count) &&
48-
!LocalProperties.Except(other.LocalProperties).Any();
49-
}
50-
51-
public override int GetHashCode()
52-
{
53-
return StageId;
54-
}
55-
56-
internal class Resource
57-
{
58-
internal string Key { get; set; }
59-
internal string Value { get; set; }
60-
internal IEnumerable<string> Addresses { get; set; } = new List<string>();
61-
62-
public override bool Equals(object obj)
63-
{
64-
if (!(obj is Resource other))
65-
{
66-
return false;
67-
}
68-
69-
return (Key == other.Key) &&
70-
(Value == other.Value) &&
71-
Addresses.SequenceEqual(Addresses);
72-
}
73-
74-
public override int GetHashCode()
75-
{
76-
return Key.GetHashCode();
77-
}
78-
}
79-
}
80-
8110
/// <summary>
8211
/// BroadcastVariables stores information on broadcast variables.
8312
/// </summary>

src/csharp/Microsoft.Spark.Worker/Processor/PayloadProcessor.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ internal Payload Process(Stream stream)
5454

5555
payload.SplitIndex = BinaryPrimitives.ReadInt32BigEndian(splitIndexBytes);
5656
payload.Version = SerDe.ReadString(stream);
57+
5758
payload.TaskContext = new TaskContextProcessor(_version).Process(stream);
59+
TaskContextHolder.Set(payload.TaskContext);
60+
5861
payload.SparkFilesDir = SerDe.ReadString(stream);
5962

6063
if (Utils.SettingUtils.IsDatabricks)

src/csharp/Microsoft.Spark/Attributes.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,12 @@ public DeprecatedAttribute(string version)
7373
{
7474
}
7575
}
76+
77+
/// <summary>
78+
/// Custom attribute to denote that a class is a Udf Wrapper.
79+
/// </summary>
80+
[AttributeUsage(AttributeTargets.Class)]
81+
internal sealed class UdfWrapperAttribute : Attribute
82+
{
83+
}
7684
}

0 commit comments

Comments
 (0)