diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 67fc91439e..923aac5c81 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -622,7 +622,7 @@ public ICursor GetRootCursor() /// public static class CursoringUtils { - private const string NeedEnvObsoleteMessage = "This method is obsolete. Please use the overload that takes an additional 'env' argument. An environment can be created via new TlcEnvironment()."; + private const string NeedEnvObsoleteMessage = "This method is obsolete. Please use the overload that takes an additional 'env' argument. An environment can be created via new LocalEnvironment()."; /// /// Generate a strongly-typed cursorable wrapper of the . diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index f04febd055..80a8b64c6b 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -18,6 +18,9 @@ [assembly: LoadableClass(typeof(RegressionPredictionTransformer>), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel), "", RegressionPredictionTransformer.LoaderSignature)] +[assembly: LoadableClass(typeof(RankingPredictionTransformer>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel), + "", RankingPredictionTransformer.LoaderSignature)] + namespace Microsoft.ML.Runtime.Data { public abstract class PredictionTransformerBase : IPredictionTransformer, ICanSaveModel @@ -301,6 +304,52 @@ private static VersionInfo GetVersionInfo() } } + public sealed class RankingPredictionTransformer : PredictionTransformerBase + where TModel : class, IPredictorProducing + { + private readonly GenericScorer _scorer; + + public RankingPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer)), model, inputSchema, featureColumn) + { + var schema = new RoleMappedSchema(inputSchema, null, featureColumn); + _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); + } + + internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer)), ctx) + { + var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); + _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); + } + + public override IDataView Transform(IDataView input) + { + Host.CheckValue(input, nameof(input)); + return _scorer.ApplyToData(Host, input); + } + + protected override void SaveCore(ModelSaveContext ctx) + { + Contracts.AssertValue(ctx); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // + base.SaveCore(ctx); + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "MC RANK", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: RankingPredictionTransformer.LoaderSignature); + } + } + internal static class BinaryPredictionTransformer { public const string LoaderSignature = "BinaryPredXfer"; @@ -324,4 +373,12 @@ internal static class RegressionPredictionTransformer public static RegressionPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) => new RegressionPredictionTransformer>(env, ctx); } + + internal static class RankingPredictionTransformer + { + public const string LoaderSignature = "RankingPredXfer"; + + public static RankingPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) + => new RankingPredictionTransformer>(env, ctx); + } } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index b6740fe5f6..67af075f7d 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -57,7 +57,7 @@ public Arguments() env => new Ova(env, new Ova.Arguments() { PredictorType = ComponentFactoryUtils.CreateFromFunction( - e => new AveragedPerceptronTrainer(e, new AveragedPerceptronTrainer.Arguments())) + e => new FastTreeBinaryClassificationTrainer(e, DefaultColumnNames.Label, DefaultColumnNames.Features)) })); } } diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index 436e365b79..fa68546eb8 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -49,7 +49,7 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF public Arguments() { BasePredictorType = ComponentFactoryUtils.CreateFromFunction( - env => new FastTreeRegressionTrainer(env, new FastTreeRegressionTrainer.Arguments())); + env => new FastTreeRegressionTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features)); } public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index f0ced9d947..63adbd7c56 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -5,6 +5,7 @@ using System; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Ensemble.OutputCombiners; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; @@ -46,7 +47,7 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto public Arguments() { BasePredictorType = ComponentFactoryUtils.CreateFromFunction( - env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments())); + env => new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features)); } public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this); diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 6373e78215..ef91e6c688 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -6,17 +6,25 @@ using System; using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.FastTree.Internal; using Microsoft.ML.Runtime.Internal.Internallearn; namespace Microsoft.ML.Runtime.FastTree { - public abstract class BoostingFastTreeTrainerBase : FastTreeTrainerBase + public abstract class BoostingFastTreeTrainerBase : FastTreeTrainerBase + where TTransformer : IPredictionTransformer where TArgs : BoostedTreeArgs, new() - where TPredictor : IPredictorProducing + where TModel : IPredictorProducing { - public BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args) : base(env, args) + protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) : base(env, args, label) + { + } + + protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, + string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) + : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) { } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 8e5d48f260..fedf642f2e 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -12,6 +12,7 @@ using System.IO; using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.Conversion; @@ -43,10 +44,11 @@ internal static class FastTreeShared public static readonly object TrainLock = new object(); } - public abstract class FastTreeTrainerBase : - TrainerBase + public abstract class FastTreeTrainerBase : + TrainerEstimatorBase + where TTransformer: IPredictionTransformer where TArgs : TreeArgs, new() - where TPredictor : IPredictorProducing + where TModel : IPredictorProducing { protected readonly TArgs Args; protected readonly bool AllowGC; @@ -87,8 +89,41 @@ public abstract class FastTreeTrainerBase : private protected virtual bool NeedCalibration => false; - private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args) - : base(env, RegisterName) + /// + /// Constructor to use when instantiating the classing deriving from here through the API. + /// + private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, + string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(featureColumn), label, MakeWeightColumn(weightColumn)) + { + Args = new TArgs(); + + //apply the advanced args, if the user supplied any + advancedSettings?.Invoke(Args); + Args.LabelColumn = label.Name; + + if (weightColumn != null) + Args.WeightColumn = weightColumn; + + if (groupIdColumn != null) + Args.GroupIdColumn = groupIdColumn; + + // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. + // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. + // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. + Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true); + // REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment + // with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from LocalEnvironment. + AllowGC = (env is HostEnvironmentBase); + + Initialize(env); + } + + /// + /// Legacy constructor that is used when invoking the classsing deriving from this, through maml. + /// + private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.WeightColumn)) { Host.CheckValue(args, nameof(args)); Args = args; @@ -96,25 +131,11 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args) // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. // Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration. Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true); - int numThreads = Args.NumThreads ?? Environment.ProcessorCount; - if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) - { - using (var ch = Host.Start("FastTreeTrainerBase")) - { - numThreads = Host.ConcurrencyFactor; - ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor " - + "setting of the environment. Using {0} training threads instead.", numThreads); - ch.Done(); - } - } - ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); - ParallelTraining.InitEnvironment(); // REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment - // with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from ConsoleEnvironment. - AllowGC = (env is HostEnvironmentBase); - Tests = new List(); + // with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from LocalEnvironment. + AllowGC = (env is HostEnvironmentBase); - InitializeThreads(numThreads); + Initialize(env); } protected abstract void PrepareLabels(IChannel ch); @@ -133,6 +154,39 @@ protected virtual Float GetMaxLabel() return Float.PositiveInfinity; } + private static SchemaShape.Column MakeWeightColumn(string weightColumn) + { + if (weightColumn == null) + return null; + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + } + + private static SchemaShape.Column MakeFeatureColumn(string featureColumn) + { + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + } + + private void Initialize(IHostEnvironment env) + { + int numThreads = Args.NumThreads ?? Environment.ProcessorCount; + if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor) + { + using (var ch = Host.Start("FastTreeTrainerBase")) + { + numThreads = Host.ConcurrencyFactor; + ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor " + + "setting of the environment. Using {0} training threads instead.", numThreads); + ch.Done(); + } + } + ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); + ParallelTraining.InitEnvironment(); + + Tests = new List(); + + InitializeThreads(numThreads); + } + protected void ConvertData(RoleMappedData trainData) { trainData.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex); diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 4afec1787b..97de30bf75 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -36,7 +35,7 @@ "fastrank", "fastrankwrapper")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastTreeBinaryPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastTreeBinaryPredictor), null, typeof(SignatureLoadModel), "FastTree Binary Executor", FastTreeBinaryPredictor.LoaderSignature)] @@ -84,7 +83,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -102,26 +101,63 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC /// public sealed partial class FastTreeBinaryClassificationTrainer : - BoostingFastTreeTrainerBase> + BoostingFastTreeTrainerBase>, IPredictorWithFeatureWeights> { + /// + /// The LoadName for the assembly containing the trainer. + /// public const string LoadNameValue = "FastTreeBinaryClassification"; internal const string UserNameValue = "FastTree (Boosted Trees) Classification"; internal const string Summary = "Uses a logit-boost boosted tree learner to perform binary classification."; internal const string ShortName = "ftc"; private bool[] _trainSetLabels; + private readonly SchemaShape.Column[] _outputColumns; + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn, string featureColumn, + string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) + : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + { + // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss + _sigmoidParameter = 2.0 * Args.LearningRates; + + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } private double _sigmoidParameter; + /// + /// Initializes a new instance of by using the legacy class. + /// public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args) - : base(env, args) + : base(env, args, MakeLabelColumn(args.LabelColumn)) { + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public override IPredictorWithFeatureWeights Train(TrainContext context) + protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -194,6 +230,11 @@ protected override void PrepareLabels(IChannel ch) //Here we set regression labels to what is in bin file if the values were not overriden with floats } + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + } + protected override Test ConstructTestForTrainingData() { return new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter); @@ -237,6 +278,12 @@ protected override void InitializeTests() } } } + + protected override BinaryPredictionTransformer> MakeTransformer(IPredictorWithFeatureWeights model, ISchema trainSchema) + => new BinaryPredictionTransformer>(Host, model, trainSchema, FeatureColumn.Name); + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + internal sealed class ObjectiveImpl : ObjectiveFunctionBase, IStepSearch { private readonly bool[] _labels; diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 7173a9f6a3..4c36e6e4cc 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Runtime.InteropServices; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -39,10 +40,11 @@ namespace Microsoft.ML.Runtime.FastTree { /// - public sealed partial class FastTreeRankingTrainer : BoostingFastTreeTrainerBase, - IHasLabelGains + public sealed partial class FastTreeRankingTrainer + : BoostingFastTreeTrainerBase, FastTreeRankingPredictor>, + IHasLabelGains { - public const string LoadNameValue = "FastTreeRanking"; + internal const string LoadNameValue = "FastTreeRanking"; internal const string UserNameValue = "FastTree (Boosted Trees) Ranking"; internal const string Summary = "Trains gradient boosted decision trees to the LambdaRank quasi-gradient."; internal const string ShortName = "ftrank"; @@ -51,11 +53,42 @@ public sealed partial class FastTreeRankingTrainer : BoostingFastTreeTrainerBase private Test _specialTrainSetTest; private TestHistory _firstTestSetHistory; + /// + /// The prediction kind for this trainer. + /// public override PredictionKind PredictionKind => PredictionKind.Ranking; + private readonly SchemaShape.Column[] _outputColumns; + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn, + string weightColumn = null, Action advancedSettings = null) + : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + { + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } + + /// + /// Initializes a new instance of by using the legacy class. + /// public FastTreeRankingTrainer(IHostEnvironment env, Arguments args) - : base(env, args) + : base(env, args, MakeLabelColumn(args.LabelColumn)) { + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; } protected override float GetMaxLabel() @@ -63,7 +96,7 @@ protected override float GetMaxLabel() return GetLabelGains().Length - 1; } - public override FastTreeRankingPredictor Train(TrainContext context) + protected override FastTreeRankingPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -125,6 +158,11 @@ protected override void CheckArgs(IChannel ch) base.CheckArgs(ch); } + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); + } + protected override void Initialize(IChannel ch) { base.Initialize(ch); @@ -405,6 +443,11 @@ protected override string GetTestGraphHeader() return headerBuilder.ToString(); } + protected override RankingPredictionTransformer MakeTransformer(FastTreeRankingPredictor model, ISchema trainSchema) + => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + public sealed class LambdaRankObjectiveFunction : ObjectiveFunctionBase, IStepSearch { private readonly short[] _labels; diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 287cfe9e1c..3984582c61 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -2,8 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -32,7 +34,8 @@ namespace Microsoft.ML.Runtime.FastTree { /// - public sealed partial class FastTreeRegressionTrainer : BoostingFastTreeTrainerBase + public sealed partial class FastTreeRegressionTrainer + : BoostingFastTreeTrainerBase, FastTreeRegressionPredictor> { public const string LoadNameValue = "FastTreeRegression"; internal const string UserNameValue = "FastTree (Boosted Trees) Regression"; @@ -43,14 +46,45 @@ public sealed partial class FastTreeRegressionTrainer : BoostingFastTreeTrainerB private Test _trainRegressionTest; private Test _testRegressionTest; + /// + /// The type of prediction for the trainer. + /// public override PredictionKind PredictionKind => PredictionKind.Regression; + private readonly SchemaShape.Column[] _outputColumns; + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn, + string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) + : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + { + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } + + /// + /// Initializes a new instance of by using the legacy class. + /// public FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) - : base(env, args) + : base(env, args, MakeLabelColumn(args.LabelColumn)) { + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; } - public override FastTreeRegressionPredictor Train(TrainContext context) + protected override FastTreeRegressionPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -79,6 +113,11 @@ protected override void CheckArgs(IChannel ch) "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); } + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + } + protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) { return new ObjectiveImpl(TrainSet, Args); @@ -124,6 +163,11 @@ protected override Test ConstructTestForTrainingData() return new RegressionTest(ConstructScoreTracker(TrainSet)); } + protected override RegressionPredictionTransformer MakeTransformer(FastTreeRegressionPredictor model, ISchema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + private void AddFullRegressionTests() { // Always compute training L1/L2 errors. diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index d02928884f..4511ab7f1d 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -5,6 +5,7 @@ using System; using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -31,7 +32,8 @@ namespace Microsoft.ML.Runtime.FastTree // Yang, Quan, and Zou. "Insurance Premium Prediction via Gradient Tree-Boosted Tweedie Compound Poisson Models." // https://arxiv.org/pdf/1508.06378.pdf /// - public sealed partial class FastTreeTweedieTrainer : BoostingFastTreeTrainerBase + public sealed partial class FastTreeTweedieTrainer + : BoostingFastTreeTrainerBase, FastTreeTweediePredictor> { public const string LoadNameValue = "FastTreeTweedieRegression"; public const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; @@ -44,13 +46,34 @@ public sealed partial class FastTreeTweedieTrainer : BoostingFastTreeTrainerBase public override PredictionKind PredictionKind => PredictionKind.Regression; + private SchemaShape.Column[] _outputColumns; + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn, + string groupIdColumn, string weightColumn = null, Action advancedSettings = null) + : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + { + Initialize(); + } + + /// + /// Initializes a new instance of by using the legacy class. + /// public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) - : base(env, args) + : base(env, args, MakeLabelColumn(args.LabelColumn)) { - Host.CheckUserArg(1 <= Args.Index && Args.Index <= 2, nameof(Args.Index), "Must be in the range [1, 2]"); + Initialize(); } - public override FastTreeTweediePredictor Train(TrainContext context) + protected override FastTreeTweediePredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -134,6 +157,16 @@ protected override Test ConstructTestForTrainingData() return new RegressionTest(ConstructScoreTracker(TrainSet)); } + private void Initialize() + { + Host.CheckUserArg(1 <= Args.Index && Args.Index <= 2, nameof(Args.Index), "Must be in the range [1, 2]"); + + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) + }; + } + private void AddFullRegressionTests() { // Always compute training L1/L2 errors. @@ -269,6 +302,16 @@ protected override void Train(IChannel ch) PrintTestGraph(ch); } + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + } + + protected override RegressionPredictionTransformer MakeTransformer(FastTreeTweediePredictor model, ISchema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + private sealed class ObjectiveImpl : ObjectiveFunctionBase, IStepSearch { private readonly float[] _labels; diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 88676754d5..5ce40742ca 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -2,20 +2,34 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - +using System; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.FastTree.Internal; namespace Microsoft.ML.Runtime.FastTree { - public abstract class RandomForestTrainerBase : FastTreeTrainerBase + public abstract class RandomForestTrainerBase : FastTreeTrainerBase where TArgs : FastForestArgumentsBase, new() - where TPredictor : IPredictorProducing + where TModel : IPredictorProducing + where TTransformer: IPredictionTransformer { private readonly bool _quantileEnabled; - protected RandomForestTrainerBase(IHostEnvironment env, TArgs args, bool quantileEnabled = false) - : base(env, args) + /// + /// Constructor invoked by the maml code-path. + /// + protected RandomForestTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label, bool quantileEnabled = false) + : base(env, args, label) + { + _quantileEnabled = quantileEnabled; + } + + /// + /// Constructor invoked by the API code-path. + /// + protected RandomForestTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, + string weightColumn = null, string groupIdColumn = null, bool quantileEnabled = false, Action advancedSettings = null) + : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) { _quantileEnabled = quantileEnabled; } diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index ae79c991d3..c5222d4e87 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -2,10 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -25,7 +24,7 @@ FastForestClassification.ShortName, "ffc")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastForestClassificationPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastForestClassificationPredictor), null, typeof(SignatureLoadModel), "FastForest Binary Executor", FastForestClassificationPredictor.LoaderSignature)] @@ -73,13 +72,15 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010006; + /// + /// The type of prediction for this trainer. + /// public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; internal FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) - { - } + { } private FastForestClassificationPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) @@ -92,7 +93,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -108,7 +109,7 @@ public static IPredictorProducing Create(IHostEnvironment env, ModelLoadC /// public sealed partial class FastForestClassification : - RandomForestTrainerBase> + RandomForestTrainerBase>, IPredictorWithFeatureWeights> { public sealed class Arguments : FastForestArgumentsBase { @@ -123,21 +124,52 @@ public sealed class Arguments : FastForestArgumentsBase } internal const string LoadNameValue = "FastForestClassification"; - public const string UserNameValue = "Fast Forest Classification"; - public const string Summary = "Uses a random forest learner to perform binary classification."; - public const string ShortName = "ff"; + internal const string UserNameValue = "Fast Forest Classification"; + internal const string Summary = "Uses a random forest learner to perform binary classification."; + internal const string ShortName = "ff"; private bool[] _trainSetLabels; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; private protected override bool NeedCalibration => true; + private readonly SchemaShape.Column[] _outputColumns; + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastForestClassification(IHostEnvironment env, string labelColumn, string featureColumn, + string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) + : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + { + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) + }; + } + /// + /// Initializes a new instance of by using the legacy class. + /// public FastForestClassification(IHostEnvironment env, Arguments args) - : base(env, args) + : base(env, args, MakeLabelColumn(args.LabelColumn)) { + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false) + }; } - public override IPredictorWithFeatureWeights Train(TrainContext context) + protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -175,11 +207,21 @@ protected override void PrepareLabels(IChannel ch) _trainSetLabels = TrainSet.Ratings.Select(x => x >= 1).ToArray(TrainSet.NumDocs); } + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + } + protected override Test ConstructTestForTrainingData() { return new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, 1); } + protected override BinaryPredictionTransformer> MakeTransformer(IPredictorWithFeatureWeights model, ISchema trainSchema) + => new BinaryPredictionTransformer>(Host, model, trainSchema, FeatureColumn.Name); + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + private sealed class ObjectiveFunctionImpl : RandomForestObjectiveFunction { private readonly bool[] _labels; diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index c580851534..510c3b0ec6 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; @@ -15,6 +13,7 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Core.Data; [assembly: LoadableClass(FastForestRegression.Summary, typeof(FastForestRegression), typeof(FastForestRegression.Arguments), new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) }, @@ -59,7 +58,7 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010006; - internal FastForestRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, + public FastForestRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { @@ -101,32 +100,32 @@ public static FastForestRegressionPredictor Create(IHostEnvironment env, ModelLo public override PredictionKind PredictionKind => PredictionKind.Regression; - protected override void Map(ref VBuffer src, ref Float dst) + protected override void Map(ref VBuffer src, ref float dst) { if (InputType.VectorSize > 0) Host.Check(src.Length == InputType.VectorSize); else Host.Check(src.Length > MaxSplitFeatIdx); - dst = (Float)TrainedEnsemble.GetOutput(ref src) / TrainedEnsemble.NumTrees; + dst = (float)TrainedEnsemble.GetOutput(ref src) / TrainedEnsemble.NumTrees; } - public ValueMapper, VBuffer> GetMapper(Float[] quantiles) + public ValueMapper, VBuffer> GetMapper(float[] quantiles) { return - (ref VBuffer src, ref VBuffer dst) => + (ref VBuffer src, ref VBuffer dst) => { // REVIEW: Should make this more efficient - it repeatedly allocates too much stuff. - Float[] weights = null; + float[] weights = null; var distribution = TrainedEnsemble.GetDistribution(ref src, _quantileSampleCount, out weights); var qdist = new QuantileStatistics(distribution, weights); var values = dst.Values; if (Utils.Size(values) < quantiles.Length) - values = new Float[quantiles.Length]; + values = new float[quantiles.Length]; for (int i = 0; i < quantiles.Length; i++) - values[i] = qdist.GetQuantile((Float)quantiles[i]); - dst = new VBuffer(quantiles.Length, values, dst.Indices); + values[i] = qdist.GetQuantile((float)quantiles[i]); + dst = new VBuffer(quantiles.Length, values, dst.Indices); }; } @@ -138,7 +137,8 @@ public ISchemaBindableMapper CreateMapper(Double[] quantiles) } /// - public sealed partial class FastForestRegression : RandomForestTrainerBase + public sealed partial class FastForestRegression + : RandomForestTrainerBase, FastForestRegressionPredictor> { public sealed class Arguments : FastForestArgumentsBase { @@ -147,20 +147,47 @@ public sealed class Arguments : FastForestArgumentsBase public bool ShuffleLabels; } - internal const string Summary = "Trains a random forest to fit target values using least-squares."; + public override PredictionKind PredictionKind => PredictionKind.Regression; + internal const string Summary = "Trains a random forest to fit target values using least-squares."; internal const string LoadNameValue = "FastForestRegression"; internal const string UserNameValue = "Fast Forest Regression"; internal const string ShortName = "ffr"; - public FastForestRegression(IHostEnvironment env, Arguments args) - : base(env, args, true) + private readonly SchemaShape.Column[] _outputColumns; + + /// + /// Initializes a new instance of + /// + /// The private instance of . + /// The name of the label column. + /// The name of the feature column. + /// The name for the column containing the group ID. + /// The name for the column containing the initial weight. + /// A delegate to apply all the advanced arguments to the algorithm. + public FastForestRegression(IHostEnvironment env, string labelColumn, string featureColumn, + string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) + : base(env, MakeLabelColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, true, advancedSettings) { + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) + }; } - public override PredictionKind PredictionKind => PredictionKind.Regression; + /// + /// Initializes a new instance of by using the legacy class. + /// + public FastForestRegression(IHostEnvironment env, Arguments args) + : base(env, args, MakeLabelColumn(args.LabelColumn), true) + { + _outputColumns = new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) + }; + } - public override FastForestRegressionPredictor Train(TrainContext context) + protected override FastForestRegressionPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -194,6 +221,16 @@ protected override Test ConstructTestForTrainingData() return new RegressionTest(ConstructScoreTracker(TrainSet)); } + protected override RegressionPredictionTransformer MakeTransformer(FastForestRegressionPredictor model, ISchema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + + private static SchemaShape.Column MakeLabelColumn(string labelColumn) + { + return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + } + + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) => _outputColumns; + private abstract class ObjectiveFunctionImplBase : RandomForestObjectiveFunction { private readonly float[] _labels; diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/TransformGenerators.cs b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/TransformGenerators.cs index 26d4c94dda..b17efbbbe8 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/TransformGenerators.cs +++ b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/TransformGenerators.cs @@ -151,7 +151,7 @@ protected override void GenerateImplCall(IndentingTextWriter w, string prefix, C var argumentInfo = CmdParser.GetArgInfo(component.ArgType, component.CreateArguments()); foreach (var arg in argumentInfo.Args.Where(a => !a.IsHidden)) GenerateImplCall(w, arg, ""); - w.WriteLine("var env = new TlcEnvironment(1, verbose: true);"); + w.WriteLine("var env = new LocalEnvironment(1, verbose: true);"); w.WriteLine("var view = builder.Create{0}{1}Impl(env, data);", prefix, component.LoadNames[0]); w.WriteLine("return new Tuple(view, new DataTransform(view));"); } diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index b5614e3cbf..2aaa3e0d62 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -137,13 +137,13 @@ private FastForestRegressionPredictor FitModel(IEnumerable previousR using (IChannel ch = _host.Start("Single training")) { // Set relevant random forest arguments. - FastForestRegression.Arguments args = new FastForestRegression.Arguments(); - args.FeatureFraction = _args.SplitRatio; - args.NumTrees = _args.NumOfTrees; - args.MinDocumentsInLeafs = _args.NMinForSplit; - // Train random forest. - var trainer = new FastForestRegression(_host, args); + var trainer = new FastForestRegression(_host, DefaultColumnNames.Label, DefaultColumnNames.Features, advancedSettings: s => + { + s.FeatureFraction = _args.SplitRatio; + s.NumTrees = _args.NumOfTrees; + s.MinDocumentsInLeafs = _args.NMinForSplit; + }); var predictor = trainer.Train(data); // Return random forest predictor. diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index b8d2ee4b77..1b4e233542 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -699,7 +699,7 @@ public void EntryPointCalibrate() calibratedLrModel = Calibrate.Pav(Env, input).PredictorModel; // This tests that the SchemaBindableCalibratedPredictor doesn't get confused if its sub-predictor is already calibrated. - var fastForest = new FastForestClassification(Env, new FastForestClassification.Arguments()); + var fastForest = new FastForestClassification(Env, "Label", "Features"); var rmd = new RoleMappedData(splitOutput.TrainData[0], "Label", "Features"); var ffModel = new PredictorModel(Env, rmd, splitOutput.TrainData[0], fastForest.Train(rmd)); var calibratedFfModel = Calibrate.Platt(Env, diff --git a/test/Microsoft.ML.Predictor.Tests/Test-API.cs b/test/Microsoft.ML.Predictor.Tests/Test-API.cs index 40efdf25df..acd01cd151 100644 --- a/test/Microsoft.ML.Predictor.Tests/Test-API.cs +++ b/test/Microsoft.ML.Predictor.Tests/Test-API.cs @@ -317,7 +317,7 @@ public void FactoryExampleTest() List originalOutputs = new List(); List originalProbabilities = new List(); - var env = new TlcEnvironment(SysRandom.Wrap(RunExperiments.GetRandom(cmd))); + var env = new LocalEnvironment(SysRandom.Wrap(RunExperiments.GetRandom(cmd))); Instances instances = RunExperiments.CreateTestData(cmd, testDataFilename, dataModel, null, env); foreach (Instance instance in instances) { diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 206ac86779..0b2d7691d5 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -1532,7 +1532,7 @@ public void OneClassSvmLibsvmWrapperDenseTest() [TestCategory("Anomaly")] public void CompareSvmPredictorResultsToLibSvm() { - using (var env = new TlcEnvironment(1, conc: 1)) + using (var env = new LocalEnvironment(1, conc: 1)) { IDataView trainView = new TextLoader(env, new TextLoader.Arguments(), new MultiFileSource(GetDataPath(TestDatasets.mnistOneClass.trainFilename))); trainView = diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index 99fb602490..f3c21feabf 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -133,7 +133,7 @@ protected virtual void InitializeCore() } // This method is used by subclass to dispose of disposable objects - // such as TlcEnvironment. + // such as LocalEnvironment. // It is called as a first step in test clean up. protected virtual void CleanupCore() { diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 42766d3934..40b29a498e 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Reflection; using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; @@ -780,6 +781,36 @@ protected bool SaveLoadTransposed(IDataView view, IHostEnvironment env, string s public abstract partial class TestDataViewBase : BaseTestBaseline { + + public class SentimentData + { + [ColumnName("Label")] + public bool Sentiment; + public string SentimentText; + } + + public class SentimentPrediction + { + [ColumnName("PredictedLabel")] + public bool Sentiment; + + public float Score; + } + + private static TextLoader.Arguments MakeSentimentTextLoaderArgs() + { + return new TextLoader.Arguments() + { + Separator = "tab", + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Label", DataKind.BL, 0), + new TextLoader.Column("SentimentText", DataKind.Text, 1) + } + }; + } + protected bool Failed() { Contracts.Assert(!IsPassing); diff --git a/test/Microsoft.ML.TestFramework/Datasets.cs b/test/Microsoft.ML.TestFramework/Datasets.cs index 6de4938433..888a06afce 100644 --- a/test/Microsoft.ML.TestFramework/Datasets.cs +++ b/test/Microsoft.ML.TestFramework/Datasets.cs @@ -2,11 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; -using Microsoft.ML.Runtime.Numeric; - namespace Microsoft.ML.Runtime.RunTests { public class TestDataset @@ -167,6 +162,13 @@ public static class TestDatasets testFilename = "external/WikiDetoxAnnotated160kRows.tsv" }; + public static TestDataset Sentiment = new TestDataset + { + name = "sentiment", + trainFilename = "wikipedia-detox-250-line-data.tsv", + testFilename = "wikipedia-detox-250-line-test.tsv" + }; + public static TestDataset generatedRegressionDataset = new TestDataset { name = "generatedRegressionDataset", @@ -304,6 +306,13 @@ public static class TestDatasets extraSettings = @"/inst Text{header+ sep=, label=14 attr=5-9,1,13,3 threads-}" }; + public static TestDataset adultRanking = new TestDataset + { + name = "adultRanking", + trainFilename = "adult.tiny.with-schema.txt", + loaderSettings = "loader=Text{header+ sep=tab, col=Label:R4:0 col=Workclass:TX:1 col=Categories:TX:2-8 col=NumericFeatures:R4:9-14}", + }; + public static TestDataset displayPoisson = new TestDataset { name = "DisplayPoisson", @@ -362,6 +371,13 @@ public static class TestDatasets mamlExtraSettings = new[] { "xf=Term{col=Label}" }, }; + public static TestDataset irisData = new TestDataset() + { + name = "iris", + trainFilename = @"iris.data", + loaderSettings = "loader=Text{col=Label:TX:4 col=Features:0-3}" + }; + public static TestDataset irisLabelName = new TestDataset() { name = "iris-label-name", diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs index a8b7aaa7ad..efc4587bd2 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs @@ -17,10 +17,6 @@ public ApiScenariosTests(ITestOutputHelper output) : base(output) { } - public const string IrisDataPath = "iris.data"; - public const string SentimentDataPath = "wikipedia-detox-250-line-data.tsv"; - public const string SentimentTestPath = "wikipedia-detox-250-line-test.tsv"; - public class IrisData : IrisDataNoLabel { public string Label; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs index fd2c5f0142..ebf6a24e7e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs @@ -168,7 +168,7 @@ public class GithubClassification { public void ClassifyGithubIssues() { - var env = new TlcEnvironment(new SysRandom(0), verbose: true); + var env = new LocalEnvironment(new SysRandom(0), verbose: true); string dataPath = "corefx-issues-train.tsv"; @@ -228,7 +228,7 @@ public class SimpleTransform { public void ScaleData() { - var env = new TlcEnvironment(new SysRandom(0), verbose: true); + var env = new LocalEnvironment(new SysRandom(0), verbose: true); string dataPath = "iris.txt"; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/AutoNormalizationAndCaching.cs b/test/Microsoft.ML.Tests/Scenarios/Api/AutoNormalizationAndCaching.cs index 00fb528307..54f53528f0 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/AutoNormalizationAndCaching.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/AutoNormalizationAndCaching.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -18,13 +19,12 @@ public partial class ApiScenariosTests [Fact] public void AutoNormalizationAndCaching() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); + var data = GetDataPath(TestDatasets.Sentiment.trainFilename); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline. - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(data)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CrossValidation.cs index e55fe53743..0b74889515 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CrossValidation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CrossValidation.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Legacy.Models; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using System; using System.Collections.Generic; using System.Linq; @@ -25,14 +26,13 @@ public partial class ApiScenariosTests [Fact] void CrossValidation() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); + var dataset = TestDatasets.Sentiment; int numFolds = 5; using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline. - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(dataset.trainFilename))); var text = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn"); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs index c883151590..e6a7b57944 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using System.Linq; using Xunit; @@ -24,10 +25,9 @@ public partial class ApiScenariosTests [Fact] void DecomposableTrainAndPredict() { - var dataPath = GetDataPath(IrisDataPath); using (var env = new LocalEnvironment()) { - var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename))); var term = TermTransform.Create(env, loader, "Label"); var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term); var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); @@ -47,8 +47,7 @@ void DecomposableTrainAndPredict() var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(newScorer); var model = env.CreatePredictionEngine(keyToValue); - var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); - var testData = testLoader.AsEnumerable(env, false); + var testData = loader.AsEnumerable(env, false); foreach (var input in testData.Take(20)) { var prediction = model.Predict(input); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs index 4073054051..2d0cb435bd 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -21,14 +22,11 @@ public partial class ApiScenariosTests [Fact] void New_CrossValidation() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var data = new TextLoader(env, MakeSentimentTextLoaderArgs()) - .Read(new MultiFileSource(dataPath)); + .Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); // Pipeline. var pipeline = new TextTransform(env, "SentimentText", "Features") .Append(new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs index 9d1d68dc4b..79b66e43bc 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using System.Linq; using Xunit; @@ -23,26 +24,35 @@ public partial class ApiScenariosTests [Fact] void New_DecomposableTrainAndPredict() { - var dataPath = GetDataPath(IrisDataPath); + var dataPath = GetDataPath(TestDatasets.irisData.trainFilename); using (var env = new LocalEnvironment()) { - var data = new TextLoader(env, MakeIrisTextLoaderArgs()) - .Read(new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); + var term = TermTransform.Create(env, loader, "Label"); + var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term); + var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); - var pipeline = new ConcatEstimator(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") - .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) - .Append(new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label")) - .Append(new KeyToValueEstimator(env, "PredictedLabel")); + IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat; + var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); - var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); - var engine = model.MakePredictionFunction(env); + // Auto-normalization. + NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer); + var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); + + var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features"); + IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); + + // Cut out term transform from pipeline. + var newScorer = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term); + var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(newScorer); + var model = env.CreatePredictionEngine(keyToValue); var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); - var testData = testLoader.AsEnumerable(env, false); + var testData = testLoader.AsEnumerable(env, false); foreach (var input in testData.Take(20)) { - var prediction = engine.Predict(input); - Assert.True(prediction.PredictedLabel == input.Label); + var prediction = model.Predict(input); + Assert.True(prediction.PredictedLabel == "Iris-setosa"); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs index cf62d68912..41deccc3fb 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -19,9 +20,6 @@ public partial class ApiScenariosTests [Fact] public void New_Evaluation() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); @@ -32,10 +30,10 @@ public void New_Evaluation() .Append(new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); // Train. - var readerModel = pipeline.Fit(new MultiFileSource(dataPath)); + var readerModel = pipeline.Fit(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); // Evaluate on the test set. - var dataEval = readerModel.Read(new MultiFileSource(testDataPath)); + var dataEval = readerModel.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.validFilename))); var evaluator = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { }); var metrics = evaluator.Evaluate(dataEval); } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs index 4c5cddf767..ee4bef02e1 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using System; using System.Linq; using Xunit; @@ -22,7 +23,8 @@ public partial class ApiScenariosTests [Fact] void New_Extensibility() { - var dataPath = GetDataPath(IrisDataPath); + var dataPath = GetDataPath(TestDatasets.irisData.trainFilename); + using (var env = new LocalEnvironment()) { var data = new TextLoader(env, MakeIrisTextLoaderArgs()) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs index c291b37ced..bfeb59506a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -22,14 +23,11 @@ public partial class ApiScenariosTests [Fact] void New_FileBasedSavingOfData() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var trainData = new TextLoader(env, MakeSentimentTextLoaderArgs()) .Append(new TextTransform(env, "SentimentText", "Features")) - .FitAndRead(new MultiFileSource(dataPath)); + .FitAndRead(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); using (var file = env.CreateOutputFile("i.idv")) trainData.SaveAsBinary(env, file.CreateWriteStream()); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs index 17f452ff09..68eba74127 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.TextAnalytics; using System.Collections.Generic; using Xunit; @@ -32,13 +33,10 @@ public partial class ApiScenariosTests [Fact] public void New_IntrospectiveTraining() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var data = new TextLoader(env, MakeSentimentTextLoaderArgs()) - .Read(new MultiFileSource(dataPath)); + .Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var pipeline = new TextTransform(env, "SentimentText", "Features") .Append(new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs index 05d7b33445..e69de29bb2 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs @@ -1,39 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Calibration; -using Microsoft.ML.Runtime.Learners; -using System.Linq; -using Xunit; - -namespace Microsoft.ML.Tests.Scenarios.Api -{ - public partial class ApiScenariosTests - { - /// - /// Meta-components: Meta-components (e.g., components that themselves instantiate components) should not be booby-trapped. - /// When specifying what trainer OVA should use, a user will be able to specify any binary classifier. - /// If they specify a regression or multi-class classifier ideally that should be a compile error. - /// - [Fact] - public void New_Metacomponents() - { - var dataPath = GetDataPath(IrisDataPath); - using (var env = new LocalEnvironment()) - { - var data = new TextLoader(env, MakeIrisTextLoaderArgs()) - .Read(new MultiFileSource(dataPath)); - - var sdcaTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"); - var pipeline = new ConcatEstimator(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") - .Append(new TermEstimator(env, "Label"), TransformerScope.TrainTest) - .Append(new Ova(env, sdcaTrainer)) - .Append(new KeyToValueEstimator(env, "PredictedLabel")); - - var model = pipeline.Fit(data); - } - } - } -} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs index 2e35acd71a..dcabc9ec43 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using System.Threading.Tasks; using Xunit; @@ -24,13 +25,10 @@ public partial class ApiScenariosTests [Fact] void New_MultithreadedPrediction() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); - var data = reader.Read(new MultiFileSource(dataPath)); + var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); // Pipeline. var pipeline = new TextTransform(env, "SentimentText", "Features") @@ -43,7 +41,7 @@ void New_MultithreadedPrediction() var engine = model.MakePredictionFunction(env); // Take a couple examples out of the test data and run predictions on top. - var testData = reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath))) + var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))) .AsEnumerable(env, false); Parallel.ForEach(testData, (input) => diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs index 42697089c3..0516e5ed4f 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -20,15 +21,12 @@ public partial class ApiScenariosTests [Fact] public void New_ReconfigurablePrediction() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var dataReader = new TextLoader(env, MakeSentimentTextLoaderArgs()); - var data = dataReader.Read(new MultiFileSource(dataPath)); - var testData = dataReader.Read(new MultiFileSource(testDataPath)); + var data = dataReader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); + var testData = dataReader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))); // Pipeline. var pipeline = new TextTransform(env, "SentimentText", "Features") diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs index 1257196253..565255305a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Learners; using Xunit; using System.Linq; +using Microsoft.ML.Runtime.RunTests; namespace Microsoft.ML.Tests.Scenarios.Api { @@ -21,13 +22,10 @@ public partial class ApiScenariosTests [Fact] public void New_SimpleTrainAndPredict() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); - var data = reader.Read(new MultiFileSource(dataPath)); + var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); // Pipeline. var pipeline = new TextTransform(env, "SentimentText", "Features") .Append(new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); @@ -39,7 +37,7 @@ public void New_SimpleTrainAndPredict() var engine = model.MakePredictionFunction(env); // Take a couple examples out of the test data and run predictions on top. - var testData = reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath))) + var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))) .AsEnumerable(env, false); foreach (var input in testData.Take(5)) { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs index 013c269e92..b14e26b66c 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using System.Linq; using Xunit; @@ -22,13 +23,10 @@ public partial class ApiScenariosTests [Fact] public void New_TrainSaveModelAndPredict() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); - var data = reader.Read(new MultiFileSource(dataPath)); + var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); // Pipeline. var pipeline = new TextTransform(env, "SentimentText", "Features") @@ -52,7 +50,7 @@ public void New_TrainSaveModelAndPredict() var engine = loadedModel.MakePredictionFunction(env); // Take a couple examples out of the test data and run predictions on top. - var testData = reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath))) + var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))) .AsEnumerable(env, false); foreach (var input in testData.Take(5)) { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs index 09bca124e0..afc456dac6 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -18,11 +19,10 @@ public partial class ApiScenariosTests [Fact] public void New_TrainWithInitialPredictor() { - var dataPath = GetDataPath(SentimentDataPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { - var data = new TextLoader(env, MakeSentimentTextLoaderArgs()).Read(new MultiFileSource(dataPath)); + var data = new TextLoader(env, MakeSentimentTextLoaderArgs()).Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); // Pipeline. var pipeline = new TextTransform(env, "SentimentText", "Features"); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs index 74431f7279..a689cab502 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -17,9 +18,6 @@ public partial class ApiScenariosTests [Fact] public void New_TrainWithValidationSet() { - var dataPath = GetDataPath(SentimentDataPath); - var validationDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline. @@ -27,10 +25,10 @@ public void New_TrainWithValidationSet() var pipeline = new TextTransform(env, "SentimentText", "Features"); // Train the pipeline, prepare train and validation set. - var data = reader.Read(new MultiFileSource(dataPath)); + var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var preprocess = pipeline.Fit(data); var trainData = preprocess.Transform(data); - var validData = preprocess.Transform(reader.Read(new MultiFileSource(validationDataPath))); + var validData = preprocess.Transform(reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename)))); // Train model with validation set. var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments(), "Features", "Label"); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs index 7d0e0e7236..eec9a8a430 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -20,13 +21,12 @@ public partial class ApiScenariosTests [Fact] void New_Visibility() { - var dataPath = GetDataPath(SentimentDataPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var pipeline = new TextLoader(env, MakeSentimentTextLoaderArgs()) .Append(new TextTransform(env, "SentimentText", "Features", s => s.OutputTokens = true)); - var data = pipeline.FitAndRead(new MultiFileSource(dataPath)); + var data = pipeline.FitAndRead(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); // In order to find out available column names, you can go through schema and check // column names and appropriate type for getter. for (int i = 0; i < data.Schema.ColumnCount; i++) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Evaluation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Evaluation.cs index 3004155784..e90af95d3c 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Evaluation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Evaluation.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -20,13 +21,10 @@ public partial class ApiScenariosTests [Fact] public void Evaluation() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); @@ -46,7 +44,7 @@ public void Evaluation() var model = env.CreatePredictionEngine(scorer); // Take a couple examples out of the test data and run predictions on top. - var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); + var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.validFilename))); var testData = testLoader.AsEnumerable(env, false); var dataEval = new RoleMappedData(scorer, label: "Label", feature: "Features", opt: true); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs index 00d0ee2a34..f893e5de47 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs @@ -1,6 +1,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using System; using System.Linq; using Xunit; @@ -18,10 +19,9 @@ public partial class ApiScenariosTests [Fact] void Extensibility() { - var dataPath = GetDataPath(IrisDataPath); using (var env = new LocalEnvironment()) { - var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename))); Action action = (i, j) => { j.Label = i.Label; @@ -50,7 +50,7 @@ void Extensibility() var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(scorer); var model = env.CreatePredictionEngine(keyToValue); - var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); + var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename))); var testData = testLoader.AsEnumerable(env, false); foreach (var input in testData.Take(20)) { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/FileBasedSavingOfData.cs b/test/Microsoft.ML.Tests/Scenarios/Api/FileBasedSavingOfData.cs index da6c4fcedb..facba8eeb3 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/FileBasedSavingOfData.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/FileBasedSavingOfData.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -21,13 +22,10 @@ public partial class ApiScenariosTests [Fact] void FileBasedSavingOfData() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); var saver = new BinarySaver(env, new BinarySaver.Arguments()); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs index 254485cebc..198b3c312b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.TextAnalytics; using System.Collections.Generic; using Xunit; @@ -40,12 +41,11 @@ private TOut GetValue(Dictionary keyValues, string key) [Fact] public void IntrospectiveTraining() { - var dataPath = GetDataPath(SentimentDataPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var words = WordBagTransform.Create(env, new WordBagTransform.Arguments() { @@ -75,12 +75,8 @@ public void IntrospectiveTraining() linearPredictor.GetFeatureWeights(ref weights); var topicSummary = lda.GetTopicSummary(); - var treeTrainer = new FastTreeBinaryClassificationTrainer(env, - new FastTreeBinaryClassificationTrainer.Arguments - { - NumTrees = 2 - } - ); + var treeTrainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, + advancedSettings: s =>{ s.NumTrees = 2; }); var ftPredictor = treeTrainer.Train(new Runtime.TrainContext(trainRoles)); FastTreeBinaryPredictor treePredictor; if (ftPredictor is CalibratedPredictorBase calibrator) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs index ed48171f4d..994b47310b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs @@ -4,8 +4,10 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -20,10 +22,9 @@ public partial class ApiScenariosTests [Fact] public void Metacomponents() { - var dataPath = GetDataPath(IrisDataPath); using (var env = new LocalEnvironment()) { - var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename))); var term = TermTransform.Create(env, loader, "Label"); var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term); var trainer = new Ova(env, new Ova.Arguments diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs index cb6b4dd6f0..e70c46a73f 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using System.Linq; using System.Threading.Tasks; using Xunit; @@ -25,13 +26,10 @@ public partial class ApiScenariosTests [Fact] void MultithreadedPrediction() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); @@ -52,7 +50,7 @@ void MultithreadedPrediction() var model = env.CreatePredictionEngine(scorer); // Take a couple examples out of the test data and run predictions on top. - var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); + var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))); var testData = testLoader.AsEnumerable(env, false); Parallel.ForEach(testData, (input) => diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/ReconfigurablePrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/ReconfigurablePrediction.cs index 307f5e086a..fdfcad6653 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/ReconfigurablePrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/ReconfigurablePrediction.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -21,13 +22,12 @@ public partial class ApiScenariosTests [Fact] void ReconfigurablePrediction() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); + var data = GetDataPath(TestDatasets.Sentiment.trainFilename); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(data)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs index 26b9869aa8..2bb45ed61a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs @@ -7,6 +7,8 @@ using Microsoft.ML.Runtime.Learners; using Xunit; using System.Linq; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.RunTests; namespace Microsoft.ML.Tests.Scenarios.Api { @@ -21,13 +23,12 @@ public partial class ApiScenariosTests [Fact] public void SimpleTrainAndPredict() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); + var dataset = TestDatasets.Sentiment; using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(dataset.trainFilename))); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); @@ -48,7 +49,7 @@ public void SimpleTrainAndPredict() var model = env.CreatePredictionEngine(scorer); // Take a couple examples out of the test data and run predictions on top. - var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); + var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(dataset.testFilename))); var testData = testLoader.AsEnumerable(env, false); foreach (var input in testData.Take(5)) { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs index 7cfe2aa73b..150f72936b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using System.Linq; using Xunit; @@ -21,13 +22,10 @@ public partial class ApiScenariosTests [Fact] public void TrainSaveModelAndPredict() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); @@ -55,7 +53,7 @@ public void TrainSaveModelAndPredict() } // Take a couple examples out of the test data and run predictions on top. - var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); + var testLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))); var testData = testLoader.AsEnumerable(env, false); foreach (var input in testData.Take(5)) { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TrainWithInitialPredictor.cs index 9683af241c..6a37b1f408 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TrainWithInitialPredictor.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TrainWithInitialPredictor.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -18,12 +19,11 @@ public partial class ApiScenariosTests [Fact] public void TrainWithInitialPredictor() { - var dataPath = GetDataPath(SentimentDataPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); var trainData = trans; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TrainWithValidationSet.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TrainWithValidationSet.cs index c2fabcd970..f2460df8f7 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TrainWithValidationSet.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TrainWithValidationSet.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -17,13 +18,11 @@ public partial class ApiScenariosTests [Fact] public void TrainWithValidationSet() { - var dataPath = GetDataPath(SentimentDataPath); - var validationDataPath = GetDataPath(SentimentTestPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); var trainData = trans; @@ -33,7 +32,7 @@ public void TrainWithValidationSet() // to create another loader, or to save the loader to model file and then reload. // A new one is not always feasible, but this time it is. - var validLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(validationDataPath)); + var validLoader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.testFilename))); var validData = ApplyTransformUtils.ApplyAllTransformsToData(env, trainData, validLoader); // Cache both datasets. @@ -41,10 +40,7 @@ public void TrainWithValidationSet() var cachedValid = new CacheDataView(env, validData, prefetch: null); // Train. - var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments - { - NumTrees = 3 - }); + var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, advancedSettings:s => { s.NumTrees = 3; }); var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features"); var validRoles = new RoleMappedData(cachedValid, label: "Label", feature: "Features"); trainer.Train(new Runtime.TrainContext(trainRoles, validRoles)); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs index 47a91e44c2..aee89c3e7f 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.RunTests; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -20,13 +21,10 @@ public partial class ApiScenariosTests [Fact] void Visibility() { - var dataPath = GetDataPath(SentimentDataPath); - var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline. - var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); + var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs index 95b476b361..f3e4bd5865 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs @@ -55,11 +55,11 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() loader); // Train - var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments() + var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, advancedSettings: s=> { - NumLeaves = 5, - NumTrees = 5, - MinDocumentsInLeafs = 2 + s.NumLeaves = 5; + s.NumTrees = 5; + s.MinDocumentsInLeafs = 2; }); var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); @@ -136,12 +136,12 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordE ModelKind = WordEmbeddingsTransform.PretrainedModelKind.Sswe, }, text); // Train - var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments() - { - NumLeaves = 5, - NumTrees = 5, - MinDocumentsInLeafs = 2 - }); + var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, advancedSettings: s=> + { + s.NumLeaves = 5; + s.NumTrees = 5; + s.MinDocumentsInLeafs = 2; + }); var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); var pred = trainer.Train(trainRoles); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs new file mode 100644 index 0000000000..fa6784bd69 --- /dev/null +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -0,0 +1,123 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.RunTests; +using System.Linq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.TrainerEstimators +{ + public partial class TreeEstimators : TestDataPipeBase + { + + public TreeEstimators(ITestOutputHelper output) : base(output) + { + } + + + /// + /// FastTreeBinaryClassification TrainerEstimator test + /// + [Fact] + public void FastTreeBinaryEstimator() + { + using (var env = new LocalEnvironment(seed: 1, conc: 1)) + { + var reader = new TextLoader(env, + new TextLoader.Arguments() + { + Separator = "\t", + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Label", DataKind.BL, 0), + new TextLoader.Column("SentimentText", DataKind.Text, 1) + } + }); + + var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); + + // Pipeline. + var pipeline = new TextTransform(env, "SentimentText", "Features") + .Append(new FastTreeBinaryClassificationTrainer(env, "Label", "Features", advancedSettings: s => { + s.NumTrees = 10; + s.NumThreads = 1; + s.NumLeaves = 5; + })); + + TestEstimatorCore(pipeline, data); + } + } + + /// + /// FastTreeBinaryClassification TrainerEstimator test + /// + [Fact] + public void FastTreeRankerEstimator() + { + using (var env = new LocalEnvironment(seed: 1, conc: 1)) + { + var reader = new TextLoader(env, new TextLoader.Arguments + { + HasHeader = true, + Separator ="\t", + Column = new[] + { + new TextLoader.Column("Label", DataKind.R4, 0), + new TextLoader.Column("Workclass", DataKind.Text, 1), + new TextLoader.Column("NumericFeatures", DataKind.R4, new [] { new TextLoader.Range(9, 14) }) + } + }); + var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.adultRanking.trainFilename))); + + + // Pipeline. + var pipeline = new TermEstimator(env, new[]{ + new TermTransform.ColumnInfo("Workclass", "Group"), + new TermTransform.ColumnInfo("Label", "Label0") }) + .Append(new FastTreeRankingTrainer(env, "Label0", "NumericFeatures", "Group", + advancedSettings: s => { s.NumTrees = 10; })); + + TestEstimatorCore(pipeline, data); + } + } + + /// + /// FastTreeRegressor TrainerEstimator test + /// + [Fact] + public void FastTreeRegressorEstimator() + { + using (var env = new LocalEnvironment(seed: 1, conc: 1)) + { + // "loader=Text{col=Label:R4:11 col=Features:R4:0-10 sep=; header+}" + var reader = new TextLoader(env, + new TextLoader.Arguments() + { + Separator = ";", + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Label", DataKind.R4, 11), + new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(0, 10) } ) + } + }); + + var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename))); + + // Pipeline. + var pipeline = new FastTreeRegressionTrainer(env, "Label", "Features", advancedSettings: s => { + s.NumTrees = 10; + s.NumThreads = 1; + s.NumLeaves = 5; + }); + + TestEstimatorCore(pipeline, data); + } + } + } +}