Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ protected Delegate CreateGetter(int col)
public ValueGetter<TValue> GetGetter<TValue>(int col)
{
Ch.Check(IsColumnActive(col), "The column must be active against the defined predicate.");
if (!(Getters[col] is ValueGetter<TValue>))
throw Ch.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}'");
return Getters[col] as ValueGetter<TValue>;
}

Expand Down
62 changes: 54 additions & 8 deletions src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,9 @@ public float Cost
}
}

public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunContext context,
private EntryPointNode(IHostEnvironment env, IChannel ch, ModuleCatalog moduleCatalog, RunContext context,
string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false,
string stageId = "", float cost = float.NaN)
string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null)
{
Contracts.AssertValue(env);
env.AssertNonEmpty(id);
Expand All @@ -497,6 +497,7 @@ public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunCont
_inputMap = new Dictionary<ParameterBinding, VariableBinding>();
_inputBindingMap = new Dictionary<string, List<ParameterBinding>>();
_inputBuilder = new InputBuilder(_host, _entryPoint.InputType, moduleCatalog);

// REVIEW: This logic should move out of Node eventually and be delegated to
// a class that can nest to handle Components with variables.
if (inputs != null)
Expand All @@ -508,6 +509,43 @@ public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunCont
if (missing.Length > 0)
throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");

var inputInstance = _inputBuilder.GetInstance();
var warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
" Using column '{2}'. To column use '{1}' instead, please specify this name in" +
"the trainer node arguments.";
if (!string.IsNullOrEmpty(label) && inputInstance is LearnerInputBaseWithLabel)
{
var labelInputInstance = inputInstance as LearnerInputBaseWithLabel;
if (label != labelInputInstance.LabelColumn)
ch.Warning(warning, "label", label, labelInputInstance.LabelColumn);
else
labelInputInstance.LabelColumn = label;
}
if (!string.IsNullOrEmpty(group) && inputInstance is LearnerInputBaseWithGroupId)
{
var groupInputInstance = inputInstance as LearnerInputBaseWithGroupId;
if (group != groupInputInstance.GroupIdColumn)
ch.Warning(warning, "group Id", group, groupInputInstance.GroupIdColumn);
else
groupInputInstance.GroupIdColumn = group;
}
if (!string.IsNullOrEmpty(weight) && inputInstance is LearnerInputBaseWithWeight)
{
var weightInputInstance = inputInstance as LearnerInputBaseWithWeight;
if (weight != weightInputInstance.WeightColumn)
ch.Warning(warning, "weight", weight, weightInputInstance.WeightColumn);
else
weightInputInstance.WeightColumn = weight;
}
if (!string.IsNullOrEmpty(weight) && inputInstance is UnsupervisedLearnerInputBaseWithWeight)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jun 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WithWeight [](start = 94, length = 10)

Is it possible to refactor it into IHaveLabelColumn, IHaveWeightColum, IHaveGroupColumn? And check for interface instead of abstract class? #Resolved

{
var weightInputInstance = inputInstance as UnsupervisedLearnerInputBaseWithWeight;
if (weight != weightInputInstance.WeightColumn)
ch.Warning(warning, "weight", weight, weightInputInstance.WeightColumn);
else
weightInputInstance.WeightColumn = weight;
}

// Validate outputs.
_outputHelper = new OutputHelper(_host, _entryPoint.OutputType);
_outputMap = new Dictionary<string, string>();
Expand Down Expand Up @@ -550,10 +588,15 @@ public static EntryPointNode Create(
var inputBuilder = new InputBuilder(env, info.InputType, catalog);
var outputHelper = new OutputHelper(env, info.OutputType);

var entryPointNode = new EntryPointNode(env, catalog, context, context.GenerateId(entryPointName), entryPointName,
inputBuilder.GetJsonObject(arguments, inputBindingMap, inputMap),
outputHelper.GetJsonObject(outputMap), checkpoint, stageId, cost);
return entryPointNode;
using (var ch = env.Start("Create EntryPointNode"))
{
var entryPointNode = new EntryPointNode(env, ch, catalog, context, context.GenerateId(entryPointName), entryPointName,
inputBuilder.GetJsonObject(arguments, inputBindingMap, inputMap),
outputHelper.GetJsonObject(outputMap), checkpoint, stageId, cost);

ch.Done();
return entryPointNode;
}
}

public static EntryPointNode Create(
Expand Down Expand Up @@ -850,7 +893,8 @@ private object BuildParameterValue(List<ParameterBinding> bindings)
throw _host.ExceptNotImpl("Unsupported ParameterBinding");
}

public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes, ModuleCatalog moduleCatalog)
public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes,
ModuleCatalog moduleCatalog, string label = null, string group = null, string weight = null)
{
Contracts.AssertValue(env);
env.AssertValue(context);
Expand Down Expand Up @@ -890,8 +934,10 @@ public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContex
ch.Warning("Node '{0}' has unexpected fields that are ignored: {1}", id, string.Join(", ", unexpectedFields.Select(x => x.Name)));
}

result.Add(new EntryPointNode(env, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost));
result.Add(new EntryPointNode(env, ch, moduleCatalog, context, id, name, inputs, outputs, checkpoint, stageId, cost, label, group, weight));
}

ch.Done();
}
return result;
}
Expand Down
147 changes: 75 additions & 72 deletions src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
using Newtonsoft.Json.Linq;

namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils
Expand Down Expand Up @@ -405,7 +404,11 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut
return null;

if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>)))
{
if (type.GetGenericTypeDefinition() == typeof(Optional<>) && value.HasValues)
value = value.Values().FirstOrDefault();
type = type.GetGenericArguments()[0];
}

if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Var<>)))
{
Expand All @@ -426,81 +429,81 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut
{
switch (dt)
{
case TlcModule.DataKind.Bool:
return value.Value<bool>();
case TlcModule.DataKind.String:
return value.Value<string>();
case TlcModule.DataKind.Char:
return value.Value<char>();
case TlcModule.DataKind.Enum:
if (!Enum.IsDefined(type, value.Value<string>()))
throw ectx.Except($"Requested value '{value.Value<string>()}' is not a member of the Enum type '{type.Name}'");
return Enum.Parse(type, value.Value<string>());
case TlcModule.DataKind.Float:
if (type == typeof(double))
return value.Value<double>();
else if (type == typeof(float))
return value.Value<float>();
else
{
ectx.Assert(false);
throw ectx.ExceptNotSupp();
}
case TlcModule.DataKind.Array:
var ja = value as JArray;
ectx.Check(ja != null, "Expected array value");
Func<IExceptionContext, JArray, Attributes, ModuleCatalog, object> makeArray = MakeArray<int>;
return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog);
case TlcModule.DataKind.Int:
if (type == typeof(long))
return value.Value<long>();
if (type == typeof(int))
return value.Value<int>();
ectx.Assert(false);
throw ectx.ExceptNotSupp();
case TlcModule.DataKind.UInt:
if (type == typeof(ulong))
return value.Value<ulong>();
if (type == typeof(uint))
return value.Value<uint>();
case TlcModule.DataKind.Bool:
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jun 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            [](start = 0, length = 16)

can you return old indentation? it's really hard to read changes #Closed

return value.Value<bool>();
case TlcModule.DataKind.String:
return value.Value<string>();
case TlcModule.DataKind.Char:
return value.Value<char>();
case TlcModule.DataKind.Enum:
if (!Enum.IsDefined(type, value.Value<string>()))
throw ectx.Except($"Requested value '{value.Value<string>()}' is not a member of the Enum type '{type.Name}'");
return Enum.Parse(type, value.Value<string>());
case TlcModule.DataKind.Float:
if (type == typeof(double))
return value.Value<double>();
else if (type == typeof(float))
return value.Value<float>();
else
{
ectx.Assert(false);
throw ectx.ExceptNotSupp();
case TlcModule.DataKind.Dictionary:
ectx.Check(value is JObject, "Expected object value");
Func<IExceptionContext, JObject, Attributes, ModuleCatalog, object> makeDict = MakeDictionary<int>;
return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog);
case TlcModule.DataKind.Component:
var jo = value as JObject;
ectx.Check(jo != null, "Expected object value");
// REVIEW: consider accepting strings alone.
var jName = jo[FieldNames.Name];
ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component.");
ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string.");
var name = jName.Value<string>();
ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject,
"Expected '" + FieldNames.Settings + "' field to be an object");
return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog);
default:
var settings = value as JObject;
ectx.Check(settings != null, "Expected object value");
var inputBuilder = new InputBuilder(ectx, type, catalog);

if (inputBuilder._fields.Length == 0)
throw ectx.Except($"Unsupported input type: {dt}");

if (settings != null)
}
case TlcModule.DataKind.Array:
var ja = value as JArray;
ectx.Check(ja != null, "Expected array value");
Func<IExceptionContext, JArray, Attributes, ModuleCatalog, object> makeArray = MakeArray<int>;
return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog);
case TlcModule.DataKind.Int:
if (type == typeof(long))
return value.Value<long>();
if (type == typeof(int))
return value.Value<int>();
ectx.Assert(false);
throw ectx.ExceptNotSupp();
case TlcModule.DataKind.UInt:
if (type == typeof(ulong))
return value.Value<ulong>();
if (type == typeof(uint))
return value.Value<uint>();
ectx.Assert(false);
throw ectx.ExceptNotSupp();
case TlcModule.DataKind.Dictionary:
ectx.Check(value is JObject, "Expected object value");
Func<IExceptionContext, JObject, Attributes, ModuleCatalog, object> makeDict = MakeDictionary<int>;
return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog);
case TlcModule.DataKind.Component:
var jo = value as JObject;
ectx.Check(jo != null, "Expected object value");
// REVIEW: consider accepting strings alone.
var jName = jo[FieldNames.Name];
ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component.");
ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string.");
var name = jName.Value<string>();
ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject,
"Expected '" + FieldNames.Settings + "' field to be an object");
return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog);
default:
var settings = value as JObject;
ectx.Check(settings != null, "Expected object value");
var inputBuilder = new InputBuilder(ectx, type, catalog);

if (inputBuilder._fields.Length == 0)
throw ectx.Except($"Unsupported input type: {dt}");

if (settings != null)
{
foreach (var pair in settings)
{
foreach (var pair in settings)
{
if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value))
throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'");
}
if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value))
throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'");
}
}

var missing = inputBuilder.GetMissingValues().ToArray();
if (missing.Length > 0)
throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}");
return inputBuilder.GetInstance();
var missing = inputBuilder.GetMissingValues().ToArray();
if (missing.Length > 0)
throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}");
return inputBuilder.GetInstance();
}
}
catch (FormatException ex)
Expand Down
Loading