Skip to content

CollectionDataSource (train on top of memory collection instead of loading data from file) #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 15, 2018
1 change: 1 addition & 0 deletions ZBaselines/Common/EntryPoints/core_ep-list.tsv
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runtime.EntryPoints.DataViewReference ImportData Microsoft.ML.Runtime.EntryPoints.DataViewReference+Input Microsoft.ML.Runtime.EntryPoints.DataViewReference+Output
Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput
Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput
Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData ImportText Microsoft.ML.Runtime.EntryPoints.ImportTextData+Input Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output
Expand Down
23 changes: 23 additions & 0 deletions ZBaselines/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,28 @@
{
"EntryPoints": [
{
"Name": "Data.DataViewReference",
"Desc": "Pass dataview from memory to experiment",
"FriendlyName": null,
"ShortName": null,
"Inputs": [
{
"Name": "Data",
"Type": "DataView",
"Desc": "Pointer to IDataView in memory",
"Required": true,
"SortOrder": 1.0,
"IsNullable": false
}
],
"Outputs": [
{
"Name": "Data",
"Type": "DataView",
"Desc": "The resulting data view"
}
]
},
{
"Name": "Data.IDataViewArrayConverter",
"Desc": "Create and array variable",
Expand Down
28 changes: 28 additions & 0 deletions src/Microsoft.ML/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,22 @@ public Microsoft.ML.Data.TextLoader.Output Add(Microsoft.ML.Data.TextLoader inpu
return output;
}

public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input)
{
var output = new Microsoft.ML.Data.DataViewReference.Output();
Add(input, output);
return output;
}

public void Add(Microsoft.ML.Data.TextLoader input, Microsoft.ML.Data.TextLoader.Output output)
{
_jsonNodes.Add(Serialize("Data.TextLoader", input, output));
}

public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output)
{
_jsonNodes.Add(Serialize("Data.DataViewReference", input, output));
}
public Microsoft.ML.Models.AnomalyDetectionEvaluator.Output Add(Microsoft.ML.Models.AnomalyDetectionEvaluator input)
{
var output = new Microsoft.ML.Models.AnomalyDetectionEvaluator.Output();
Expand Down Expand Up @@ -1311,6 +1322,23 @@ public sealed partial class TextLoader
public string CustomSchema { get; set; }


public sealed class Output
{
/// <summary>
/// The resulting data view
/// </summary>
public Var<Microsoft.ML.Runtime.Data.IDataView> Data { get; set; } = new Var<Microsoft.ML.Runtime.Data.IDataView>();

}
}

public sealed partial class DataViewReference
{
/// <summary>
/// Location of the input file
/// </summary>
public Var<Microsoft.ML.Runtime.Data.IDataView> Data { get; set; } = new Var<Microsoft.ML.Runtime.Data.IDataView>();

public sealed class Output
{
/// <summary>
Expand Down
101 changes: 101 additions & 0 deletions src/Microsoft.ML/Data/CollectionDataSource.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;

namespace Microsoft.ML.Data
{
/// <summary>
/// Creates data source for pipeline based on provided collection of data.
/// </summary>
public static class CollectionDataSource
{
/// <summary>
/// Creates pipeline data source. Support shuffle.
/// </summary>
public static ILearningPipelineLoader Create<T>(IList<T> data) where T : class
{
return new ListDataSource<T>(data);
}

/// <summary>
/// Creates pipeline data source which can't be shuffled.
/// </summary>
public static ILearningPipelineLoader Create<T>(IEnumerable<T> data) where T : class
{
return new EnumerableDataSource<T>(data);
}

private abstract class BaseDataSource<TInput> : ILearningPipelineLoader where TInput : class
{
private Data.DataViewReference _dataViewEntryPoint;
private IDataView _dataView;

public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
{
Contracts.Assert(previousStep == null);
_dataViewEntryPoint = new Data.DataViewReference();
var importOutput = experiment.Add(_dataViewEntryPoint);
return new CollectionDataSourcePipelineStep(importOutput.Data);
}

public void SetInput(IHostEnvironment environment, Experiment experiment)
{
_dataView = GetDataView(environment);
environment.CheckValue(_dataView, nameof(_dataView));
experiment.SetInput(_dataViewEntryPoint.Data, _dataView);
}

public abstract IDataView GetDataView(IHostEnvironment environment);
}

private class EnumerableDataSource<TInput> : BaseDataSource<TInput> where TInput : class
{
private readonly IEnumerable<TInput> _enumerableCollection;

public EnumerableDataSource(IEnumerable<TInput> collection)
{
Contracts.CheckValue(collection, nameof(collection));
_enumerableCollection = collection;
}

public override IDataView GetDataView(IHostEnvironment environment)
{
return ComponentCreation.CreateStreamingDataView(environment, _enumerableCollection);
}
}

private class ListDataSource<TInput> : BaseDataSource<TInput> where TInput : class
{
private readonly IList<TInput> _listCollection;

public ListDataSource(IList<TInput> collection)
{
Contracts.CheckParamValue(Utils.Size(collection) > 0, collection, nameof(collection), "Must be non-empty");
_listCollection = collection;
}

public override IDataView GetDataView(IHostEnvironment environment)
{
return ComponentCreation.CreateDataView(environment, _listCollection);
}
}

private class CollectionDataSourcePipelineStep : ILearningPipelineDataStep
{
public CollectionDataSourcePipelineStep(Var<IDataView> data)
{
Data = data;
}

public Var<IDataView> Data { get; }
public Var<ITransformModel> Model => null;
}
}
}
37 changes: 37 additions & 0 deletions src/Microsoft.ML/Runtime/EntryPoints/DataViewReference.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;

[assembly: LoadableClass(typeof(void), typeof(DataViewReference), null, typeof(SignatureEntryPointModule), "DataViewReference")]
namespace Microsoft.ML.Runtime.EntryPoints
{
public class DataViewReference
{
public sealed class Input
{
[Argument(ArgumentType.Required, HelpText = "Pointer to IDataView in memory", SortOrder = 1)]
public IDataView Data;
}

public sealed class Output
{
[TlcModule.Output(Desc = "The resulting data view", SortOrder = 1)]
public IDataView Data;
}

[TlcModule.EntryPoint(Name = "Data.DataViewReference", Desc = "Pass dataview from memory to experiment")]
public static Output ImportData(IHostEnvironment env, Input input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("DataViewReference");
env.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
return new Output { Data = input.Data };
}
}
}
3 changes: 1 addition & 2 deletions src/Microsoft.ML/TextLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,10 @@ private class TextLoaderPipelineStep : ILearningPipelineDataStep
public TextLoaderPipelineStep(Var<IDataView> data)
{
Data = data;
Model = null;
}

public Var<IDataView> Data { get; }
public Var<ITransformModel> Model { get; }
public Var<ITransformModel> Model => null;
}
}
}
Loading