diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index b7e01e47b0..958aba0da1 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -298,8 +298,6 @@ private static ColumnType GetPredColType(ColumnType scoreType, ISchemaBoundRowMa } private static bool OutputTypeMatches(ColumnType scoreType) - { - return scoreType == NumberType.Float; - } + => scoreType == NumberType.Float; } } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index bf41106198..3e9208ff5b 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -20,6 +20,9 @@ [assembly: LoadableClass(typeof(RankingPredictionTransformer>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel), "", RankingPredictionTransformer.LoaderSignature)] +[assembly: LoadableClass(typeof(AnomalyPredictionTransformer>), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel), + "", AnomalyPredictionTransformer.LoaderSignature)] + namespace Microsoft.ML.Runtime.Data { @@ -27,7 +30,9 @@ namespace Microsoft.ML.Runtime.Data /// Base class for transformers with no feature column, or more than one feature columns. /// /// - public abstract class PredictionTransformerBase : IPredictionTransformer + /// The Scorer used by this + public abstract class PredictionTransformerBase : IPredictionTransformer + where TScorer : RowToRowScorerBase where TModel : class, IPredictor { /// @@ -41,21 +46,23 @@ public abstract class PredictionTransformerBase : IPredictionTransformer protected ISchemaBindableMapper BindableMapper; protected ISchema TrainSchema; - public abstract bool IsRowToRowMapper { get; } + public bool IsRowToRowMapper => true; + + protected abstract TScorer Scorer { get; set; } protected PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema) { Contracts.CheckValue(host, nameof(host)); - Host = host; - Host.CheckValue(trainSchema, nameof(trainSchema)); + Host.CheckValue(trainSchema, nameof(trainSchema)); Model = model; + + Host.CheckValue(trainSchema, nameof(trainSchema)); TrainSchema = trainSchema; } protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) - { Host = host; @@ -91,9 +98,23 @@ protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) /// /// The input data. /// The transformed - public abstract IDataView Transform(IDataView input); - public abstract IRowToRowMapper GetRowToRowMapper(ISchema inputSchema); + public IDataView Transform(IDataView input) + { + Host.CheckValue(input, nameof(input)); + return Scorer.ApplyToData(Host, input); + } + + /// + /// Gets a IRowToRowMapper instance. + /// + /// + /// + public IRowToRowMapper GetRowToRowMapper(ISchema inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + return (IRowToRowMapper)Scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema)); + } protected void SaveModel(ModelSaveContext ctx) { @@ -118,8 +139,10 @@ protected void SaveModel(ModelSaveContext ctx) /// Those are all the transformers that work with one feature column. /// /// The model used to transform the data. - public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer, ICanSaveModel + /// The scorer used on this PredictionTransformer. + public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer, ICanSaveModel where TModel : class, IPredictor + where TScorer: RowToRowScorerBase { /// /// The name of the feature column used by the prediction transformer. @@ -131,8 +154,17 @@ public abstract class SingleFeaturePredictionTransformerBase : Predictio /// public ColumnType FeatureColumnType { get; } + protected override TScorer Scorer { get; set; } + + /// + /// Initializes a new reference of . + /// + /// The local instance of . + /// The model used for scoring. + /// The schema of the training data. + /// The feature column name. public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) - :base(host, model, trainSchema) + : base(host, model, trainSchema) { FeatureColumn = featureColumn; @@ -148,7 +180,7 @@ public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema } internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx) - :base(host, ctx) + : base(host, ctx) { FeatureColumn = ctx.LoadStringOrNull(); @@ -166,7 +198,7 @@ public override ISchema GetOutputSchema(ISchema inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - if(FeatureColumn != null) + if (FeatureColumn != null) { if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null); @@ -189,34 +221,36 @@ protected virtual void SaveCore(ModelSaveContext ctx) SaveModel(ctx); ctx.SaveStringOrNull(FeatureColumn); } + + protected virtual GenericScorer GetGenericScorer() + { + var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); + return new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); + } } /// - /// Base class for the working on binary classification tasks. + /// Base class for the working on anomaly detection tasks. /// /// An implementation of the - public sealed class BinaryPredictionTransformer : SingleFeaturePredictionTransformerBase + public sealed class AnomalyPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing { - private readonly BinaryClassifierScorer _scorer; - public readonly string ThresholdColumn; public readonly float Threshold; - public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, + public AnomalyPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), model, inputSchema, featureColumn) { Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); - var schema = new RoleMappedSchema(inputSchema, null, featureColumn); Threshold = threshold; ThresholdColumn = thresholdColumn; - var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; - _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); + SetScorer(); } - public BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) + public AnomalyPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), ctx) { // *** Binary format *** @@ -226,24 +260,82 @@ public BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) Threshold = ctx.Reader.ReadSingle(); ThresholdColumn = ctx.LoadString(); + SetScorer(); + } + private void SetScorer() + { var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; - _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); + Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); + } + + protected override void SaveCore(ModelSaveContext ctx) + { + Contracts.AssertValue(ctx); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // + // float: scorer threshold + // id of string: scorer threshold column + base.SaveCore(ctx); + + ctx.Writer.Write(Threshold); + ctx.SaveString(ThresholdColumn); } - public override IDataView Transform(IDataView input) + private static VersionInfo GetVersionInfo() { - Host.CheckValue(input, nameof(input)); - return _scorer.ApplyToData(Host, input); + return new VersionInfo( + modelSignature: "ANOMPRED", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: AnomalyPredictionTransformer.LoaderSignature, + loaderAssemblyName: typeof(AnomalyPredictionTransformer<>).Assembly.FullName); + } + } + + /// + /// Base class for the working on binary classification tasks. + /// + /// An implementation of the + public sealed class BinaryPredictionTransformer : SingleFeaturePredictionTransformerBase + where TModel : class, IPredictorProducing + { + public readonly string ThresholdColumn; + public readonly float Threshold; + + public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, + float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), model, inputSchema, featureColumn) + { + Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); + Threshold = threshold; + ThresholdColumn = thresholdColumn; + + SetScorer(); } - public override bool IsRowToRowMapper => true; + public BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer)), ctx) + { + // *** Binary format *** + // + // float: scorer threshold + // id of string: scorer threshold column + + Threshold = ctx.Reader.ReadSingle(); + ThresholdColumn = ctx.LoadString(); + SetScorer(); + } - public override IRowToRowMapper GetRowToRowMapper(ISchema inputSchema) + private void SetScorer() { - Host.CheckValue(inputSchema, nameof(inputSchema)); - return (IRowToRowMapper)_scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema)); + var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); + var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; + Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } protected override void SaveCore(ModelSaveContext ctx) @@ -277,10 +369,9 @@ private static VersionInfo GetVersionInfo() /// Base class for the working on multi-class classification tasks. /// /// An implementation of the - public sealed class MulticlassPredictionTransformer : SingleFeaturePredictionTransformerBase + public sealed class MulticlassPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing> { - private readonly MultiClassClassifierScorer _scorer; private readonly string _trainLabelColumn; public MulticlassPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, string labelColumn) @@ -289,9 +380,7 @@ public MulticlassPredictionTransformer(IHostEnvironment env, TModel model, ISche Host.CheckValueOrNull(labelColumn); _trainLabelColumn = labelColumn; - var schema = new RoleMappedSchema(inputSchema, labelColumn, featureColumn); - var args = new MultiClassClassifierScorer.Arguments(); - _scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); + SetScorer(); } public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) @@ -302,24 +391,14 @@ public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ct // id of string: train label column _trainLabelColumn = ctx.LoadStringOrNull(); - - var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumn); - var args = new MultiClassClassifierScorer.Arguments(); - _scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); - } - - public override IDataView Transform(IDataView input) - { - Host.CheckValue(input, nameof(input)); - return _scorer.ApplyToData(Host, input); + SetScorer(); } - public override bool IsRowToRowMapper => true; - - public override IRowToRowMapper GetRowToRowMapper(ISchema inputSchema) + private void SetScorer() { - Host.CheckValue(inputSchema, nameof(inputSchema)); - return (IRowToRowMapper)_scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema)); + var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumn); + var args = new MultiClassClassifierScorer.Arguments(); + Scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } protected override void SaveCore(ModelSaveContext ctx) @@ -351,37 +430,19 @@ private static VersionInfo GetVersionInfo() /// Base class for the working on regression tasks. /// /// An implementation of the - public sealed class RegressionPredictionTransformer : SingleFeaturePredictionTransformerBase + public sealed class RegressionPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing { - private readonly GenericScorer _scorer; - public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer)), model, inputSchema, featureColumn) { - var schema = new RoleMappedSchema(inputSchema, null, featureColumn); - _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); + Scorer = GetGenericScorer(); } internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer)), ctx) { - var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); - _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); - } - - public override IDataView Transform(IDataView input) - { - Host.CheckValue(input, nameof(input)); - return _scorer.ApplyToData(Host, input); - } - - public override bool IsRowToRowMapper => true; - - public override IRowToRowMapper GetRowToRowMapper(ISchema inputSchema) - { - Host.CheckValue(inputSchema, nameof(inputSchema)); - return (IRowToRowMapper)_scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema)); + Scorer = GetGenericScorer(); } protected override void SaveCore(ModelSaveContext ctx) @@ -397,7 +458,7 @@ protected override void SaveCore(ModelSaveContext ctx) private static VersionInfo GetVersionInfo() { return new VersionInfo( - modelSignature: "MC PRED", + modelSignature: "REG PRED", verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, @@ -406,37 +467,23 @@ private static VersionInfo GetVersionInfo() } } - public sealed class RankingPredictionTransformer : SingleFeaturePredictionTransformerBase + /// + /// Base class for the working on ranking tasks. + /// + /// An implementation of the + public sealed class RankingPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing { - private readonly GenericScorer _scorer; - public RankingPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer)), model, inputSchema, featureColumn) { - var schema = new RoleMappedSchema(inputSchema, null, featureColumn); - _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); + Scorer = GetGenericScorer(); } internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer)), ctx) { - var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); - _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); - } - - public override IDataView Transform(IDataView input) - { - Host.CheckValue(input, nameof(input)); - return _scorer.ApplyToData(Host, input); - } - - public override bool IsRowToRowMapper => true; - - public override IRowToRowMapper GetRowToRowMapper(ISchema inputSchema) - { - Host.CheckValue(inputSchema, nameof(inputSchema)); - return (IRowToRowMapper)_scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema)); + Scorer = GetGenericScorer(); } protected override void SaveCore(ModelSaveContext ctx) @@ -452,7 +499,7 @@ protected override void SaveCore(ModelSaveContext ctx) private static VersionInfo GetVersionInfo() { return new VersionInfo( - modelSignature: "MC RANK", + modelSignature: "RANK PRED", verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, @@ -492,4 +539,12 @@ internal static class RankingPredictionTransformer public static RankingPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) => new RankingPredictionTransformer>(env, ctx); } + + internal static class AnomalyPredictionTransformer + { + public const string LoaderSignature = "AnomalyPredXfer"; + + public static AnomalyPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) + => new AnomalyPredictionTransformer>(env, ctx); + } } diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index b784f1f6ee..a2c103d110 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -18,6 +18,7 @@ using Microsoft.ML.Runtime.PCA; using Microsoft.ML.Runtime.Training; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Core.Data; [assembly: LoadableClass(RandomizedPcaTrainer.Summary, typeof(RandomizedPcaTrainer), typeof(RandomizedPcaTrainer.Arguments), new[] { typeof(SignatureAnomalyDetectorTrainer), typeof(SignatureTrainer) }, @@ -41,7 +42,7 @@ namespace Microsoft.ML.Runtime.PCA /// /// This PCA can be made into Kernel PCA by using Random Fourier Features transform /// - public sealed class RandomizedPcaTrainer : TrainerBase + public sealed class RandomizedPcaTrainer : TrainerEstimatorBase, PcaPredictor> { public const string LoadNameValue = "pcaAnomaly"; internal const string UserNameValue = "PCA Anomaly Detector"; @@ -73,6 +74,7 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight private readonly int _oversampling; private readonly bool _center; private readonly int _seed; + private readonly string _featureColumn; public override PredictionKind PredictionKind => PredictionKind.AnomalyDetection; @@ -80,21 +82,58 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight private static readonly TrainerInfo _info = new TrainerInfo(caching: false); public override TrainerInfo Info => _info; - public RandomizedPcaTrainer(IHostEnvironment env, Arguments args) - : base(env, LoadNameValue) + /// + /// Initializes a new instance of . + /// + /// The local instance of the . + /// The name of the feature column. + /// The name of the weight column. + /// The number of components in the PCA. + /// Oversampling parameter for randomized PCA training. + /// If enabled, data is centered to be zero mean. + /// The seed for random number generation. + public RandomizedPcaTrainer(IHostEnvironment env, string featureColumn, string weightColumn = null, + int rank = 20, int oversampling = 20, bool center = true, int? seed = null) + : this(env, null, featureColumn, weightColumn, rank, oversampling, center, seed) + { + + } + + internal RandomizedPcaTrainer(IHostEnvironment env, Arguments args) + :this(env, args, args.FeatureColumn, args.WeightColumn) { - Host.CheckValue(args, nameof(args)); - Host.CheckUserArg(args.Rank > 0, nameof(args.Rank), "Rank must be positive"); - Host.CheckUserArg(args.Oversampling >= 0, nameof(args.Oversampling), "Oversampling must be non-negative"); - - _rank = args.Rank; - _center = args.Center; - _oversampling = args.Oversampling; - _seed = args.Seed ?? Host.Rand.Next(); + + } + + private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn, + int rank = 20, int oversampling = 20, bool center = true, int? seed = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + { + // if the args are not null, we got here from maml, and the internal ctor. + if (args != null) + { + _rank = args.Rank; + _center = args.Center; + _oversampling = args.Oversampling; + _seed = args.Seed ?? Host.Rand.Next(); + } + else + { + _rank = rank; + _center = center; + _oversampling = oversampling; + _seed = seed ?? Host.Rand.Next(); + } + + _featureColumn = featureColumn; + + Host.CheckUserArg(_rank > 0, nameof(_rank), "Rank must be positive"); + Host.CheckUserArg(_oversampling >= 0, nameof(_oversampling), "Oversampling must be non-negative"); + } //Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9) - public override PcaPredictor Train(TrainContext context) + protected override PcaPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); @@ -108,6 +147,18 @@ public override PcaPredictor Train(TrainContext context) } } + private static SchemaShape.Column MakeWeightColumn(string weightColumn) + { + if (weightColumn == null) + return null; + return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); + } + + private static SchemaShape.Column MakeFeatureColumn(string featureColumn) + { + return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + } + private PcaPredictor TrainCore(IChannel ch, RoleMappedData data, int dimension) { Host.AssertValue(ch); @@ -266,6 +317,27 @@ private static void PostProcess(VBuffer[] y, Float[] sigma, Float[] z, in } } + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, + SchemaShape.Column.VectorKind.Scalar, + NumberType.R4, + false, + new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, + SchemaShape.Column.VectorKind.Scalar, + BoolType.Instance, + false, + new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } + + protected override AnomalyPredictionTransformer MakeTransformer(PcaPredictor model, ISchema trainSchema) + => new AnomalyPredictionTransformer(Host, model, trainSchema, _featureColumn); + [TlcModule.EntryPoint(Name = "Trainers.PcaAnomalyDetector", Desc = "Train an PCA Anomaly model.", UserName = UserNameValue, diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs index ce94155b23..62f987901a 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs @@ -190,7 +190,7 @@ internal void CopyLatentWeightsTo(AlignedArray latentWeights) } } - public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel + public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel { public const string LoaderSignature = "FAFMPredXfer"; @@ -204,7 +204,7 @@ public sealed class FieldAwareFactorizationMachinePredictionTransformer : Predic /// public ColumnType[] FeatureColumnTypes { get; } - private readonly BinaryClassifierScorer _scorer; + protected override BinaryClassifierScorer Scorer { get; set; } private readonly string _thresholdColumn; private readonly float _threshold; @@ -236,7 +236,7 @@ public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host var schema = GetSchema(); var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; - _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema); + Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema); } public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, ModelLoadContext ctx) @@ -269,11 +269,11 @@ public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host var schema = GetSchema(); var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; - _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); + Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } /// - /// Gets the result after applying . + /// Gets the result after transformation. /// /// The of the input data. /// The post transformation . @@ -292,25 +292,6 @@ public override ISchema GetOutputSchema(ISchema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - /// - /// Applies the transformer to the , scoring it through the . - /// - /// The data to be scored with the . - /// The scored . - public override IDataView Transform(IDataView input) - { - Host.CheckValue(input, nameof(input)); - return _scorer.ApplyToData(Host, input); - } - - public override bool IsRowToRowMapper => true; - - public override IRowToRowMapper GetRowToRowMapper(ISchema inputSchema) - { - Host.CheckValue(inputSchema, nameof(inputSchema)); - return (IRowToRowMapper)_scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema)); - } - /// /// Saves the transformer to file. /// diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs index 74fa97807a..dcde5e65aa 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.PCA; using Microsoft.ML.Runtime.RunTests; using Xunit; using Xunit.Abstractions; @@ -17,6 +18,33 @@ public TrainerEstimators(ITestOutputHelper helper) : base(helper) { } + /// + /// FastTreeBinaryClassification TrainerEstimator test + /// + [Fact] + public void PCATrainerEstimator() + { + string featureColumn = "NumericFeatures"; + + var reader = new TextLoader(Env, new TextLoader.Arguments() + { + HasHeader = true, + Separator = "\t", + Column = new[] + { + new TextLoader.Column(featureColumn, DataKind.R4, new [] { new TextLoader.Range(1, 784) }) + } + }); + var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.mnistOneClass.trainFilename))); + + + // Pipeline. + var pipeline = new RandomizedPcaTrainer(Env, featureColumn, rank:10); + + TestEstimatorCore(pipeline, data); + Done(); + } + private (IEstimator, IDataView) GetBinaryClassificationPipeline() { var data = new TextLoader(Env,