Skip to content

MetaMulticlassTrainer throws an exception when used in an estimator chain #1969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
daholste opened this issue Dec 26, 2018 · 3 comments
Closed
Assignees

Comments

@daholste
Copy link
Contributor

daholste commented Dec 26, 2018

Issue

Setup code:

            // load data from disk
            var textLoader = new TextLoader(mlContext, new TextLoader.Arguments()
                {
                    Separator = ",",
                    HasHeader = true,
                    Column = new[]
                        {
                            new TextLoader.Column("Label", DataKind.R4, 0),
                            new TextLoader.Column("Features", DataKind.R4, 1, 784),
                        }
                });

Code that succeeds:

            var apTrainer = mlContext.BinaryClassification.Trainers.AveragedPerceptron();
            var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(apTrainer);
            var model1 = trainer.Fit(testData);

Code for a dummy estimator chain that fails:

            IEstimator<ITransformer> pipeline = new EstimatorChain<ITransformer>();
            pipeline = pipeline.Append(trainer);
            var model2 = pipeline.Fit(trainData);

This fails b/c of

LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);

in the MetaMulticlassTrainer class, which expects label column to be U4, not R4.

If I try to fix by changing the data type of Label column to U4, the code that used to succeed fails with the exception:

System.ArgumentOutOfRangeException: 'Training label column 'Label' type is not valid for multi-class: U4. Type must be R4 or R8.'

in TrainerUtils.CheckMultiClassLabel

@yaeldekel yaeldekel self-assigned this Dec 28, 2018
@yaeldekel
Copy link

Hi @daholste , thank you for reporting this issue.
All the multi-class trainers should be able to handle both R4/R8 labels, and key type labels, which is a type used for representing class IDs.

The underlying values of a key type are of type U4, however, this type is not identical to the numeric U4 type, which is the type you get by changing the TextLoader column definition to DataKind.U4.

So there are two issues here:

  1. When you change the data type of the Label column to U4, the correct behavior is to fail, however the exception message is a bit misleading, since it should say that the expected types are key, R4 or R8.
  2. With the original definition of the Label column as type R4, the trainer should succeed both when run independently, and inside an estimator chain.

As a workaround until the second issue is fixed, you can add a ValueToKeyMappingEstimator (applied to the Label column) to your estimator chain right before the OVA estimator. This will convert the R4 Label column to a key type.

@daholste
Copy link
Contributor Author

daholste commented Jan 24, 2019

Thank you a lot for the info & workaround! Any chance you would be able to prioritize this fix to ship in your next release?

@codemzs
Copy link
Member

codemzs commented Jun 30, 2019

this seems to have been fixed already. @daholste if this is still an issue please reopen.

@codemzs codemzs closed this as completed Jun 30, 2019
@ghost ghost locked as resolved and limited conversation to collaborators Mar 25, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants