Skip to content

Simple API to go from a trainer to something that can make predictions #560

Closed
@eerhardt

Description

@eerhardt

With the API proposal change in #371, the current proposed API looks something like:

... // load data and make transforms

// Train.
var trainer = new SdcaRegressionTrainer(env, new SdcaRegressionTrainer.Arguments());
var cached = new CacheDataView(env, trans, prefetch: null);
var trainRoles = TrainUtils.CreateExamples(cached, label: "Label", feature: "Features");
var pred = trainer.Train(trainRoles);

// Score.
IDataView scoredData = ScoreUtils.GetScorer(pred, trainRoles, env, trainRoles.Schema);

// Do a simple prediction.
var engine = env.CreatePredictionEngine<HousePriceData, HousePricePrediction>(scoredData);

HousePricePrediction prediction = engine.Predict(new HousePriceData()
....

Compare and contrast the similar code what what we have in the LearningPipeline API:

... // load data and make transforms

pipeline.Add(new StochasticDualCoordinateAscentRegressor());

PredictionModel<HousePriceData, HousePricePrediction> model = pipeline.Train<HousePriceData, HousePricePrediction>();

HousePricePrediction prediction = model.Predict(new HousePriceData()
....

You can see the proposed API has what feels like boilerplate code (create a cache data view, create examples, call train, get a scorer, create an engine). Where the LearningPipeline API simplifies this into roughly one call: call train, get something that can make predictions.

I don't think our simplest API example should have so many concepts in it. In my mind, the main concepts a new user needs to know about are:

  • Load data
  • Do transforms
  • Pick a learning algorithm
  • Train
  • Predict

However, in the current proposed API, they also need to think/learn about:

  • Whether or not they need a cached data view
  • Creating roles/examples
    • I'm not sure which is it. The type is RoleMappedData, but the method is named CreateExamples.
  • An IPredictor object
    • which doesn't make predictions
  • Calling GetScorer, which returns an IDataView that we call scoredData.
    • Is this object really data, or is it something that does scoring as implied by the method name: GetScorer?

In my opinion, this API is too complex and non-intuitive for first time users. We should investigate ways to make it simpler and see if we can come up with a design with less concepts to learn when first interacting with ML.NET.

/cc @ericstj @TomFinley @Zruty0 @terrajobst

Metadata

Metadata

Assignees

No one assigned

    Labels

    APIIssues pertaining the friendly API

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions