diff --git a/ZBaselines/Common/EntryPoints/core_ep-list.tsv b/ZBaselines/Common/EntryPoints/core_ep-list.tsv index 568a6066f9..3f6639caad 100644 --- a/ZBaselines/Common/EntryPoints/core_ep-list.tsv +++ b/ZBaselines/Common/EntryPoints/core_ep-list.tsv @@ -1,3 +1,4 @@ +Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runtime.EntryPoints.DataViewReference ImportData Microsoft.ML.Runtime.EntryPoints.DataViewReference+Input Microsoft.ML.Runtime.EntryPoints.DataViewReference+Output Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData ImportText Microsoft.ML.Runtime.EntryPoints.ImportTextData+Input Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index a3778a7f7f..7529b70212 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -1,5 +1,28 @@ { "EntryPoints": [ + { + "Name": "Data.DataViewReference", + "Desc": "Pass dataview from memory to experiment", + "FriendlyName": null, + "ShortName": null, + "Inputs": [ + { + "Name": "Data", + "Type": "DataView", + "Desc": "Pointer to IDataView in memory", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + } + ], + "Outputs": [ + { + "Name": "Data", + "Type": "DataView", + "Desc": "The resulting data view" + } + ] + }, { "Name": "Data.IDataViewArrayConverter", "Desc": "Create and array variable", diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 46002c5abf..2fb94c498b 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -53,11 +53,22 @@ public Microsoft.ML.Data.TextLoader.Output Add(Microsoft.ML.Data.TextLoader inpu return output; } + public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input) + { + var output = new Microsoft.ML.Data.DataViewReference.Output(); + Add(input, output); + return output; + } + public void Add(Microsoft.ML.Data.TextLoader input, Microsoft.ML.Data.TextLoader.Output output) { _jsonNodes.Add(Serialize("Data.TextLoader", input, output)); } + public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output) + { + _jsonNodes.Add(Serialize("Data.DataViewReference", input, output)); + } public Microsoft.ML.Models.AnomalyDetectionEvaluator.Output Add(Microsoft.ML.Models.AnomalyDetectionEvaluator input) { var output = new Microsoft.ML.Models.AnomalyDetectionEvaluator.Output(); @@ -1311,6 +1322,23 @@ public sealed partial class TextLoader public string CustomSchema { get; set; } + public sealed class Output + { + /// + /// The resulting data view + /// + public Var Data { get; set; } = new Var(); + + } + } + + public sealed partial class DataViewReference + { + /// + /// Location of the input file + /// + public Var Data { get; set; } = new Var(); + public sealed class Output { /// diff --git a/src/Microsoft.ML/Data/CollectionDataSource.cs b/src/Microsoft.ML/Data/CollectionDataSource.cs new file mode 100644 index 0000000000..56523fc994 --- /dev/null +++ b/src/Microsoft.ML/Data/CollectionDataSource.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; + +namespace Microsoft.ML.Data +{ + /// + /// Creates data source for pipeline based on provided collection of data. + /// + public static class CollectionDataSource + { + /// + /// Creates pipeline data source. Support shuffle. + /// + public static ILearningPipelineLoader Create(IList data) where T : class + { + return new ListDataSource(data); + } + + /// + /// Creates pipeline data source which can't be shuffled. + /// + public static ILearningPipelineLoader Create(IEnumerable data) where T : class + { + return new EnumerableDataSource(data); + } + + private abstract class BaseDataSource : ILearningPipelineLoader where TInput : class + { + private Data.DataViewReference _dataViewEntryPoint; + private IDataView _dataView; + + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + Contracts.Assert(previousStep == null); + _dataViewEntryPoint = new Data.DataViewReference(); + var importOutput = experiment.Add(_dataViewEntryPoint); + return new CollectionDataSourcePipelineStep(importOutput.Data); + } + + public void SetInput(IHostEnvironment environment, Experiment experiment) + { + _dataView = GetDataView(environment); + environment.CheckValue(_dataView, nameof(_dataView)); + experiment.SetInput(_dataViewEntryPoint.Data, _dataView); + } + + public abstract IDataView GetDataView(IHostEnvironment environment); + } + + private class EnumerableDataSource : BaseDataSource where TInput : class + { + private readonly IEnumerable _enumerableCollection; + + public EnumerableDataSource(IEnumerable collection) + { + Contracts.CheckValue(collection, nameof(collection)); + _enumerableCollection = collection; + } + + public override IDataView GetDataView(IHostEnvironment environment) + { + return ComponentCreation.CreateStreamingDataView(environment, _enumerableCollection); + } + } + + private class ListDataSource : BaseDataSource where TInput : class + { + private readonly IList _listCollection; + + public ListDataSource(IList collection) + { + Contracts.CheckParamValue(Utils.Size(collection) > 0, collection, nameof(collection), "Must be non-empty"); + _listCollection = collection; + } + + public override IDataView GetDataView(IHostEnvironment environment) + { + return ComponentCreation.CreateDataView(environment, _listCollection); + } + } + + private class CollectionDataSourcePipelineStep : ILearningPipelineDataStep + { + public CollectionDataSourcePipelineStep(Var data) + { + Data = data; + } + + public Var Data { get; } + public Var Model => null; + } + } +} diff --git a/src/Microsoft.ML/Runtime/EntryPoints/DataViewReference.cs b/src/Microsoft.ML/Runtime/EntryPoints/DataViewReference.cs new file mode 100644 index 0000000000..3b1633456d --- /dev/null +++ b/src/Microsoft.ML/Runtime/EntryPoints/DataViewReference.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; + +[assembly: LoadableClass(typeof(void), typeof(DataViewReference), null, typeof(SignatureEntryPointModule), "DataViewReference")] +namespace Microsoft.ML.Runtime.EntryPoints +{ + public class DataViewReference + { + public sealed class Input + { + [Argument(ArgumentType.Required, HelpText = "Pointer to IDataView in memory", SortOrder = 1)] + public IDataView Data; + } + + public sealed class Output + { + [TlcModule.Output(Desc = "The resulting data view", SortOrder = 1)] + public IDataView Data; + } + + [TlcModule.EntryPoint(Name = "Data.DataViewReference", Desc = "Pass dataview from memory to experiment")] + public static Output ImportData(IHostEnvironment env, Input input) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register("DataViewReference"); + env.CheckValue(input, nameof(input)); + EntryPointUtils.CheckInputArgs(host, input); + return new Output { Data = input.Data }; + } + } +} diff --git a/src/Microsoft.ML/TextLoader.cs b/src/Microsoft.ML/TextLoader.cs index f63a14611b..cdb9df8a36 100644 --- a/src/Microsoft.ML/TextLoader.cs +++ b/src/Microsoft.ML/TextLoader.cs @@ -115,11 +115,10 @@ private class TextLoaderPipelineStep : ILearningPipelineDataStep public TextLoaderPipelineStep(Var data) { Data = data; - Model = null; } public Var Data { get; } - public Var Model { get; } + public Var Model => null; } } } diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs new file mode 100644 index 0000000000..923d4eb375 --- /dev/null +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -0,0 +1,209 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.TestFramework; +using Microsoft.ML.Trainers; +using Microsoft.ML.Transforms; +using System.Collections.Generic; +using System.Linq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.EntryPoints.Tests +{ + public class CollectionDataSourceTests : BaseTestClass + { + public CollectionDataSourceTests(ITestOutputHelper output) + : base(output) + { + } + + [Fact] + public void CheckConstructor() + { + Assert.NotNull(CollectionDataSource.Create(new List() { new Input { Number1 = 1, String1 = "1" } })); + Assert.NotNull(CollectionDataSource.Create(new Input[1] { new Input { Number1 = 1, String1 = "1" } })); + Assert.NotNull(CollectionDataSource.Create(new Input[1] { new Input { Number1 = 1, String1 = "1" } }.AsEnumerable())); + + bool thrown = false; + try + { + CollectionDataSource.Create(new List()); + } + catch + { + thrown = true; + } + Assert.True(thrown); + + thrown = false; + try + { + CollectionDataSource.Create(new Input[0]); + } + catch + { + thrown = true; + } + Assert.True(thrown); + } + + [Fact] + public void CanSuccessfullyApplyATransform() + { + var collection = CollectionDataSource.Create(new List() { new Input { Number1 = 1, String1 = "1" } }); + using (var environment = new TlcEnvironment()) + { + Experiment experiment = environment.CreateExperiment(); + ILearningPipelineDataStep output = (ILearningPipelineDataStep)collection.ApplyStep(null, experiment); + + Assert.NotNull(output.Data); + Assert.NotNull(output.Data.VarName); + Assert.Null(output.Model); + } + } + + [Fact] + public void CanSuccessfullyEnumerated() + { + var collection = CollectionDataSource.Create(new List() { + new Input { Number1 = 1, String1 = "1" }, + new Input { Number1 = 2, String1 = "2" }, + new Input { Number1 = 3, String1 = "3" } + }); + + using (var environment = new TlcEnvironment()) + { + Experiment experiment = environment.CreateExperiment(); + ILearningPipelineDataStep output = collection.ApplyStep(null, experiment) as ILearningPipelineDataStep; + + experiment.Compile(); + collection.SetInput(environment, experiment); + experiment.Run(); + + IDataView data = experiment.GetOutput(output.Data); + Assert.NotNull(data); + + using (var cursor = data.GetRowCursor((a => true))) + { + var IDGetter = cursor.GetGetter(0); + var TextGetter = cursor.GetGetter(1); + + Assert.True(cursor.MoveNext()); + + float ID = 0; + IDGetter(ref ID); + Assert.Equal(1, ID); + + DvText Text = new DvText(); + TextGetter(ref Text); + Assert.Equal("1", Text.ToString()); + + Assert.True(cursor.MoveNext()); + + ID = 0; + IDGetter(ref ID); + Assert.Equal(2, ID); + + Text = new DvText(); + TextGetter(ref Text); + Assert.Equal("2", Text.ToString()); + + Assert.True(cursor.MoveNext()); + + ID = 0; + IDGetter(ref ID); + Assert.Equal(3, ID); + + Text = new DvText(); + TextGetter(ref Text); + Assert.Equal("3", Text.ToString()); + + Assert.False(cursor.MoveNext()); + } + } + } + + [Fact] + public void CanTrain() + { + var pipeline = new LearningPipeline(); + var data = new List() { + new IrisData { SepalLength = 1f, SepalWidth = 1f ,PetalLength=0.3f, PetalWidth=5.1f, Label=1}, + new IrisData { SepalLength = 1f, SepalWidth = 1f ,PetalLength=0.3f, PetalWidth=5.1f, Label=1}, + new IrisData { SepalLength = 1.2f, SepalWidth = 0.5f ,PetalLength=0.3f, PetalWidth=5.1f, Label=0} + }; + var collection = CollectionDataSource.Create(data); + + pipeline.Add(collection); + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + pipeline.Add(new StochasticDualCoordinateAscentClassifier()); + PredictionModel model = pipeline.Train(); + + IrisPrediction prediction = model.Predict(new IrisData() + { + SepalLength = 3.3f, + SepalWidth = 1.6f, + PetalLength = 0.2f, + PetalWidth = 5.1f, + }); + + pipeline = new LearningPipeline(); + collection = CollectionDataSource.Create(data.AsEnumerable()); + pipeline.Add(collection); + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + pipeline.Add(new StochasticDualCoordinateAscentClassifier()); + model = pipeline.Train(); + + prediction = model.Predict(new IrisData() + { + SepalLength = 3.3f, + SepalWidth = 1.6f, + PetalLength = 0.2f, + PetalWidth = 5.1f, + }); + + } + + public class Input + { + [Column("0")] + public float Number1; + + [Column("1")] + public string String1; + } + + public class IrisData + { + [Column("0")] + public float Label; + + [Column("1")] + public float SepalLength; + + [Column("2")] + public float SepalWidth; + + [Column("3")] + public float PetalLength; + + [Column("4")] + public float PetalWidth; + } + + public class IrisPrediction + { + [ColumnName("Score")] + public float[] PredictedLabels; + } + + } +}