Skip to content

Commit dd318d8

Browse files
authored
Channels await fix (#5313)
* Update to only get result if awaiter has completed * Add thread sleep * Add regression scenario to make sure async fix works * Update from feedback
1 parent 6bae29f commit dd318d8

File tree

2 files changed

+66
-14
lines changed

2 files changed

+66
-14
lines changed

src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs

+16-14
Original file line numberDiff line numberDiff line change
@@ -634,23 +634,25 @@ protected override bool MoveNextCore()
634634
while (_liveCount < _poolRows && !_doneConsuming)
635635
{
636636
// We are under capacity. Try to get some more.
637-
while (_toConsumeChannel.Reader.WaitToReadAsync().GetAwaiter().GetResult())
637+
var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got);
638+
if (hasReadItem)
638639
{
639-
var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got);
640-
if (hasReadItem)
640+
if (got == 0)
641641
{
642-
if (got == 0)
643-
{
644-
// We've reached the end of the Channel. There's no reason
645-
// to attempt further communication with the producer.
646-
// Check whether something horrible happened.
647-
if (_producerTaskException != null)
648-
throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
649-
_doneConsuming = true;
650-
break;
651-
}
652-
_liveCount += got;
642+
// We've reached the end of the Channel. There's no reason
643+
// to attempt further communication with the producer.
644+
// Check whether something horrible happened.
645+
if (_producerTaskException != null)
646+
throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
647+
_doneConsuming = true;
648+
break;
653649
}
650+
_liveCount += got;
651+
}
652+
else
653+
{
654+
// Sleeping for one millisecond to stop the thread from spinning while waiting for the producer.
655+
Thread.Sleep(1);
654656
}
655657
}
656658
if (_liveCount == 0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Tests;
6+
using Xunit;
7+
8+
namespace Microsoft.ML.Scenarios
9+
{
10+
public partial class ScenariosTests
11+
{
12+
[Fact]
13+
public void TestRegressionScenario()
14+
{
15+
var context = new MLContext();
16+
17+
string taxiDataPath = GetDataPath("taxi-fare-train.csv");
18+
19+
var taxiData =
20+
context.Data.LoadFromTextFile<FeatureContributionTests.TaxiTrip>(taxiDataPath, hasHeader: true,
21+
separatorChar: ',');
22+
23+
var splitData = context.Data.TrainTestSplit(taxiData, testFraction: 0.1);
24+
25+
IDataView trainingDataView = context.Data.FilterRowsByColumn(splitData.TrainSet, "FareAmount", lowerBound: 1, upperBound: 150);
26+
27+
var dataProcessPipeline = context.Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: "FareAmount")
28+
.Append(context.Transforms.Categorical.OneHotEncoding(outputColumnName: "VendorIdEncoded", inputColumnName: "VendorId"))
29+
.Append(context.Transforms.Categorical.OneHotEncoding(outputColumnName: "RateCodeEncoded", inputColumnName: "RateCode"))
30+
.Append(context.Transforms.Categorical.OneHotEncoding(outputColumnName: "PaymentTypeEncoded", inputColumnName: "PaymentType"))
31+
.Append(context.Transforms.NormalizeMeanVariance(outputColumnName: "PassengerCount"))
32+
.Append(context.Transforms.NormalizeMeanVariance(outputColumnName: "TripTime"))
33+
.Append(context.Transforms.NormalizeMeanVariance(outputColumnName: "TripDistance"))
34+
.Append(context.Transforms.Concatenate("Features", "VendorIdEncoded", "RateCodeEncoded", "PaymentTypeEncoded", "PassengerCount",
35+
"TripTime", "TripDistance"));
36+
37+
var trainer = context.Regression.Trainers.Sdca(labelColumnName: "Label", featureColumnName: "Features");
38+
var trainingPipeline = dataProcessPipeline.Append(trainer);
39+
40+
var model = trainingPipeline.Fit(trainingDataView);
41+
42+
var predictions = model.Transform(splitData.TestSet);
43+
44+
var metrics = context.Regression.Evaluate(predictions);
45+
46+
Assert.True(metrics.RSquared > .8);
47+
Assert.True(metrics.RootMeanSquaredError > 2);
48+
}
49+
}
50+
}

0 commit comments

Comments
 (0)