diff --git a/ZBaselines/Common/EntryPoints/core_ep-list.tsv b/ZBaselines/Common/EntryPoints/core_ep-list.tsv
index 568a6066f9..3f6639caad 100644
--- a/ZBaselines/Common/EntryPoints/core_ep-list.tsv
+++ b/ZBaselines/Common/EntryPoints/core_ep-list.tsv
@@ -1,3 +1,4 @@
+Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runtime.EntryPoints.DataViewReference ImportData Microsoft.ML.Runtime.EntryPoints.DataViewReference+Input Microsoft.ML.Runtime.EntryPoints.DataViewReference+Output
Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput
Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput
Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData ImportText Microsoft.ML.Runtime.EntryPoints.ImportTextData+Input Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output
diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json
index a3778a7f7f..7529b70212 100644
--- a/ZBaselines/Common/EntryPoints/core_manifest.json
+++ b/ZBaselines/Common/EntryPoints/core_manifest.json
@@ -1,5 +1,28 @@
{
"EntryPoints": [
+ {
+ "Name": "Data.DataViewReference",
+ "Desc": "Pass dataview from memory to experiment",
+ "FriendlyName": null,
+ "ShortName": null,
+ "Inputs": [
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Pointer to IDataView in memory",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "The resulting data view"
+ }
+ ]
+ },
{
"Name": "Data.IDataViewArrayConverter",
"Desc": "Create and array variable",
diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs
index 46002c5abf..2fb94c498b 100644
--- a/src/Microsoft.ML/CSharpApi.cs
+++ b/src/Microsoft.ML/CSharpApi.cs
@@ -53,11 +53,22 @@ public Microsoft.ML.Data.TextLoader.Output Add(Microsoft.ML.Data.TextLoader inpu
return output;
}
+ public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input)
+ {
+ var output = new Microsoft.ML.Data.DataViewReference.Output();
+ Add(input, output);
+ return output;
+ }
+
public void Add(Microsoft.ML.Data.TextLoader input, Microsoft.ML.Data.TextLoader.Output output)
{
_jsonNodes.Add(Serialize("Data.TextLoader", input, output));
}
+ public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output)
+ {
+ _jsonNodes.Add(Serialize("Data.DataViewReference", input, output));
+ }
public Microsoft.ML.Models.AnomalyDetectionEvaluator.Output Add(Microsoft.ML.Models.AnomalyDetectionEvaluator input)
{
var output = new Microsoft.ML.Models.AnomalyDetectionEvaluator.Output();
@@ -1311,6 +1322,23 @@ public sealed partial class TextLoader
public string CustomSchema { get; set; }
+ public sealed class Output
+ {
+ ///
+ /// The resulting data view
+ ///
+ public Var Data { get; set; } = new Var();
+
+ }
+ }
+
+ public sealed partial class DataViewReference
+ {
+ ///
+ /// Location of the input file
+ ///
+ public Var Data { get; set; } = new Var();
+
public sealed class Output
{
///
diff --git a/src/Microsoft.ML/Data/CollectionDataSource.cs b/src/Microsoft.ML/Data/CollectionDataSource.cs
new file mode 100644
index 0000000000..56523fc994
--- /dev/null
+++ b/src/Microsoft.ML/Data/CollectionDataSource.cs
@@ -0,0 +1,101 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
+
+namespace Microsoft.ML.Data
+{
+ ///
+ /// Creates data source for pipeline based on provided collection of data.
+ ///
+ public static class CollectionDataSource
+ {
+ ///
+ /// Creates pipeline data source. Support shuffle.
+ ///
+ public static ILearningPipelineLoader Create(IList data) where T : class
+ {
+ return new ListDataSource(data);
+ }
+
+ ///
+ /// Creates pipeline data source which can't be shuffled.
+ ///
+ public static ILearningPipelineLoader Create(IEnumerable data) where T : class
+ {
+ return new EnumerableDataSource(data);
+ }
+
+ private abstract class BaseDataSource : ILearningPipelineLoader where TInput : class
+ {
+ private Data.DataViewReference _dataViewEntryPoint;
+ private IDataView _dataView;
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ Contracts.Assert(previousStep == null);
+ _dataViewEntryPoint = new Data.DataViewReference();
+ var importOutput = experiment.Add(_dataViewEntryPoint);
+ return new CollectionDataSourcePipelineStep(importOutput.Data);
+ }
+
+ public void SetInput(IHostEnvironment environment, Experiment experiment)
+ {
+ _dataView = GetDataView(environment);
+ environment.CheckValue(_dataView, nameof(_dataView));
+ experiment.SetInput(_dataViewEntryPoint.Data, _dataView);
+ }
+
+ public abstract IDataView GetDataView(IHostEnvironment environment);
+ }
+
+ private class EnumerableDataSource : BaseDataSource where TInput : class
+ {
+ private readonly IEnumerable _enumerableCollection;
+
+ public EnumerableDataSource(IEnumerable collection)
+ {
+ Contracts.CheckValue(collection, nameof(collection));
+ _enumerableCollection = collection;
+ }
+
+ public override IDataView GetDataView(IHostEnvironment environment)
+ {
+ return ComponentCreation.CreateStreamingDataView(environment, _enumerableCollection);
+ }
+ }
+
+ private class ListDataSource : BaseDataSource where TInput : class
+ {
+ private readonly IList _listCollection;
+
+ public ListDataSource(IList collection)
+ {
+ Contracts.CheckParamValue(Utils.Size(collection) > 0, collection, nameof(collection), "Must be non-empty");
+ _listCollection = collection;
+ }
+
+ public override IDataView GetDataView(IHostEnvironment environment)
+ {
+ return ComponentCreation.CreateDataView(environment, _listCollection);
+ }
+ }
+
+ private class CollectionDataSourcePipelineStep : ILearningPipelineDataStep
+ {
+ public CollectionDataSourcePipelineStep(Var data)
+ {
+ Data = data;
+ }
+
+ public Var Data { get; }
+ public Var Model => null;
+ }
+ }
+}
diff --git a/src/Microsoft.ML/Runtime/EntryPoints/DataViewReference.cs b/src/Microsoft.ML/Runtime/EntryPoints/DataViewReference.cs
new file mode 100644
index 0000000000..3b1633456d
--- /dev/null
+++ b/src/Microsoft.ML/Runtime/EntryPoints/DataViewReference.cs
@@ -0,0 +1,37 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+
+[assembly: LoadableClass(typeof(void), typeof(DataViewReference), null, typeof(SignatureEntryPointModule), "DataViewReference")]
+namespace Microsoft.ML.Runtime.EntryPoints
+{
+ public class DataViewReference
+ {
+ public sealed class Input
+ {
+ [Argument(ArgumentType.Required, HelpText = "Pointer to IDataView in memory", SortOrder = 1)]
+ public IDataView Data;
+ }
+
+ public sealed class Output
+ {
+ [TlcModule.Output(Desc = "The resulting data view", SortOrder = 1)]
+ public IDataView Data;
+ }
+
+ [TlcModule.EntryPoint(Name = "Data.DataViewReference", Desc = "Pass dataview from memory to experiment")]
+ public static Output ImportData(IHostEnvironment env, Input input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("DataViewReference");
+ env.CheckValue(input, nameof(input));
+ EntryPointUtils.CheckInputArgs(host, input);
+ return new Output { Data = input.Data };
+ }
+ }
+}
diff --git a/src/Microsoft.ML/TextLoader.cs b/src/Microsoft.ML/TextLoader.cs
index f63a14611b..cdb9df8a36 100644
--- a/src/Microsoft.ML/TextLoader.cs
+++ b/src/Microsoft.ML/TextLoader.cs
@@ -115,11 +115,10 @@ private class TextLoaderPipelineStep : ILearningPipelineDataStep
public TextLoaderPipelineStep(Var data)
{
Data = data;
- Model = null;
}
public Var Data { get; }
- public Var Model { get; }
+ public Var Model => null;
}
}
}
diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs
new file mode 100644
index 0000000000..923d4eb375
--- /dev/null
+++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs
@@ -0,0 +1,209 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.TestFramework;
+using Microsoft.ML.Trainers;
+using Microsoft.ML.Transforms;
+using System.Collections.Generic;
+using System.Linq;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.ML.EntryPoints.Tests
+{
+ public class CollectionDataSourceTests : BaseTestClass
+ {
+ public CollectionDataSourceTests(ITestOutputHelper output)
+ : base(output)
+ {
+ }
+
+ [Fact]
+ public void CheckConstructor()
+ {
+ Assert.NotNull(CollectionDataSource.Create(new List() { new Input { Number1 = 1, String1 = "1" } }));
+ Assert.NotNull(CollectionDataSource.Create(new Input[1] { new Input { Number1 = 1, String1 = "1" } }));
+ Assert.NotNull(CollectionDataSource.Create(new Input[1] { new Input { Number1 = 1, String1 = "1" } }.AsEnumerable()));
+
+ bool thrown = false;
+ try
+ {
+ CollectionDataSource.Create(new List());
+ }
+ catch
+ {
+ thrown = true;
+ }
+ Assert.True(thrown);
+
+ thrown = false;
+ try
+ {
+ CollectionDataSource.Create(new Input[0]);
+ }
+ catch
+ {
+ thrown = true;
+ }
+ Assert.True(thrown);
+ }
+
+ [Fact]
+ public void CanSuccessfullyApplyATransform()
+ {
+ var collection = CollectionDataSource.Create(new List() { new Input { Number1 = 1, String1 = "1" } });
+ using (var environment = new TlcEnvironment())
+ {
+ Experiment experiment = environment.CreateExperiment();
+ ILearningPipelineDataStep output = (ILearningPipelineDataStep)collection.ApplyStep(null, experiment);
+
+ Assert.NotNull(output.Data);
+ Assert.NotNull(output.Data.VarName);
+ Assert.Null(output.Model);
+ }
+ }
+
+ [Fact]
+ public void CanSuccessfullyEnumerated()
+ {
+ var collection = CollectionDataSource.Create(new List() {
+ new Input { Number1 = 1, String1 = "1" },
+ new Input { Number1 = 2, String1 = "2" },
+ new Input { Number1 = 3, String1 = "3" }
+ });
+
+ using (var environment = new TlcEnvironment())
+ {
+ Experiment experiment = environment.CreateExperiment();
+ ILearningPipelineDataStep output = collection.ApplyStep(null, experiment) as ILearningPipelineDataStep;
+
+ experiment.Compile();
+ collection.SetInput(environment, experiment);
+ experiment.Run();
+
+ IDataView data = experiment.GetOutput(output.Data);
+ Assert.NotNull(data);
+
+ using (var cursor = data.GetRowCursor((a => true)))
+ {
+ var IDGetter = cursor.GetGetter(0);
+ var TextGetter = cursor.GetGetter(1);
+
+ Assert.True(cursor.MoveNext());
+
+ float ID = 0;
+ IDGetter(ref ID);
+ Assert.Equal(1, ID);
+
+ DvText Text = new DvText();
+ TextGetter(ref Text);
+ Assert.Equal("1", Text.ToString());
+
+ Assert.True(cursor.MoveNext());
+
+ ID = 0;
+ IDGetter(ref ID);
+ Assert.Equal(2, ID);
+
+ Text = new DvText();
+ TextGetter(ref Text);
+ Assert.Equal("2", Text.ToString());
+
+ Assert.True(cursor.MoveNext());
+
+ ID = 0;
+ IDGetter(ref ID);
+ Assert.Equal(3, ID);
+
+ Text = new DvText();
+ TextGetter(ref Text);
+ Assert.Equal("3", Text.ToString());
+
+ Assert.False(cursor.MoveNext());
+ }
+ }
+ }
+
+ [Fact]
+ public void CanTrain()
+ {
+ var pipeline = new LearningPipeline();
+ var data = new List() {
+ new IrisData { SepalLength = 1f, SepalWidth = 1f ,PetalLength=0.3f, PetalWidth=5.1f, Label=1},
+ new IrisData { SepalLength = 1f, SepalWidth = 1f ,PetalLength=0.3f, PetalWidth=5.1f, Label=1},
+ new IrisData { SepalLength = 1.2f, SepalWidth = 0.5f ,PetalLength=0.3f, PetalWidth=5.1f, Label=0}
+ };
+ var collection = CollectionDataSource.Create(data);
+
+ pipeline.Add(collection);
+ pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
+ "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
+ pipeline.Add(new StochasticDualCoordinateAscentClassifier());
+ PredictionModel model = pipeline.Train();
+
+ IrisPrediction prediction = model.Predict(new IrisData()
+ {
+ SepalLength = 3.3f,
+ SepalWidth = 1.6f,
+ PetalLength = 0.2f,
+ PetalWidth = 5.1f,
+ });
+
+ pipeline = new LearningPipeline();
+ collection = CollectionDataSource.Create(data.AsEnumerable());
+ pipeline.Add(collection);
+ pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
+ "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
+ pipeline.Add(new StochasticDualCoordinateAscentClassifier());
+ model = pipeline.Train();
+
+ prediction = model.Predict(new IrisData()
+ {
+ SepalLength = 3.3f,
+ SepalWidth = 1.6f,
+ PetalLength = 0.2f,
+ PetalWidth = 5.1f,
+ });
+
+ }
+
+ public class Input
+ {
+ [Column("0")]
+ public float Number1;
+
+ [Column("1")]
+ public string String1;
+ }
+
+ public class IrisData
+ {
+ [Column("0")]
+ public float Label;
+
+ [Column("1")]
+ public float SepalLength;
+
+ [Column("2")]
+ public float SepalWidth;
+
+ [Column("3")]
+ public float PetalLength;
+
+ [Column("4")]
+ public float PetalWidth;
+ }
+
+ public class IrisPrediction
+ {
+ [ColumnName("Score")]
+ public float[] PredictedLabels;
+ }
+
+ }
+}