Skip to content

Commit 93388b6

Browse files
committed
Added extraction of score column before node creation
1 parent ea71828 commit 93388b6

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema
320320
if (!ctx.ContainsColumn(featName))
321321
return false;
322322
Contracts.Assert(ctx.ContainsColumn(featName));
323-
return mapper.SaveAsOnnx(ctx, new[] { outputNames[1] }, ctx.GetVariableName(featName));
323+
return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName));
324324
}
325325

326326
private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) =>

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3111,7 +3111,8 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
31113111
}
31123112

31133113
string opType = "TreeEnsembleRegressor";
3114-
var node = ctx.CreateNode(opType, new[] { featureColumn }, outputNames, ctx.GetNodeName(opType));
3114+
string scoreVarName = (Utils.Size(outputNames) == 2) ? outputNames[1] : outputNames[0]; // Get Score from PredictedLabel and/or Score columns
3115+
var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType));
31153116

31163117
node.AddAttribute("post_transform", PostTransform.None.GetDescription());
31173118
node.AddAttribute("n_targets", 1);

src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,10 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input)
240240
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
241241
{
242242
Host.CheckValue(ctx, nameof(ctx));
243-
Host.Check(Utils.Size(outputs) == 1);
244-
245243
string opType = "LinearRegressor";
246-
var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));
244+
string scoreVarName = (Utils.Size(outputs) == 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns
245+
246+
var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType));
247247
// Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
248248
node.AddAttribute("post_transform", "NONE");
249249
node.AddAttribute("targets", 1);

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ public void binaryClassificationTrainersOnnxConversionTest()
240240
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3);
241241
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult);
242242
}
243-
244243
}
245244
Done();
246245
}

0 commit comments

Comments
 (0)