diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 1dc52d81f5..fd529208c0 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -634,23 +634,25 @@ protected override bool MoveNextCore() while (_liveCount < _poolRows && !_doneConsuming) { // We are under capacity. Try to get some more. - while (_toConsumeChannel.Reader.WaitToReadAsync().GetAwaiter().GetResult()) + var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got); + if (hasReadItem) { - var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got); - if (hasReadItem) + if (got == 0) { - if (got == 0) - { - // We've reached the end of the Channel. There's no reason - // to attempt further communication with the producer. - // Check whether something horrible happened. - if (_producerTaskException != null) - throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception"); - _doneConsuming = true; - break; - } - _liveCount += got; + // We've reached the end of the Channel. There's no reason + // to attempt further communication with the producer. + // Check whether something horrible happened. + if (_producerTaskException != null) + throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception"); + _doneConsuming = true; + break; } + _liveCount += got; + } + else + { + // Sleeping for one millisecond to stop the thread from spinning while waiting for the producer. + Thread.Sleep(1); } } if (_liveCount == 0) diff --git a/test/Microsoft.ML.Tests/Scenarios/RegressionTest.cs b/test/Microsoft.ML.Tests/Scenarios/RegressionTest.cs new file mode 100644 index 0000000000..c179a6f341 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/RegressionTest.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Tests; +using Xunit; + +namespace Microsoft.ML.Scenarios +{ + public partial class ScenariosTests + { + [Fact] + public void TestRegressionScenario() + { + var context = new MLContext(); + + string taxiDataPath = GetDataPath("taxi-fare-train.csv"); + + var taxiData = + context.Data.LoadFromTextFile(taxiDataPath, hasHeader: true, + separatorChar: ','); + + var splitData = context.Data.TrainTestSplit(taxiData, testFraction: 0.1); + + IDataView trainingDataView = context.Data.FilterRowsByColumn(splitData.TrainSet, "FareAmount", lowerBound: 1, upperBound: 150); + + var dataProcessPipeline = context.Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: "FareAmount") + .Append(context.Transforms.Categorical.OneHotEncoding(outputColumnName: "VendorIdEncoded", inputColumnName: "VendorId")) + .Append(context.Transforms.Categorical.OneHotEncoding(outputColumnName: "RateCodeEncoded", inputColumnName: "RateCode")) + .Append(context.Transforms.Categorical.OneHotEncoding(outputColumnName: "PaymentTypeEncoded", inputColumnName: "PaymentType")) + .Append(context.Transforms.NormalizeMeanVariance(outputColumnName: "PassengerCount")) + .Append(context.Transforms.NormalizeMeanVariance(outputColumnName: "TripTime")) + .Append(context.Transforms.NormalizeMeanVariance(outputColumnName: "TripDistance")) + .Append(context.Transforms.Concatenate("Features", "VendorIdEncoded", "RateCodeEncoded", "PaymentTypeEncoded", "PassengerCount", + "TripTime", "TripDistance")); + + var trainer = context.Regression.Trainers.Sdca(labelColumnName: "Label", featureColumnName: "Features"); + var trainingPipeline = dataProcessPipeline.Append(trainer); + + var model = trainingPipeline.Fit(trainingDataView); + + var predictions = model.Transform(splitData.TestSet); + + var metrics = context.Regression.Evaluate(predictions); + + Assert.True(metrics.RSquared > .8); + Assert.True(metrics.RootMeanSquaredError > 2); + } + } +}