Skip to content

CacheDataView and PredictionEngine don't interact well #580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
eerhardt opened this issue Jul 24, 2018 · 2 comments
Closed

CacheDataView and PredictionEngine don't interact well #580

eerhardt opened this issue Jul 24, 2018 · 2 comments

Comments

@eerhardt
Copy link
Member

System information

  • OS version/distro: all
  • .NET Version (eg., dotnet --info): all

Issue

        public class IrisData
        {
            [Column("0")]
            public float SepalLength;

            [Column("1")]
            public float SepalWidth;

            [Column("2")]
            public float PetalLength;

            [Column("3")]
            public float PetalWidth;

            [Column("4")]
            [ColumnName("Label")]
            public string Label;
        }

        public class IrisPrediction
        {
            [ColumnName("PredictedLabel")]
            [KeyType]
            public uint PredictedLabels;
        }

        static void Main(string[] args)
        {
            using (var env = new TlcEnvironment(seed: 0))
            {
                string dataPath = "iris-data.txt";

                var loader = new TextLoader(env, new TextLoader.Arguments()
                {
                    HasHeader = false,
                    SeparatorChars = new char[] { ',' },
                    Column = new[] {
                        ScalarCol("SepalLength", 0),
                        ScalarCol("SepalWidth", 1),
                        ScalarCol("PetalLength", 2),
                        ScalarCol("PetalWidth", 3),
                        ScalarCol("Label", 4, DataKind.Text),
                        }
                }, new MultiFileSource(dataPath));

                IDataTransform trans = new TermTransform(env, loader, "Label");

               trans = new ConcatTransform(env, trans, "Features",
                    "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");

                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments());

                var cached = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var pred = trainer.Train(trainRoles);

                // Score.
                IDataView scoredData = ScoreUtils.GetScorer(pred, trainRoles, env, trainRoles.Schema);

                // Do a simple prediction.
                var engine = env.CreatePredictionEngine<IrisData, IrisPrediction>(scoredData);

                var prediction = engine.Predict(new IrisData()
                {
                    SepalLength = 3.3f,
                    SepalWidth = 1.6f,
                    PetalLength = 0.2f,
                    PetalWidth = 5.1f,
                });
                Console.WriteLine($"Predicted flower type is: {prediction.PredictedLabels}");
            }
        }
  • What happened?
Unhandled Exception: System.ArgumentOutOfRangeException: Feature column 'Features' not found
Parameter name: name
   at Microsoft.ML.Runtime.Data.ColumnInfo.CreateFromName(ISchema schema, String name, String descName)
   at Microsoft.ML.Runtime.Data.RoleMappedSchema.MapFromNames(ISchema schema, IEnumerable`1 roles, Boolean opt)
   at Microsoft.ML.Runtime.Data.RoleMappedSchema..ctor(ISchema schema, IEnumerable`1 roles, Boolean opt)
   at Microsoft.ML.Runtime.Data.PredictedLabelScorerBase.BindingsImpl.ApplyToSchema(ISchema input, ISchemaBindableMapper bindable, IHostEnvironment env)
   at Microsoft.ML.Runtime.Data.PredictedLabelScorerBase..ctor(IHostEnvironment env, PredictedLabelScorerBase transform, IDataView newSource, String registrationName)
   at Microsoft.ML.Runtime.Data.MultiClassClassifierScorer..ctor(IHostEnvironment env, MultiClassClassifierScorer transform, IDataView newSource)
   at Microsoft.ML.Runtime.Data.MultiClassClassifierScorer.ApplyToData(IHostEnvironment env, IDataView newSource)
   at Microsoft.ML.Runtime.Data.ApplyTransformUtils.ApplyTransformToData(IHostEnvironment env, IDataTransform transform, IDataView newSource)
   at Microsoft.ML.Runtime.Data.ApplyTransformUtils.ApplyAllTransformsToData(IHostEnvironment env, IDataView chain, IDataView newSource, IDataView oldSource)
   at Microsoft.ML.Runtime.Api.BatchPredictionEngine`2..ctor(IHostEnvironment env, IDataView dataPipeline, Boolean ignoreMissingColumns, SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition)
   at Microsoft.ML.Runtime.Api.PredictionEngine`2..ctor(IHostEnvironment env, IDataView dataPipe, Boolean ignoreMissingColumns, SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition)
   at Microsoft.ML.Runtime.Api.ComponentCreation.CreatePredictionEngine[TSrc,TDst](IHostEnvironment env, IDataView dataPipe, Boolean ignoreMissingColumns, SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition)
   at myApp.Program.Main(String[] args) in C:\Users\eerhardt\source\repos\MLNetCore30Test\Program.cs:line 182
  • What did you expect?
    I expected it to work.

Notes

The reason (AFAICT) is because of the CacheDataView usage. When PredictionEngine is trying to apply all the transforms:

while ((xf = chain as IDataTransform) != null)

It hits that CacheDataView, which isn’t an IDataTransform, and it escapes out. Thus, the only transform that gets applied is the Scorer transform, and not any of the transforms used before (like adding the “Features” column).

We work around this in the tests by serializing the IDV out and then reading it back in:

        private IDataScorerTransform GetScorer(IHostEnvironment env, IDataView transforms, IPredictor pred, string testDataPath = null)
        {
            using (var ch = env.Start("Saving model"))
            using (var memoryStream = new MemoryStream())
            {
                var trainRoles = new RoleMappedData(transforms, label: "Label", feature: "Features");


                // Model cannot be saved with CacheDataView
                TrainUtils.SaveModel(env, ch, memoryStream, pred, trainRoles);
                memoryStream.Position = 0;
                using (var rep = RepositoryReader.Open(memoryStream, ch))
                {
                    IDataLoader testPipe = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(testDataPath), true);
                    RoleMappedData testRoles = new RoleMappedData(testPipe, label: "Label", feature: "Features");
                    return ScoreUtils.GetScorer(pred, testRoles, env, testRoles.Schema);
                }
            }
        }

I would not expect a user to have to do this. Any thoughts on how to make this better?

I removed the CacheDataView from my pipeline, which makes the code work but the training got super slow. So that seems to be a non-starter.

/cc @TomFinley @Zruty0

@Zruty0
Copy link
Contributor

Zruty0 commented Jul 24, 2018

As for this particular use case, there exists a 'proper' workaround:

                var cached = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var pred = trainer.Train(trainRoles);

                // -----
                trainRoles = new RoleMappedData(trans, feature: "Features");
                // -----

                // Score.
                IDataView scoredData = ScoreUtils.GetScorer(pred, trainRoles, env, trainRoles.Schema);

This will 'short-circuit' our CacheDataView to only be used for training and not used as part of the scoring pipeline.

Obviously, we should think some about how to make this arrangement less of a 'trap of failure' for new users. And I think it goes back to the idea of the 'smart' training (the training process that would cache if needed, normalize if needed and calibrate if needed).

@codemzs
Copy link
Member

codemzs commented Jun 30, 2019

I believe this is not an issue with the current stable API.

@codemzs codemzs closed this as completed Jun 30, 2019
@ghost ghost locked as resolved and limited conversation to collaborators Mar 29, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants