Skip to content

Throw exception in ImageClassificationTrainer when dataset contains only 1 class #4662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 25, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,10 @@ private void InitializeTrainingGraph(IDataView input)
(string)labelType.ToString());
}

_classCount = labelCount == 1 ? 2 : (int)labelCount;
var msg = $"Only one class found in the {_options.LabelColumnName} column. To build a multiclass classification model, the number of classes needs to be 2 or greater";
Contracts.CheckParam(labelCount > 1, nameof(labelCount), msg);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented the message that @natke suggested.

I had to use the "var msg" to store the string, because I couldn't use string interpolation inside the CheckParam() argument, because of this:

/// <summary>
/// Looks up a localized string similar to Since C# has no concept of lazy evaluation of parameters, we prefer Contracts.Check&apos;s message arguments to not involve string formatting, or other complex operations, since such operations will happen always, whether the check fails or not. If you want to have detailed messages that&apos;s great, but use Contracts.Except instead. That is instead of something like &apos;Check(c, msg)&apos;, prefer &apos;if (!c) throw Except(msg)&apos;..
/// </summary>
internal static string ContractsCheckMessageNotLiteralOrIdentifier {
get {
return ResourceManager.GetString("ContractsCheckMessageNotLiteralOrIdentifier", resourceCulture);
}
}

So the suggestion there is to use Contracts.Except() instead of Contracts.Param() if I want to use string interpolation. Problem is Contracts.Except() throws an InvalidOperationException (which @justinormont and @harishsk had suggested I shouldn't use). So I guess it's best to simply create the "msg" string before passing it to the Contracts.CheckParam()...


_classCount = (int)labelCount;
var imageSize = ImagePreprocessingSize[_options.Arch];
_session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch).Session;
_session.graph.as_default();
Expand Down