diff --git a/src/Microsoft.ML/LearningPipeline.cs b/src/Microsoft.ML/LearningPipeline.cs index 8056d03418..27afe3adf8 100644 --- a/src/Microsoft.ML/LearningPipeline.cs +++ b/src/Microsoft.ML/LearningPipeline.cs @@ -11,6 +11,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; +using static Microsoft.ML.Runtime.DefaultEnvironment; namespace Microsoft.ML { @@ -48,15 +49,26 @@ public ScorerPipelineStep(Var data, Var model) [DebuggerTypeProxy(typeof(LearningPipelineDebugProxy))] public class LearningPipeline : ICollection { + readonly internal IHostEnvironment Env; private List Items { get; } = new List(); /// /// Construct an empty object. /// - public LearningPipeline() + public LearningPipeline(int? seed = null, int concurrency = 0) { + + var env = new DefaultEnvironment(seed: seed, conc: concurrency); + env.MessageRecieved += Env_MessageRecieved; + Env = env; + } + + private void Env_MessageRecieved(object sender, ChannelMessageEventArgs e) + { + MessageOccured?.Invoke(this, e); } + public event EventHandler MessageOccured; /// /// Get the count of ML components in the object /// @@ -137,80 +149,76 @@ public PredictionModel Train() where TInput : class where TOutput : class, new() { + Experiment experiment = Env.CreateExperiment(); + ILearningPipelineStep step = null; + List loaders = new List(); + List> transformModels = new List>(); + Var lastTransformModel = null; - using (var environment = new TlcEnvironment()) + foreach (ILearningPipelineItem currentItem in this) { - Experiment experiment = environment.CreateExperiment(); - ILearningPipelineStep step = null; - List loaders = new List(); - List> transformModels = new List>(); - Var lastTransformModel = null; + if (currentItem is ILearningPipelineLoader loader) + loaders.Add(loader); + + step = currentItem.ApplyStep(step, experiment); + if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null) + transformModels.Add(dataStep.Model); - foreach (ILearningPipelineItem currentItem in this) + else if (step is ILearningPipelinePredictorStep predictorDataStep) { - if (currentItem is ILearningPipelineLoader loader) - loaders.Add(loader); - - step = currentItem.ApplyStep(step, experiment); - if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null) - transformModels.Add(dataStep.Model); - - else if (step is ILearningPipelinePredictorStep predictorDataStep) + if (lastTransformModel != null) + transformModels.Insert(0, lastTransformModel); + + var localModelInput = new Transforms.ManyHeterogeneousModelCombiner { - if (lastTransformModel != null) - transformModels.Insert(0, lastTransformModel); - - var localModelInput = new Transforms.ManyHeterogeneousModelCombiner - { - PredictorModel = predictorDataStep.Model, - TransformModels = new ArrayVar(transformModels.ToArray()) - }; - - var localModelOutput = experiment.Add(localModelInput); - - var scorer = new Transforms.Scorer - { - PredictorModel = localModelOutput.PredictorModel - }; - - var scorerOutput = experiment.Add(scorer); - lastTransformModel = scorerOutput.ScoringTransform; - step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform); - transformModels.Clear(); - } - } + PredictorModel = predictorDataStep.Model, + TransformModels = new ArrayVar(transformModels.ToArray()) + }; - if (transformModels.Count > 0) - { - transformModels.Insert(0,lastTransformModel); - var modelInput = new Transforms.ModelCombiner + var localModelOutput = experiment.Add(localModelInput); + + var scorer = new Transforms.Scorer { - Models = new ArrayVar(transformModels.ToArray()) + PredictorModel = localModelOutput.PredictorModel }; - var modelOutput = experiment.Add(modelInput); - lastTransformModel = modelOutput.OutputModel; + var scorerOutput = experiment.Add(scorer); + lastTransformModel = scorerOutput.ScoringTransform; + step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform); + transformModels.Clear(); } + } - experiment.Compile(); - foreach (ILearningPipelineLoader loader in loaders) + if (transformModels.Count > 0) + { + transformModels.Insert(0, lastTransformModel); + var modelInput = new Transforms.ModelCombiner { - loader.SetInput(environment, experiment); - } - experiment.Run(); + Models = new ArrayVar(transformModels.ToArray()) + }; - ITransformModel model = experiment.GetOutput(lastTransformModel); - BatchPredictionEngine predictor; - using (var memoryStream = new MemoryStream()) - { - model.Save(environment, memoryStream); + var modelOutput = experiment.Add(modelInput); + lastTransformModel = modelOutput.OutputModel; + } - memoryStream.Position = 0; + experiment.Compile(); + foreach (ILearningPipelineLoader loader in loaders) + { + loader.SetInput(Env, experiment); + } + experiment.Run(); - predictor = environment.CreateBatchPredictionEngine(memoryStream); + ITransformModel model = experiment.GetOutput(lastTransformModel); + BatchPredictionEngine predictor; + using (var memoryStream = new MemoryStream()) + { + model.Save(Env, memoryStream); - return new PredictionModel(predictor, memoryStream); - } + memoryStream.Position = 0; + + predictor = Env.CreateBatchPredictionEngine(memoryStream); + + return new PredictionModel(predictor, memoryStream); } } @@ -220,9 +228,9 @@ public PredictionModel Train() /// /// The IDataView that was returned by the pipeline. /// - internal IDataView Execute(IHostEnvironment environment) + internal IDataView Execute() { - Experiment experiment = environment.CreateExperiment(); + Experiment experiment = Env.CreateExperiment(); ILearningPipelineStep step = null; List loaders = new List(); foreach (ILearningPipelineItem currentItem in this) @@ -241,7 +249,7 @@ internal IDataView Execute(IHostEnvironment environment) experiment.Compile(); foreach (ILearningPipelineLoader loader in loaders) { - loader.SetInput(environment, experiment); + loader.SetInput(Env, experiment); } experiment.Run(); diff --git a/src/Microsoft.ML/LearningPipelineDebugProxy.cs b/src/Microsoft.ML/LearningPipelineDebugProxy.cs index e9d93425e4..6c4b103aca 100644 --- a/src/Microsoft.ML/LearningPipelineDebugProxy.cs +++ b/src/Microsoft.ML/LearningPipelineDebugProxy.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Transforms; using System; @@ -25,7 +26,7 @@ internal sealed class LearningPipelineDebugProxy private const int MaxSlotNamesToDisplay = 100; private readonly LearningPipeline _pipeline; - private readonly TlcEnvironment _environment; + private readonly IHostEnvironment _environment; private IDataView _preview; private Exception _pipelineExecutionException; private PipelineItemDebugColumn[] _columns; @@ -36,10 +37,9 @@ public LearningPipelineDebugProxy(LearningPipeline pipeline) if (pipeline == null) throw new ArgumentNullException(nameof(pipeline)); - _pipeline = new LearningPipeline(); + _pipeline = new LearningPipeline(seed:42, concurrency:1); - // use a ConcurrencyFactor of 1 so other threads don't need to run in the debugger - _environment = new TlcEnvironment(conc: 1); + _environment = _pipeline.Env; foreach (ILearningPipelineItem item in pipeline) { @@ -139,7 +139,7 @@ private IDataView ExecutePipeline() { try { - _preview = _pipeline.Execute(_environment); + _preview = _pipeline.Execute(); } catch (Exception e) { diff --git a/src/Microsoft.ML/Runtime/DefaultEnvironment.cs b/src/Microsoft.ML/Runtime/DefaultEnvironment.cs new file mode 100644 index 0000000000..6a5fa0f455 --- /dev/null +++ b/src/Microsoft.ML/Runtime/DefaultEnvironment.cs @@ -0,0 +1,96 @@ +using Microsoft.ML.Runtime.Data; +using System; + +namespace Microsoft.ML.Runtime +{ + public sealed class DefaultEnvironment : HostEnvironmentBase + { + public DefaultEnvironment(int? seed = null, int conc = 0) + : this(RandomUtils.Create(seed), true, conc) + { + } + + public DefaultEnvironment(IRandom rand, bool verbose, int conc, string shortName = null, string parentFullName = null) : base(rand, verbose, conc, shortName, parentFullName) + { + EnsureDispatcher(); + AddListener(OnMessageRecieved); + } + + void OnMessageRecieved(IMessageSource sender, ChannelMessage msg) + { + ChannelMessageEventArgs eventArgs = new ChannelMessageEventArgs() { Message = msg }; + MessageRecieved?.Invoke(this, eventArgs); + } + + public event EventHandler MessageRecieved; + public class ChannelMessageEventArgs : EventArgs + { + public ChannelMessage Message { get; set; } + } + + private sealed class Channel : ChannelBase + { + public Channel(DefaultEnvironment master, ChannelProviderBase parent, string shortName, Action dispatch) + : base(master, parent, shortName, dispatch) + { + } + } + + private sealed class Host : HostBase + { + public new bool IsCancelled => Root.IsCancelled; + + public Host(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + : base(source, shortName, parentFullName, rand, verbose, conc) + { + } + + protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) + { + Contracts.AssertValue(parent); + Contracts.Assert(parent is Host); + Contracts.AssertNonEmpty(name); + return new Channel(Root, parent, name, GetDispatchDelegate()); + } + + protected override IPipe CreatePipe(ChannelProviderBase parent, string name) + { + Contracts.AssertValue(parent); + Contracts.Assert(parent is Host); + Contracts.AssertNonEmpty(name); + return new Pipe(parent, name, GetDispatchDelegate()); + } + + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + { + return new Host(source, shortName, parentFullName, rand, verbose, conc); + } + } + + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc) + { + Contracts.AssertValue(rand); + Contracts.AssertValueOrNull(parentFullName); + Contracts.AssertNonEmpty(shortName); + Contracts.Assert(source == this || source is Host); + return new Host(source, shortName, parentFullName, rand, verbose, conc); + } + + protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) + { + Contracts.AssertValue(parent); + Contracts.Assert(parent is DefaultEnvironment); + Contracts.AssertNonEmpty(name); + return new Channel(this, parent, name, GetDispatchDelegate()); + } + + protected override IPipe CreatePipe(ChannelProviderBase parent, string name) + { + Contracts.AssertValue(parent); + Contracts.Assert(parent is DefaultEnvironment); + Contracts.AssertNonEmpty(name); + return new Pipe(parent, name, GetDispatchDelegate()); + } + } + +} diff --git a/test/Microsoft.ML.TestFramework/ModelHelper.cs b/test/Microsoft.ML.TestFramework/ModelHelper.cs index dca360c4e3..f9c1f060e3 100644 --- a/test/Microsoft.ML.TestFramework/ModelHelper.cs +++ b/test/Microsoft.ML.TestFramework/ModelHelper.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.TestFramework { public static class ModelHelper { - private static TlcEnvironment s_environment = new TlcEnvironment(seed: 1); + private static TlcEnvironment s_environment = new TlcEnvironment(seed: 1, conc: 1); private static ITransformModel s_housePriceModel; public static void WriteKcHousePriceModel(string dataPath, string outputModelPath) diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index bafef95040..a91188ec80 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -4,6 +4,7 @@ using Microsoft.ML; using Microsoft.ML.TestFramework; +using System; using System.Linq; using Xunit; using Xunit.Abstractions; @@ -22,12 +23,15 @@ public LearningPipelineTests(ITestOutputHelper output) public void ConstructorDoesntThrow() { Assert.NotNull(new LearningPipeline()); + Assert.NotNull(new LearningPipeline(seed:42)); + Assert.NotNull(new LearningPipeline(concurrency: 1)); + Assert.NotNull(new LearningPipeline(seed:42, concurrency: 1)); } [Fact] public void CanAddAndRemoveFromPipeline() { - var pipeline = new LearningPipeline() + var pipeline = new LearningPipeline(seed:42, concurrency: 1) { new Transforms.CategoricalOneHotVectorizer("String1", "String2"), new Transforms.ColumnConcatenator(outputColumn: "Features", "String1", "String2", "Number1", "Number2"), @@ -42,5 +46,7 @@ public void CanAddAndRemoveFromPipeline() pipeline.Add(new Trainers.StochasticDualCoordinateAscentRegressor()); Assert.Equal(3, pipeline.Count); } + + } } diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs index 38ec6ce073..f0828dba18 100644 --- a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs @@ -19,7 +19,7 @@ public void TrainAndPredictHousePriceModelTest() { string dataPath = GetDataPath("kc_house_data.csv"); - var pipeline = new LearningPipeline(); + var pipeline = new LearningPipeline(seed: 42, concurrency : 1); pipeline.Add(new TextLoader(dataPath, useHeader: true, separator: ",")); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index 30c497ccc5..c393f1e53e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; +using System; using Xunit; namespace Microsoft.ML.Scenarios @@ -17,8 +18,8 @@ public void TrainAndPredictIrisModelTest() { string dataPath = GetDataPath("iris.txt"); - var pipeline = new LearningPipeline(); - + var pipeline = new LearningPipeline(seed: 42, concurrency: 1); + pipeline.MessageOccured += Pipeline_MessageOccured; pipeline.Add(new TextLoader(dataPath, useHeader: true, separator: "tab")); pipeline.Add(new ColumnConcatenator(outputColumn: "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); @@ -32,7 +33,7 @@ public void TrainAndPredictIrisModelTest() SepalLength = 3.3f, SepalWidth = 1.6f, PetalLength = 0.2f, - PetalWidth= 5.1f, + PetalWidth = 5.1f, }); Assert.Equal(1, prediction.PredictedLabels[0], 2); @@ -112,6 +113,12 @@ public void TrainAndPredictIrisModelTest() Assert.Equal(49, matrix["2", "2"]); } + + private void Pipeline_MessageOccured(object sender, Runtime.DefaultEnvironment.ChannelMessageEventArgs e) + { + System.Diagnostics.Debug.WriteLine(e.Message.Message); + } + public class IrisData { [Column("0")] diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index 79cc2fc137..e453904e6e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -17,7 +17,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest() { string dataPath = GetDataPath("iris.data"); - var pipeline = new LearningPipeline(); + var pipeline = new LearningPipeline(seed: 42, concurrency: 1); pipeline.Add(new TextLoader(dataPath, useHeader: false, separator: ",")); diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 608cbef144..f5ac0cc4ee 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -22,7 +22,7 @@ public partial class ScenariosTests public void TrainAndPredictSentimentModelTest() { string dataPath = GetDataPath(SentimentDataPath); - var pipeline = new LearningPipeline(); + var pipeline = new LearningPipeline(seed:42, concurrency:1); pipeline.Add(new TextLoader(dataPath, useHeader: true, separator: "tab")); pipeline.Add(new TextFeaturizer("Features", "SentimentText") {