diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index eba8029fdb..218da383bc 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -603,7 +603,10 @@ private void InitializeTrainingGraph(IDataView input) (string)labelType.ToString()); } - _classCount = labelCount == 1 ? 2 : (int)labelCount; + var msg = $"Only one class found in the {_options.LabelColumnName} column. To build a multiclass classification model, the number of classes needs to be 2 or greater"; + Contracts.CheckParam(labelCount > 1, nameof(labelCount), msg); + + _classCount = (int)labelCount; var imageSize = ImagePreprocessingSize[_options.Arch]; _session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch).Session; _session.graph.as_default();