From 162e8063bd010f286b249f2d2d849064331cd8eb Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Wed, 15 Jan 2020 15:56:47 -0800 Subject: [PATCH 1/4] Throw exception when dataset contains only 1 label --- src/Microsoft.ML.Vision/ImageClassificationTrainer.cs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index eba8029fdb..b51acda77f 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -603,7 +603,10 @@ private void InitializeTrainingGraph(IDataView input) (string)labelType.ToString()); } - _classCount = labelCount == 1 ? 2 : (int)labelCount; + if (labelCount == 1) + throw new InvalidOperationException("Dataset contains only 1 class. ImageClassificationTrainer requires more than 1"); + + _classCount = (int)labelCount; var imageSize = ImagePreprocessingSize[_options.Arch]; _session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch).Session; _session.graph.as_default(); From 5bf129b9e4f9fd2c64ffcac13e14dc83d6317d10 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Fri, 17 Jan 2020 09:37:47 -0800 Subject: [PATCH 2/4] Changed type of Exception to ArgumentOutOfRange --- src/Microsoft.ML.Vision/ImageClassificationTrainer.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index b51acda77f..f976727401 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -603,8 +603,7 @@ private void InitializeTrainingGraph(IDataView input) (string)labelType.ToString()); } - if (labelCount == 1) - throw new InvalidOperationException("Dataset contains only 1 class. ImageClassificationTrainer requires more than 1"); + Contracts.CheckParam(labelCount > 1, nameof(labelCount), "ImageClassificationTrainer requires more than 1 class in the training dataset"); _classCount = (int)labelCount; var imageSize = ImagePreprocessingSize[_options.Arch]; From 5b8ebd6e4c62380eba33abaea1d475c682045319 Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Tue, 21 Jan 2020 13:10:40 -0800 Subject: [PATCH 3/4] Updated exception message --- src/Microsoft.ML.Vision/ImageClassificationTrainer.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index f976727401..c7f3624ec2 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -603,7 +603,8 @@ private void InitializeTrainingGraph(IDataView input) (string)labelType.ToString()); } - Contracts.CheckParam(labelCount > 1, nameof(labelCount), "ImageClassificationTrainer requires more than 1 class in the training dataset"); + var msg = $"Only one class found in {_options.LabelColumnName} column. To build a multiclass classification model, the number of classes needs to be 2 or greater"; + Contracts.CheckParam(labelCount > 1, nameof(labelCount), msg); _classCount = (int)labelCount; var imageSize = ImagePreprocessingSize[_options.Arch]; From bf113b77878f2ccdea1cd548fbfed43a0659469b Mon Sep 17 00:00:00 2001 From: Antonio Velazquez Date: Tue, 21 Jan 2020 13:22:25 -0800 Subject: [PATCH 4/4] Updated exception message --- src/Microsoft.ML.Vision/ImageClassificationTrainer.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index c7f3624ec2..218da383bc 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -603,7 +603,7 @@ private void InitializeTrainingGraph(IDataView input) (string)labelType.ToString()); } - var msg = $"Only one class found in {_options.LabelColumnName} column. To build a multiclass classification model, the number of classes needs to be 2 or greater"; + var msg = $"Only one class found in the {_options.LabelColumnName} column. To build a multiclass classification model, the number of classes needs to be 2 or greater"; Contracts.CheckParam(labelCount > 1, nameof(labelCount), msg); _classCount = (int)labelCount;