Skip to content

Scrubbing PkPd #2749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,26 @@
using Microsoft.ML.Model;
using Microsoft.ML.Trainers;

[assembly: LoadableClass(Pkpd.Summary, typeof(Pkpd), typeof(Pkpd.Options),
[assembly: LoadableClass(PairwiseCouplingTrainer.Summary, typeof(PairwiseCouplingTrainer), typeof(PairwiseCouplingTrainer.Options),
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) },
Pkpd.UserNameValue, Pkpd.LoadNameValue, DocName = "trainer/OvaPkpd.md")]
PairwiseCouplingTrainer.UserNameValue, PairwiseCouplingTrainer.LoadNameValue, DocName = "trainer/OvaPkpd.md")]

[assembly: LoadableClass(typeof(PkpdModelParameters), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(PairwiseCouplingModelParameters), null, typeof(SignatureLoadModel),
"PKPD Executor",
PkpdModelParameters.LoaderSignature)]
PairwiseCouplingModelParameters.LoaderSignature)]

namespace Microsoft.ML.Trainers
{
using CR = RoleMappedSchema.ColumnRole;
using TDistPredictor = IDistPredictorProducing<float, float>;
using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
using TTransformer = MulticlassPredictionTransformer<PkpdModelParameters>;
using TTransformer = MulticlassPredictionTransformer<PairwiseCouplingModelParameters>;

/// <summary>
/// In this strategy, a binary classification algorithm is trained on each pair of classes.
/// The pairs are unordered but created with replacement: so, if there were three classes, 0, 1,
/// 2, we would train classifiers for the pairs (0,0), (0,1), (0,2), (1,1), (1,2),
/// and(2,2). For each binary classifier, an input data point is considered a
/// and (2,2). For each binary classifier, an input data point is considered a
/// positive example if it is in either of the two classes in the pair, and a
/// negative example otherwise. At prediction time, the probabilities for each
/// pair of classes is considered as the probability of being in either class of
Expand All @@ -42,17 +42,17 @@ namespace Microsoft.ML.Trainers
/// pair.
///
/// These two can allow you to exploit trainers that do not naturally have a
/// multiclass option, for example, using the Runtime.FastTree.FastTreeBinaryClassificationTrainer
/// multiclass option, for example, using the FastTree Binary Classification
/// to solve a multiclass problem.
/// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases
/// where the trainer has a multiclass option, but using it directly is not
/// practical due to, usually, memory constraints.For example, while a multiclass
/// practical due to, usually, memory constraints. For example, while a multiclass
/// logistic regression is a more principled way to solve a multiclass problem, it
/// requires that the learner store a lot more intermediate state in the form of
/// L-BFGS history for all classes *simultaneously*, rather than just one-by-one
/// as would be needed for a one-versus-all classification model.
/// </summary>
public sealed class Pkpd : MetaMulticlassTrainer<MulticlassPredictionTransformer<PkpdModelParameters>, PkpdModelParameters>
public sealed class PairwiseCouplingTrainer : MetaMulticlassTrainer<MulticlassPredictionTransformer<PairwiseCouplingModelParameters>, PairwiseCouplingModelParameters>
{
internal const string LoadNameValue = "PKPD";
internal const string UserNameValue = "Pairwise coupling (PKPD)";
Expand All @@ -61,49 +61,48 @@ public sealed class Pkpd : MetaMulticlassTrainer<MulticlassPredictionTransformer
+ "classifiers predicted it. The prediction is the class with the highest score.";

/// <summary>
/// Options passed to PKPD.
/// Options passed to <see cref="Microsoft.ML.Trainers.PairwiseCouplingTrainer"/>.
/// </summary>
internal sealed class Options : OptionsBase
{
}

/// <summary>
/// Legacy constructor that builds the <see cref="Pkpd"/> trainer supplying the base trainer to use, for the classification task
/// Constructs a <see cref="PairwiseCouplingTrainer"/> trainer supplying the base trainer to use, for the classification task
/// through the <see cref="Options"/>Options.
/// Developers should instantiate <see cref="Pkpd"/> by supplying the trainer argument directly to the <see cref="Pkpd"/> constructor
/// using the other public constructor.
/// </summary>
internal Pkpd(IHostEnvironment env, Options options)
internal PairwiseCouplingTrainer(IHostEnvironment env, Options options)
: base(env, options, LoadNameValue)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="Pkpd"/>
/// Initializes a new instance of the <see cref="PairwiseCouplingTrainer"/>
/// </summary>
/// <param name="env">The <see cref="IHostEnvironment"/> instance.</param>
/// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param>
/// <param name="labelColumn">The name of the label colum.</param>
/// <param name="labelColumnName">The name of the label colum.</param>
/// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param>
/// <param name="calibrator">The calibrator to use for each model instance. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
/// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param>
internal Pkpd(IHostEnvironment env,
/// <param name="maximumCalibrationExampleCount">Number of instances to train the calibrator.</param>
internal PairwiseCouplingTrainer(IHostEnvironment env,
TScalarTrainer binaryEstimator,
string labelColumn = DefaultColumnNames.Label,
string labelColumnName = DefaultColumnNames.Label,
bool imputeMissingLabelsAsNegative = false,
ICalibratorTrainer calibrator = null,
int maxCalibrationExamples = 1000000000)
int maximumCalibrationExampleCount = 1000000000)
: base(env,
new Options
{
ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative,
MaxCalibrationExamples = maxCalibrationExamples,
MaxCalibrationExamples = maximumCalibrationExampleCount,
},
LoadNameValue, labelColumn, binaryEstimator, calibrator)
LoadNameValue, labelColumnName, binaryEstimator, calibrator)
{
Host.CheckValue(labelColumn, nameof(labelColumn), "Label column should not be null.");
Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null.");
}

private protected override PkpdModelParameters TrainCore(IChannel ch, RoleMappedData data, int count)
private protected override PairwiseCouplingModelParameters TrainCore(IChannel ch, RoleMappedData data, int count)
{
// Train M * (M+1) / 2 models arranged as a lower triangular matrix.
var predModels = new TDistPredictor[count][];
Expand All @@ -119,7 +118,7 @@ private protected override PkpdModelParameters TrainCore(IChannel ch, RoleMapped
}
}

return new PkpdModelParameters(Host, predModels);
return new PairwiseCouplingModelParameters(Host, predModels);
}

private ISingleFeaturePredictionTransformer<TDistPredictor> TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2)
Expand Down Expand Up @@ -167,7 +166,7 @@ private IDataView MapLabels(RoleMappedData data, int cls1, int cls2)
return MapLabelsCore(NumberDataViewType.Double, (in double val) => val == key1 || val == key2, data);
}

throw Host.ExceptNotSupp($"Label column type is not supported by PKPD: {lab.Type}");
throw Host.ExceptNotSupp($"Label column type is not supported by nameof(PairwiseCouplingTrainer): {lab.Type.RawType}");
}

/// <summary>
Expand Down Expand Up @@ -209,11 +208,14 @@ public override TTransformer Fit(IDataView input)
}
}

return new MulticlassPredictionTransformer<PkpdModelParameters>(Host, new PkpdModelParameters(Host, predictors), input.Schema, featureColumn, LabelColumn.Name);
return new MulticlassPredictionTransformer<PairwiseCouplingModelParameters>(Host, new PairwiseCouplingModelParameters(Host, predictors), input.Schema, featureColumn, LabelColumn.Name);
}
}

public sealed class PkpdModelParameters :
/// <summary>
/// Contains the model parameters and prediction functions for the PairwiseCouplingTrainer.
/// </summary>
public sealed class PairwiseCouplingModelParameters :
ModelParametersBase<VBuffer<float>>,
IValueMapper
{
Expand All @@ -228,7 +230,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(PkpdModelParameters).Assembly.FullName);
loaderAssemblyName: typeof(PairwiseCouplingModelParameters).Assembly.FullName);
}

private const string SubPredictorFmt = "SubPredictor_{0:000}";
Expand All @@ -248,7 +250,7 @@ private static VersionInfo GetVersionInfo()
DataViewType IValueMapper.InputType => _inputType;
DataViewType IValueMapper.OutputType => _outputType;

internal PkpdModelParameters(IHostEnvironment env, TDistPredictor[][] predictors) :
internal PairwiseCouplingModelParameters(IHostEnvironment env, TDistPredictor[][] predictors) :
base(env, RegistrationName)
{
Host.Assert(Utils.Size(predictors) > 0);
Expand All @@ -272,7 +274,7 @@ internal PkpdModelParameters(IHostEnvironment env, TDistPredictor[][] predictors
_outputType = new VectorType(NumberDataViewType.Single, _numClasses);
}

private PkpdModelParameters(IHostEnvironment env, ModelLoadContext ctx)
private PairwiseCouplingModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
// *** Binary format ***
Expand Down Expand Up @@ -331,12 +333,12 @@ private bool IsValid(IValueMapperDist mapper, ref VectorType inputType)
return true;
}

private static PkpdModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
private static PairwiseCouplingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new PkpdModelParameters(env, ctx);
return new PairwiseCouplingModelParameters(env, ctx);
}

private protected override void SaveCore(ModelSaveContext ctx)
Expand Down Expand Up @@ -478,4 +480,4 @@ ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
return (ValueMapper<TIn, TOut>)(Delegate)del;
}
}
}
}
11 changes: 6 additions & 5 deletions src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ public static OneVersusAllTrainer OneVersusAll<TModel>(this MulticlassClassifica
}

/// <summary>
/// Predicts a target using a linear multiclass classification model trained with the <see cref="Pkpd"/>.
/// Predicts a target using a linear multiclass classification model trained with the <see cref="PairwiseCouplingTrainer"/>.
/// </summary>
/// <remarks>
/// <para>
Expand All @@ -619,21 +619,22 @@ public static OneVersusAllTrainer OneVersusAll<TModel>(this MulticlassClassifica
/// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
/// <param name="labelColumnName">The name of the label colum.</param>
/// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param>
/// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param>
/// <param name="maximumCalibrationExampleCount">Number of instances to train the calibrator.</param>
/// <typeparam name="TModel">The type of the model. This type parameter will usually be inferred automatically from <paramref name="binaryEstimator"/>.</typeparam>
public static Pkpd PairwiseCoupling<TModel>(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
public static PairwiseCouplingTrainer PairwiseCoupling<TModel>(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
ITrainerEstimator<ISingleFeaturePredictionTransformer<TModel>, TModel> binaryEstimator,
string labelColumnName = DefaultColumnNames.Label,
bool imputeMissingLabelsAsNegative = false,
IEstimator<ISingleFeaturePredictionTransformer<ICalibrator>> calibrator = null,
int maxCalibrationExamples = 1_000_000_000)
int maximumCalibrationExampleCount = 1_000_000_000)
where TModel : class
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
if (!(binaryEstimator is ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>> est))
throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model.");
return new Pkpd(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maxCalibrationExamples);
return new PairwiseCouplingTrainer(env, est, labelColumnName, imputeMissingLabelsAsNegative,
GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ public void OVAUncalibrated()
}

/// <summary>
/// Pkpd trainer
/// Pairwise Coupling trainer
/// </summary>
[Fact]
public void Pkpd()
public void PairwiseCouplingTrainer()
{
var (pipeline, data) = GetMultiClassPipeline();

Expand Down