diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index fb15345497..bad42afd5d 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -17,8 +17,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath", "src EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.PipelineInference", "src\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj", "{2D7391C9-8254-4B8F-BF26-FADAF8F02F44}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InferenceTesting", "test\Microsoft.ML.InferenceTesting\Microsoft.ML.InferenceTesting.csproj", "{E278EC99-E6EE-49FE-92E6-0A309A478D98}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Data", "src\Microsoft.ML.Data\Microsoft.ML.Data.csproj", "{AD92D96B-0E96-4F22-8DCE-892E13B1F282}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Onnx", "src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}" @@ -175,14 +173,6 @@ Global {2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release|Any CPU.Build.0 = Release|Any CPU {2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU {2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU - {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug|Any CPU.Build.0 = Debug|Any CPU - {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU - {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU - {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release|Any CPU.ActiveCfg = Release|Any CPU - {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release|Any CPU.Build.0 = Release|Any CPU - {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU - {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU {AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug|Any CPU.Build.0 = Debug|Any CPU {AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU @@ -520,7 +510,6 @@ Global {EC743D1D-7691-43B7-B9B0-5F2F7018A8F6} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {46F2F967-C23F-4076-858D-33F7DA9BD2DA} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {2D7391C9-8254-4B8F-BF26-FADAF8F02F44} = {09EADF06-BE25-4228-AB53-95AE3E15B530} - {E278EC99-E6EE-49FE-92E6-0A309A478D98} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {AD92D96B-0E96-4F22-8DCE-892E13B1F282} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {65D0603E-B96C-4DFC-BDD1-705891B88C18} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {707BB22C-7E5F-497A-8C2F-74578F675705} = {09EADF06-BE25-4228-AB53-95AE3E15B530} diff --git a/src/Microsoft.ML.Api/GenerateCodeCommand.cs b/src/Microsoft.ML.Api/GenerateCodeCommand.cs index 26136971af..4855a94219 100644 --- a/src/Microsoft.ML.Api/GenerateCodeCommand.cs +++ b/src/Microsoft.ML.Api/GenerateCodeCommand.cs @@ -24,7 +24,7 @@ namespace Microsoft.ML.Runtime.Api /// /// REVIEW: Consider adding support for generating VBuffers instead of arrays, maybe for high dimensionality vectors. /// - public sealed class GenerateCodeCommand : ICommand + internal sealed class GenerateCodeCommand : ICommand { public const string LoadName = "GenerateSamplePredictionCode"; private const string CodeTemplatePath = "Microsoft.ML.Api.GeneratedCodeTemplate.csresource"; diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 3450dd2bbf..7110565d12 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -528,8 +528,6 @@ public ICursor GetRootCursor() /// public static class CursoringUtils { - private const string NeedEnvObsoleteMessage = "This method is obsolete. Please use the overload that takes an additional 'env' argument. An environment can be created via new LocalEnvironment()."; - /// /// Generate a strongly-typed cursorable wrapper of the . /// @@ -550,24 +548,6 @@ public static ICursorable AsCursorable(this IDataView data, IHostEnv return TypedCursorable.Create(env, data, ignoreMissingColumns, schemaDefinition); } - /// - /// Generate a strongly-typed cursorable wrapper of the . - /// - /// The user-defined row type. - /// The underlying data view. - /// Whether to ignore the case when a requested column is not present in the data view. - /// Optional user-provided schema definition. If it is not present, the schema is inferred from the definition of T. - /// The cursorable wrapper of . - [Obsolete(NeedEnvObsoleteMessage)] - public static ICursorable AsCursorable(this IDataView data, bool ignoreMissingColumns = false, - SchemaDefinition schemaDefinition = null) - where TRow : class, new() - { - // REVIEW: Take an env as a parameter. - var env = new ConsoleEnvironment(); - return data.AsCursorable(env, ignoreMissingColumns, schemaDefinition); - } - /// /// Convert an into a strongly-typed . /// @@ -589,24 +569,5 @@ public static IEnumerable AsEnumerable(this IDataView data, IHostEnv var engine = new PipeEngine(env, data, ignoreMissingColumns, schemaDefinition); return engine.RunPipe(reuseRowObject); } - - /// - /// Convert an into a strongly-typed . - /// - /// The user-defined row type. - /// The underlying data view. - /// Whether to return the same object on every row, or allocate a new one per row. - /// Whether to ignore the case when a requested column is not present in the data view. - /// Optional user-provided schema definition. If it is not present, the schema is inferred from the definition of T. - /// The that holds the data in . It can be enumerated multiple times. - [Obsolete(NeedEnvObsoleteMessage)] - public static IEnumerable AsEnumerable(this IDataView data, bool reuseRowObject, - bool ignoreMissingColumns = false, SchemaDefinition schemaDefinition = null) - where TRow : class, new() - { - // REVIEW: Take an env as a parameter. - var env = new ConsoleEnvironment(); - return data.AsEnumerable(env, reuseRowObject, ignoreMissingColumns, schemaDefinition); - } } } diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index 28c623f3c2..81018dad3e 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -40,7 +40,8 @@ internal enum SettingsFlags /// /// This allows components to be created by name, signature type, and a settings string. /// - public interface ICommandLineComponentFactory : IComponentFactory + [BestFriend] + internal interface ICommandLineComponentFactory : IComponentFactory { Type SignatureType { get; } string Name { get; } diff --git a/src/Common/AssemblyLoadingUtils.cs b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs similarity index 97% rename from src/Common/AssemblyLoadingUtils.cs rename to src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs index ab2f2c563a..e947776bc9 100644 --- a/src/Common/AssemblyLoadingUtils.cs +++ b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs @@ -6,11 +6,13 @@ using System; using System.IO; using System.IO.Compression; -using System.Linq; using System.Reflection; namespace Microsoft.ML.Runtime { + [Obsolete("The usage for this is intended for the internal command line utilities and is not intended for anything related to the API. " + + "Please consider another way of doing whatever it is you're attempting to accomplish.")] + [BestFriend] internal static class AssemblyLoadingUtils { /// diff --git a/src/Microsoft.ML.Core/Data/ICommand.cs b/src/Microsoft.ML.Core/Data/ICommand.cs index c300b4f3a5..44d4c7340b 100644 --- a/src/Microsoft.ML.Core/Data/ICommand.cs +++ b/src/Microsoft.ML.Core/Data/ICommand.cs @@ -11,9 +11,11 @@ namespace Microsoft.ML.Runtime.Command /// /// The signature for commands. /// - public delegate void SignatureCommand(); + [BestFriend] + internal delegate void SignatureCommand(); - public interface ICommand + [BestFriend] + internal interface ICommand { void Run(); } diff --git a/src/Microsoft.ML.Core/Data/IFileHandle.cs b/src/Microsoft.ML.Core/Data/IFileHandle.cs index 2b57187250..37b871b7b6 100644 --- a/src/Microsoft.ML.Core/Data/IFileHandle.cs +++ b/src/Microsoft.ML.Core/Data/IFileHandle.cs @@ -61,7 +61,7 @@ public sealed class SimpleFileHandle : IFileHandle // handle has been disposed. private List _streams; - private bool IsDisposed { get { return _streams == null; } } + private bool IsDisposed => _streams == null; public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete) { @@ -84,15 +84,9 @@ public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bo _streams = new List(); } - public bool CanWrite - { - get { return !_wrote && !IsDisposed; } - } + public bool CanWrite => !_wrote && !IsDisposed; - public bool CanRead - { - get { return _wrote && !IsDisposed; } - } + public bool CanRead => _wrote && !IsDisposed; public void Dispose() { diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index 0b8c097e7d..eb0d57845c 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -72,6 +72,8 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider /// The suffix and prefix are optional. A common use for suffix is to specify an extension, eg, ".txt". /// The use of suffix and prefix, including whether they have any affect, is up to the host environment. /// + [Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " + + "Please handle your own temporary files within the component yourself, including their proper disposal and deletion.")] IFileHandle CreateTempFile(string suffix = null, string prefix = null); /// @@ -188,7 +190,8 @@ public readonly struct ChannelMessage /// public string Message => _args != null ? string.Format(_message, _args) : _message; - public ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message) + [BestFriend] + internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message) { Contracts.CheckNonEmpty(message, nameof(message)); Kind = kind; @@ -197,7 +200,8 @@ public ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, s _args = null; } - public ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args) + [BestFriend] + internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args) { Contracts.CheckNonEmpty(fmt, nameof(fmt)); Contracts.CheckNonEmpty(args, nameof(args)); diff --git a/src/Microsoft.ML.Core/Data/ProgressReporter.cs b/src/Microsoft.ML.Core/Data/ProgressReporter.cs index f7741b462c..191364e2a3 100644 --- a/src/Microsoft.ML.Core/Data/ProgressReporter.cs +++ b/src/Microsoft.ML.Core/Data/ProgressReporter.cs @@ -14,7 +14,8 @@ namespace Microsoft.ML.Runtime.Data /// /// The progress reporting classes used by descendants. /// - public static class ProgressReporting + [BestFriend] + internal static class ProgressReporting { /// /// The progress channel for . diff --git a/src/Microsoft.ML.Core/Data/ServerChannel.cs b/src/Microsoft.ML.Core/Data/ServerChannel.cs index b11e962ab2..a9b33d1986 100644 --- a/src/Microsoft.ML.Core/Data/ServerChannel.cs +++ b/src/Microsoft.ML.Core/Data/ServerChannel.cs @@ -19,7 +19,8 @@ namespace Microsoft.ML.Runtime /// delegates will be published in some fashion, with the target scenario being /// that the library will publish some sort of restful API. /// - public sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable + [BestFriend] + internal sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable { // See ServerChannel.md for a more elaborate discussion of high level usage and design. private readonly IChannelProvider _chp; @@ -250,7 +251,8 @@ public void AddDoneAction(Action onDone) } } - public static class ServerChannelUtilities + [BestFriend] + internal static class ServerChannelUtilities { /// /// Convenience method for that looks more idiomatic to typical diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs index 79a8c028ef..0163222fc1 100644 --- a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs +++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs @@ -9,13 +9,15 @@ namespace Microsoft.ML.Runtime.EntryPoints /// /// This is a signature for classes that are 'holders' of entry points and components. /// - public delegate void SignatureEntryPointModule(); + [BestFriend] + internal delegate void SignatureEntryPointModule(); /// /// A simplified assembly attribute for marking EntryPoint modules. /// [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] - public sealed class EntryPointModuleAttribute : LoadableClassAttributeBase + [BestFriend] + internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase { public EntryPointModuleAttribute(Type loaderType) : base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName) diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs index f64c8d0758..c4d4325f79 100644 --- a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs +++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs @@ -12,7 +12,8 @@ namespace Microsoft.ML.Runtime.EntryPoints { - public static class EntryPointUtils + [BestFriend] + internal static class EntryPointUtils { private static bool IsValueWithinRange(TlcModule.RangeAttribute range, object obj) { diff --git a/src/Microsoft.ML.Core/EntryPoints/IMlState.cs b/src/Microsoft.ML.Core/EntryPoints/IMlState.cs index 52b0828256..41ea062861 100644 --- a/src/Microsoft.ML.Core/EntryPoints/IMlState.cs +++ b/src/Microsoft.ML.Core/EntryPoints/IMlState.cs @@ -10,5 +10,5 @@ namespace Microsoft.ML.Runtime.EntryPoints /// black box to the graph. The macro itself will then case to the concrete type. /// public interface IMlState - {} + { } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs index 4dc02993d2..d538f636ce 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs @@ -18,7 +18,8 @@ namespace Microsoft.ML.Runtime.EntryPoints /// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining /// the module interface. /// - public static class TlcModule + [BestFriend] + internal static class TlcModule { /// /// An attribute used to annotate the component. diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs index e683a42413..3e27ce2516 100644 --- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs +++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs @@ -5,8 +5,6 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange using System; -using System.Collections.Concurrent; -using System.Collections.Generic; using System.IO; using System.Linq; using System.Threading; @@ -15,7 +13,12 @@ namespace Microsoft.ML.Runtime.Data { using Stopwatch = System.Diagnostics.Stopwatch; - public sealed class ConsoleEnvironment : HostEnvironmentBase + /// + /// The console environment. As its name suggests, should be limited to those applications that deliberately want + /// console functionality. + /// + [BestFriend] + internal sealed class ConsoleEnvironment : HostEnvironmentBase { public const string ComponentHistoryKey = "ComponentHistory"; diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index 6cfb4157be..31ab23e28f 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -15,7 +15,8 @@ namespace Microsoft.ML.Runtime.Data /// Base class for channel providers. This is a common base class for. /// The ParentFullName, ShortName, and FullName may be null or empty. /// - public abstract class ChannelProviderBase : IExceptionContext + [BestFriend] + internal abstract class ChannelProviderBase : IExceptionContext { /// /// Data keys that are attached to the exception thrown via the exception context. @@ -79,42 +80,22 @@ public virtual TException Process(TException ex) /// /// Message source (a channel) that generated the message being dispatched. /// - public interface IMessageSource + [BestFriend] + internal interface IMessageSource { string ShortName { get; } string FullName { get; } bool Verbose { get; } } - /// - /// A that is also a channel listener can attach - /// listeners for messages, as sent through . - /// - public interface IMessageDispatcher : IHostEnvironment - { - /// - /// Listen on this environment to messages of a particular type. - /// - /// The message type - /// The action to perform when a message of the - /// appropriate type is received. - void AddListener(Action listenerFunc); - - /// - /// Removes a previously added listener. - /// - /// The message type - /// The previous listener function that is now being removed. - void RemoveListener(Action listenerFunc); - } - /// /// A basic host environment suited for many environments. /// This also supports modifying the concurrency factor, provides the ability to subscribe to pipes via the /// AddListener/RemoveListener methods, and exposes the to /// query progress. /// - public abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironment, IDisposable, IChannelProvider, IMessageDispatcher + [BestFriend] + internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironment, IDisposable, IChannelProvider where TEnv : HostEnvironmentBase { /// diff --git a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs index ebd7d41486..72b08e2715 100644 --- a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs +++ b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs @@ -13,7 +13,8 @@ namespace Microsoft.ML.Runtime /// /// A telemetry message. /// - public abstract class TelemetryMessage + [BestFriend] + internal abstract class TelemetryMessage { public static TelemetryMessage CreateCommand(string commandName, string commandText) { @@ -40,7 +41,8 @@ public static TelemetryMessage CreateException(Exception exception) /// /// Message with one long text and bunch of small properties (limit on value is ~1020 chars) /// - public sealed class TelemetryTrace : TelemetryMessage + [BestFriend] + internal sealed class TelemetryTrace : TelemetryMessage { public readonly string Text; public readonly string Name; @@ -57,7 +59,8 @@ public TelemetryTrace(string text, string name, string type) /// /// Message with exception /// - public sealed class TelemetryException : TelemetryMessage + [BestFriend] + internal sealed class TelemetryException : TelemetryMessage { public readonly Exception Exception; public TelemetryException(Exception exception) @@ -70,7 +73,8 @@ public TelemetryException(Exception exception) /// /// Message with metric value and it properites /// - public sealed class TelemetryMetric : TelemetryMessage + [BestFriend] + internal sealed class TelemetryMetric : TelemetryMessage { public readonly string Name; public readonly double Value; diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs index 6647f77592..5e796aa602 100644 --- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs +++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Runtime.Data; +using System; namespace Microsoft.ML.Runtime { @@ -39,7 +40,8 @@ namespace Microsoft.ML.Runtime /// The base interface for a trainers. Implementors should not implement this interface directly, /// but rather implement the more specific . /// - public interface ITrainer + [BestFriend] + internal interface ITrainer { /// /// Auxiliary information about the trainer in terms of its capabilities @@ -66,7 +68,8 @@ public interface ITrainer /// and produces a predictor. /// /// Type of predictor produced - public interface ITrainer : ITrainer + [BestFriend] + internal interface ITrainer : ITrainer where TPredictor : IPredictor { /// @@ -77,7 +80,8 @@ public interface ITrainer : ITrainer new TPredictor Train(TrainContext context); } - public static class TrainerExtensions + [BestFriend] + internal static class TrainerExtensions { /// /// Convenience train extension for the case where one has only a training set with no auxiliary information. diff --git a/src/Microsoft.ML.Core/Prediction/TrainContext.cs b/src/Microsoft.ML.Core/Prediction/TrainContext.cs index 7bd1509bb5..e5e4bbad1c 100644 --- a/src/Microsoft.ML.Core/Prediction/TrainContext.cs +++ b/src/Microsoft.ML.Core/Prediction/TrainContext.cs @@ -11,7 +11,8 @@ namespace Microsoft.ML.Runtime /// into or . /// This holds at least a training set, as well as optioonally a predictor. /// - public sealed class TrainContext + [BestFriend] + internal sealed class TrainContext { /// /// The training set. Cannot be null. diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 24b84d38c1..2f20462290 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -21,7 +21,8 @@ namespace Microsoft.ML.Runtime.Data { - public sealed class CrossValidationCommand : DataCommand.ImplBase + [BestFriend] + internal sealed class CrossValidationCommand : DataCommand.ImplBase { // REVIEW: We need a way to specify different data sets, not just LabeledExamples. public sealed class Arguments : DataCommand.ArgumentsBase diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs index f617fb206f..5c5ceef6f7 100644 --- a/src/Microsoft.ML.Data/Commands/DataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs @@ -17,7 +17,8 @@ namespace Microsoft.ML.Runtime.Data /// /// This holds useful base classes for commands that ingest a primary dataset and deal with associated model files. /// - public static class DataCommand + [BestFriend] + internal static class DataCommand { public abstract class ArgumentsBase { @@ -56,7 +57,8 @@ public abstract class ArgumentsBase public KeyValuePair>[] Transform; } - public abstract class ImplBase : ICommand + [BestFriend] + internal abstract class ImplBase : ICommand where TArgs : ArgumentsBase { protected readonly IHost Host; diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs index 937f019c37..7ed4144262 100644 --- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs +++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs @@ -162,7 +162,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } } - public sealed class EvaluateCommand : DataCommand.ImplBase + internal sealed class EvaluateCommand : DataCommand.ImplBase { public sealed class Arguments : DataCommand.ArgumentsBase { diff --git a/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs b/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs index 577f51a4db..a721aee15a 100644 --- a/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs @@ -22,7 +22,7 @@ namespace Microsoft.ML.Runtime.Data { - public sealed class SaveDataCommand : DataCommand.ImplBase + internal sealed class SaveDataCommand : DataCommand.ImplBase { public sealed class Arguments : DataCommand.ArgumentsBase { @@ -87,7 +87,7 @@ private void RunCore(IChannel ch) } } - public sealed class ShowDataCommand : DataCommand.ImplBase + internal sealed class ShowDataCommand : DataCommand.ImplBase { public sealed class Arguments : DataCommand.ArgumentsBase { diff --git a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs index e1057d18b4..633f871dd0 100644 --- a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs +++ b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Tools { - public sealed class SavePredictorCommand : ICommand + internal sealed class SavePredictorCommand : ICommand { public sealed class Arguments { diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index d4c737f7b1..c0b2d7ecc5 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -37,7 +37,7 @@ public interface IDataScorerTransform : IDataTransform, ITransformTemplate public delegate void SignatureBindableMapper(IPredictor predictor); - public sealed class ScoreCommand : DataCommand.ImplBase + internal sealed class ScoreCommand : DataCommand.ImplBase { public sealed class Arguments : DataCommand.ArgumentsBase { @@ -232,7 +232,8 @@ private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNam } } - public static class ScoreUtils + [BestFriend] + internal static class ScoreUtils { public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema) { diff --git a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs index 5fbab5176a..20a9857a8f 100644 --- a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.Runtime.Data { - public sealed class ShowSchemaCommand : DataCommand.ImplBase + internal sealed class ShowSchemaCommand : DataCommand.ImplBase { public sealed class Arguments : DataCommand.ArgumentsBase { diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs index 574bd4e00a..407f7e713d 100644 --- a/src/Microsoft.ML.Data/Commands/TestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs @@ -14,8 +14,12 @@ namespace Microsoft.ML.Runtime.Data { - // This command is essentially chaining together Score and Evaluate, without the need to save the intermediary scored data. - public sealed class TestCommand : DataCommand.ImplBase + /// + /// This command is essentially chaining together and + /// , without the need to save the intermediary scored data. + /// + [BestFriend] + internal sealed class TestCommand : DataCommand.ImplBase { public sealed class Arguments : DataCommand.ArgumentsBase { diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index 53b025ef39..ff709de143 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -32,7 +32,8 @@ public enum NormalizeOption Yes } - public sealed class TrainCommand : DataCommand.ImplBase + [BestFriend] + internal sealed class TrainCommand : DataCommand.ImplBase { public sealed class Arguments : DataCommand.ArgumentsBase { @@ -202,7 +203,8 @@ private void RunCore(IChannel ch, string cmd) } } - public static class TrainUtils + [BestFriend] + internal static class TrainUtils { public static void CheckTrainer(IExceptionContext ectx, IComponentFactory trainer, string dataFile) { diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index 450d680765..49f25375f5 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; +using System.IO; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; @@ -16,7 +17,8 @@ namespace Microsoft.ML.Runtime.Data { - public sealed class TrainTestCommand : DataCommand.ImplBase + [BestFriend] + internal sealed class TrainTestCommand : DataCommand.ImplBase { public sealed class Arguments : DataCommand.ArgumentsBase { @@ -184,11 +186,12 @@ private void RunCore(IChannel ch, string cmd) Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer); IDataLoader testPipe; - using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ? - Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip")) + bool hasOutfile = !string.IsNullOrEmpty(Args.OutputModelFile); + var tempFilePath = hasOutfile ? null : Path.GetTempFileName(); + + using (var file = new SimpleFileHandle(ch, hasOutfile ? Args.OutputModelFile : tempFilePath, true, !hasOutfile)) { TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd); - ch.Trace("Constructing the testing pipeline"); using (var stream = file.OpenReadStream()) using (var rep = RepositoryReader.Open(stream, ch)) diff --git a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs index 79c3295017..e9db784f1f 100644 --- a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs @@ -17,7 +17,7 @@ namespace Microsoft.ML.Data.Commands { - public sealed class TypeInfoCommand : ICommand + internal sealed class TypeInfoCommand : ICommand { internal const string LoadName = "TypeInfo"; internal const string Summary = "Displays information about the standard primitive " + diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index cf1e5f7a25..73dd62152b 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -2137,7 +2137,7 @@ public override ValueGetter GetIdGetter() } } - public sealed class InfoCommand : ICommand + internal sealed class InfoCommand : ICommand { public const string LoadName = "IdvInfo"; diff --git a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs index 6db939a877..f866ca4817 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs @@ -173,11 +173,6 @@ public interface IPredictorWithFeatureWeights : IHaveFeatureWeights { } - public interface IHasLabelGains : ITrainer - { - Double[] GetLabelGains(); - } - /// /// Interface for mapping input values to corresponding feature contributions. /// This interface is commonly implemented by predictors. diff --git a/src/Microsoft.ML.Data/EntryPoints/Cache.cs b/src/Microsoft.ML.Data/EntryPoints/Cache.cs index 1bcbcc0bf8..aa2693f4aa 100644 --- a/src/Microsoft.ML.Data/EntryPoints/Cache.cs +++ b/src/Microsoft.ML.Data/EntryPoints/Cache.cs @@ -68,9 +68,11 @@ public static CacheOutput CacheData(IHostEnvironment env, CacheInput input) cols.Add(i); } +#pragma warning disable CS0618 // This ought to be addressed. See #1287. // We are not disposing the fileHandle because we want it to stay around for the execution of the graph. // It will be disposed when the environment is disposed. var fileHandle = host.CreateTempFile(); +#pragma warning restore CS0618 using (var stream = fileHandle.CreateWriteStream()) saver.SaveData(stream, input.Data, cols.ToArray()); diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs index 550191b157..3254289f12 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs @@ -99,7 +99,8 @@ public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight public Optional GroupIdColumn = Optional.Implicit(DefaultColumnNames.GroupId); } - public static class LearnerEntryPointsUtils + [BestFriend] + internal static class LearnerEntryPointsUtils { public static string FindColumn(IExceptionContext ectx, ISchema schema, Optional value) { diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index b6c650d77c..20d87fcbf6 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -690,7 +690,8 @@ public ValueMapper> GetWhatTheFeatureMapper(int } } - public static class CalibratorUtils + [BestFriend] + internal static class CalibratorUtils { private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator, ITrainer trainer, IPredictor predictor, RoleMappedSchema schema) diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs index dde426f0d6..849318c1e0 100644 --- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs @@ -7,12 +7,13 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformTest" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ResultProcessor" + PublicKey.Value)] -[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Data" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Api" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Ensemble" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.FastTree" + PublicKey.Value)] diff --git a/src/Microsoft.ML.Data/Training/TrainerBase.cs b/src/Microsoft.ML.Data/Training/TrainerBase.cs index ca2f2c7b64..d82dfb7ba6 100644 --- a/src/Microsoft.ML.Data/Training/TrainerBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerBase.cs @@ -19,7 +19,8 @@ public abstract class TrainerBase : ITrainer public abstract PredictionKind PredictionKind { get; } public abstract TrainerInfo Info { get; } - protected TrainerBase(IHostEnvironment env, string name) + [BestFriend] + private protected TrainerBase(IHostEnvironment env, string name) { Contracts.CheckValue(env, nameof(env)); env.CheckNonEmpty(name, nameof(name)); @@ -30,6 +31,9 @@ protected TrainerBase(IHostEnvironment env, string name) IPredictor ITrainer.Train(TrainContext context) => Train(context); - public abstract TPredictor Train(TrainContext context); + TPredictor ITrainer.Train(TrainContext context) => Train(context); + + [BestFriend] + private protected abstract TPredictor Train(TrainContext context); } } diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 035d66c054..1e49c32ed3 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -51,7 +51,8 @@ public abstract class TrainerEstimatorBase : ITrainerEstim public abstract PredictionKind PredictionKind { get; } - public TrainerEstimatorBase(IHost host, + [BestFriend] + private protected TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null) @@ -87,7 +88,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) /// protected abstract SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema); - public TModel Train(TrainContext context) + TModel ITrainer.Train(TrainContext context) { Host.CheckValue(context, nameof(context)); return TrainModelCore(context); @@ -149,14 +150,15 @@ protected TTransformer TrainTransformer(IDataView trainSet, return MakeTransformer(pred, trainSet.Schema); } - protected abstract TModel TrainModelCore(TrainContext trainContext); + [BestFriend] + private protected abstract TModel TrainModelCore(TrainContext trainContext); protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema); protected virtual RoleMappedData MakeRoles(IDataView data) => new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name); - IPredictor ITrainer.Train(TrainContext context) => Train(context); + IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context); } /// diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index 612631fae7..95bc4d3a10 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -263,33 +263,6 @@ public static IDataTransform CreateMinMaxNormalizer(IHostEnvironment env, IDataV return normalizer.Fit(input).MakeDataTransform(input); } - /// - /// Potentially apply a min-max normalizer to the data's feature column, keeping all existing role - /// mappings except for the feature role mapping. - /// - /// The host environment to use to potentially instantiate the transform - /// The role-mapped data that is potentially going to be modified by this method. - /// The trainer to query as to whether it wants normalization. If the - /// 's is true - /// True if the normalizer was applied and was modified - public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, ITrainer trainer) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(data, nameof(data)); - env.CheckValue(trainer, nameof(trainer)); - - // If the trainer does not need normalization, or if the features either don't exist - // or are not normalized, return false. - if (!trainer.Info.NeedNormalization || data.Schema.FeaturesAreNormalized() != false) - return false; - var featInfo = data.Schema.Feature; - env.AssertValue(featInfo); // Should be defined, if FeaturesAreNormalized returned a definite value. - - var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name); - data = new RoleMappedData(view, data.Schema.GetColumnRoleNames()); - return true; - } - /// /// Public create method corresponding to SignatureDataTransform. /// diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs index 9aa364b386..e8ec5f4e53 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs @@ -9,6 +9,7 @@ using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Transforms; +using System; using System.Collections.Generic; [assembly: LoadableClass(ScoringTransformer.Summary, typeof(IDataTransform), typeof(ScoringTransformer), typeof(ScoringTransformer.Arguments), typeof(SignatureDataTransform), @@ -98,7 +99,11 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } } - public static class TrainAndScoreTransformer + // Essentially, all trainer estimators when fitted return a transformer that produces scores -- which is to say, all machine + // learning algorithms actually behave more or less as this transform used to, so its presence is no longer necessary or helpful, + // from an API perspective, but this is still how things are invoked from the command line for now. + [BestFriend] + internal static class TrainAndScoreTransformer { public abstract class ArgumentsBase : TransformInputBase { diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs index 728cccb1f6..e5e9b79afc 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Ensemble.EntryPoints { - public static class Ensemble + internal static class Ensemble { [TlcModule.EntryPoint(Name = "Trainers.EnsembleBinaryClassifier", Desc = "Train binary ensemble.", UserName = EnsembleTrainer.UserNameValue)] public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, EnsembleTrainer.Arguments input) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs index dbe1517f22..257499cd13 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs @@ -9,7 +9,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { - public abstract class BaseScalarStacking : BaseStacking + internal abstract class BaseScalarStacking : BaseStacking { internal BaseScalarStacking(IHostEnvironment env, string name, ArgumentsBase args) : base(env, name, args) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index eb6dc3a650..d62650b7c2 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -16,7 +16,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using ColumnRole = RoleMappedSchema.ColumnRole; - public abstract class BaseStacking : IStackingTrainer + internal abstract class BaseStacking : IStackingTrainer { public abstract class ArgumentsBase { @@ -28,13 +28,13 @@ public abstract class ArgumentsBase internal abstract IComponentFactory>> GetPredictorFactory(); } - protected readonly IComponentFactory>> BasePredictorType; - protected readonly IHost Host; - protected IPredictorProducing Meta; + private protected readonly IComponentFactory>> BasePredictorType; + private protected readonly IHost Host; + private protected IPredictorProducing Meta; public Single ValidationDatasetProportion { get; } - internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args) + private protected BaseStacking(IHostEnvironment env, string name, ArgumentsBase args) { Contracts.AssertValue(env); env.AssertNonWhiteSpace(name); @@ -49,7 +49,7 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args) Host.CheckValue(BasePredictorType, nameof(BasePredictorType)); } - internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx) + private protected BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx) { Contracts.AssertValue(env); env.AssertNonWhiteSpace(name); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 4e352e8265..f9e3b246f7 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -21,7 +21,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TVectorPredictor = IPredictorProducing>; - public sealed class MultiStacking : BaseStacking>, ICanSaveModel, IMultiClassOutputCombiner + internal sealed class MultiStacking : BaseStacking>, ICanSaveModel, IMultiClassOutputCombiner { public const string LoadName = "MultiStacking"; public const string LoaderSignature = "MultiStackingCombiner"; @@ -37,9 +37,11 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(MultiStacking).Assembly.FullName); } +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory { + // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer. [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] [TGUI(Label = "Base predictor")] @@ -49,6 +51,7 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this); } +#pragma warning restore CS0649 public MultiStacking(IHostEnvironment env, Arguments args) : base(env, LoaderSignature, args) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index 9adb9ba799..8c984613db 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing; - public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel + internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel { public const string LoadName = "RegressionStacking"; public const string LoaderSignature = "RegressionStacking"; @@ -35,9 +35,11 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(RegressionStacking).Assembly.FullName); } +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory { + // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer. [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] [TGUI(Label = "Base predictor")] @@ -47,6 +49,7 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this); } +#pragma warning restore CS0649 public RegressionStacking(IHostEnvironment env, Arguments args) : base(env, LoaderSignature, args) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index 8c36cc866e..f44f987b05 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -16,7 +16,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing; - public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel + internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel { public const string UserName = "Stacking"; public const string LoadName = "Stacking"; @@ -33,9 +33,11 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(Stacking).Assembly.FullName); } +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFactory { + // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer. [Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] [TGUI(Label = "Base predictor")] @@ -45,6 +47,7 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this); } +#pragma warning restore CS0649 public Stacking(IHostEnvironment env, Arguments args) : base(env, LoaderSignature, args) diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs index cc03de02d4..a1befc112b 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs @@ -29,7 +29,7 @@ namespace Microsoft.ML.Runtime.Ensemble /// /// A generic ensemble trainer for binary classification. /// - public sealed class EnsembleTrainer : EnsembleTrainerBase, IModelCombiner { @@ -48,6 +48,7 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportBinaryOutputCombinerFactory OutputCombiner = new MedianFactory(); + // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer. [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))] public IComponentFactory>[] BasePredictors; diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs index 9d7c6fab40..a8fc896c5b 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs @@ -101,7 +101,7 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env, } } - public sealed override TPredictor Train(TrainContext context) + private protected sealed override TPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs index 3909fb1b07..1961e6a785 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs @@ -30,7 +30,7 @@ namespace Microsoft.ML.Runtime.Ensemble /// /// A generic ensemble classifier for multi-class classification /// - public sealed class MulticlassDataPartitionEnsembleTrainer : + internal sealed class MulticlassDataPartitionEnsembleTrainer : EnsembleTrainerBase, EnsembleMultiClassPredictor, IMulticlassSubModelSelector, IMultiClassOutputCombiner>, IModelCombiner @@ -49,6 +49,7 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Arguments(); + // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer. [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))] public IComponentFactory>[] BasePredictors; diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs index b7e63b8862..09d394d596 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Runtime.Ensemble { using TScalarPredictor = IPredictorProducing; - public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase, IModelCombiner { @@ -43,6 +43,7 @@ public sealed class Arguments : ArgumentsBase [TGUI(Label = "Output combiner", Description = "Output combiner type")] public ISupportRegressionOutputCombinerFactory OutputCombiner = new MedianFactory(); + // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer. [Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))] public IComponentFactory>[] BasePredictors; diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index b56e92e4f7..f68108d2e8 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -17,7 +17,7 @@ namespace Microsoft.ML.Trainers.FastTree { [TlcModule.ComponentKind("FastTreeTrainer")] - public interface IFastTreeTrainerFactory : IComponentFactory + internal interface IFastTreeTrainerFactory : IComponentFactory { } @@ -31,7 +31,7 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory [TGUI(Label = "Optimize for unbalanced")] public bool UnbalancedSets = false; - public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this); + ITrainer IComponentFactory.CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this); } } @@ -45,7 +45,7 @@ public Arguments() EarlyStoppingMetrics = 1; // Use L1 by default. } - public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeRegressionTrainer(env, this); + ITrainer IComponentFactory.CreateComponent(IHostEnvironment env) => new FastTreeRegressionTrainer(env, this); } } @@ -62,7 +62,7 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory "and intermediate values are compound Poisson loss.")] public Double Index = 1.5; - public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeTweedieTrainer(env, this); + ITrainer IComponentFactory.CreateComponent(IHostEnvironment env) => new FastTreeTweedieTrainer(env, this); } } @@ -111,7 +111,7 @@ public Arguments() EarlyStoppingMetrics = 1; } - public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this); + ITrainer IComponentFactory.CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this); internal override void Check(IExceptionContext ectx) { diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 3b278a4e56..433cafa908 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -154,7 +154,7 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments arg public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context) + private protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 3eb2b83b84..2663c9df51 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -42,8 +42,7 @@ namespace Microsoft.ML.Trainers.FastTree { /// public sealed partial class FastTreeRankingTrainer - : BoostingFastTreeTrainerBase, FastTreeRankingPredictor>, - IHasLabelGains + : BoostingFastTreeTrainerBase, FastTreeRankingPredictor> { internal const string LoadNameValue = "FastTreeRanking"; internal const string UserNameValue = "FastTree (Boosted Trees) Ranking"; @@ -100,7 +99,7 @@ protected override float GetMaxLabel() return GetLabelGains().Length - 1; } - protected override FastTreeRankingPredictor TrainModelCore(TrainContext context) + private protected override FastTreeRankingPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -117,7 +116,7 @@ protected override FastTreeRankingPredictor TrainModelCore(TrainContext context) return new FastTreeRankingPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); } - public Double[] GetLabelGains() + private Double[] GetLabelGains() { try { diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 2b8e1b60c3..133f9c2bd4 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -84,7 +84,7 @@ internal FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) { } - protected override FastTreeRegressionPredictor TrainModelCore(TrainContext context) + private protected override FastTreeRegressionPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 7b70bf8169..f49798aa22 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -86,7 +86,7 @@ internal FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) Initialize(); } - protected override FastTreeTweediePredictor TrainModelCore(TrainContext context) + private protected override FastTreeTweediePredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 4afebf83ba..3f45cb7ed3 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -102,7 +102,7 @@ private static bool[] ConvertTargetsToBool(double[] targets) return boolArray; } - protected override IPredictorProducing TrainModelCore(TrainContext context) + private protected override IPredictorProducing TrainModelCore(TrainContext context) { TrainBase(context); var predictor = new BinaryClassGamPredictor(Host, InputLength, TrainSet, diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 9669bff805..575c1d9e24 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -70,7 +70,7 @@ internal override void CheckLabel(RoleMappedData data) data.CheckRegressionLabel(); } - protected override RegressionGamPredictor TrainModelCore(TrainContext context) + private protected override RegressionGamPredictor TrainModelCore(TrainContext context) { TrainBase(context); return new RegressionGamPredictor(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap); diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index ffe8516b20..d77000d7ae 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -187,7 +187,7 @@ private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, InitializeThreads(); } - protected void TrainBase(TrainContext context) + private protected void TrainBase(TrainContext context) { using (var ch = Host.Start("Training")) { @@ -981,7 +981,7 @@ public void SaveSummary(TextWriter writer, RoleMappedSchema schema) /// , it is convenient to have the command itself nested within the base /// predictor class. /// - public sealed class VisualizationCommand : DataCommand.ImplBase + internal sealed class VisualizationCommand : DataCommand.ImplBase { public const string Summary = "Loads a model trained with a GAM learner, and starts an interactive web session to visualize it."; public const string LoadName = "GamVisualization"; diff --git a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..a03d7bdab6 --- /dev/null +++ b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Microsoft.ML; + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)] + +[assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index b5f85bf919..a7aedc692a 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -168,7 +168,7 @@ public FastForestClassification(IHostEnvironment env, Arguments args) { } - protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context) + private protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index d805e9c750..4fe842214c 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -189,7 +189,7 @@ public FastForestRegression(IHostEnvironment env, Arguments args) { } - protected override FastForestRegressionPredictor TrainModelCore(TrainContext context) + private protected override FastForestRegressionPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; diff --git a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs index 4d02097941..14fcad759e 100644 --- a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs +++ b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs @@ -29,7 +29,7 @@ namespace Microsoft.ML.Trainers.FastTree /// /// This is an internal utility command to measure the performance of the IntArray sumup operation. /// - public sealed class SumupPerformanceCommand : ICommand + internal sealed class SumupPerformanceCommand : ICommand { public sealed class Arguments { diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 252ac42abe..b2a2973da7 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -547,8 +547,10 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) } /// - public static class TreeEnsembleFeaturizerTransform + [BestFriend] + internal static class TreeEnsembleFeaturizerTransform { +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. public sealed class Arguments : TrainAndScoreTransformer.ArgumentsBase { [Argument(ArgumentType.Multiple, HelpText = "Trainer to use", ShortName = "tr", NullName = "", SortOrder = 1, SignatureType = typeof(SignatureTreeEnsembleTrainer))] @@ -586,6 +588,7 @@ public sealed class ArgumentsForEntryPoint : TransformInputBase [Argument(ArgumentType.Required, HelpText = "Trainer to use", SortOrder = 10, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] public IPredictorModel PredictorModel; } +#pragma warning restore CS0649 internal const string TreeEnsembleSummary = "Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector " + @@ -801,8 +804,9 @@ private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch, } } - public static partial class TreeFeaturize + internal static partial class TreeFeaturize { +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. [TlcModule.EntryPoint(Name = "Transforms.TreeLeafFeaturizer", Desc = TreeEnsembleFeaturizerTransform.TreeEnsembleSummary, UserName = TreeEnsembleFeaturizerTransform.UserName, @@ -818,5 +822,6 @@ public static CommonOutputs.TransformOutput Featurizer(IHostEnvironment env, Tre var xf = TreeEnsembleFeaturizerTransform.CreateForEntryPoint(env, input, input.Data); return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }; } +#pragma warning restore CS0649 } } diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index 48ff08b67b..c1dfab7957 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -130,7 +130,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc private static Double ProbClamp(Double p) => Math.Max(0, Math.Min(p, 1)); - protected override OlsLinearRegressionPredictor TrainModelCore(TrainContext context) + private protected override OlsLinearRegressionPredictor TrainModelCore(TrainContext context) { using (var ch = Host.Start("Training")) { diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs index a37dac9338..590aa06e27 100644 --- a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs +++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs @@ -133,7 +133,7 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa return examplesToFeedTrain; } - protected override TPredictor TrainModelCore(TrainContext context) + private protected override TPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); using (var ch = Host.Start("Training")) diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index 2682e50301..7e2bb4f807 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -150,7 +150,7 @@ private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, Action diff --git a/src/Microsoft.ML.Legacy/LearningPipeline.cs b/src/Microsoft.ML.Legacy/LearningPipeline.cs index e56e49f6a3..7544d2f112 100644 --- a/src/Microsoft.ML.Legacy/LearningPipeline.cs +++ b/src/Microsoft.ML.Legacy/LearningPipeline.cs @@ -161,86 +161,84 @@ public PredictionModel Train() where TInput : class where TOutput : class, new() { - using (var environment = new ConsoleEnvironment(seed: _seed, conc: _conc)) + var environment = new MLContext(seed: _seed, conc: _conc); + Experiment experiment = environment.CreateExperiment(); + ILearningPipelineStep step = null; + List loaders = new List(); + List> transformModels = new List>(); + Var lastTransformModel = null; + + foreach (ILearningPipelineItem currentItem in this) { - Experiment experiment = environment.CreateExperiment(); - ILearningPipelineStep step = null; - List loaders = new List(); - List> transformModels = new List>(); - Var lastTransformModel = null; + if (currentItem is ILearningPipelineLoader loader) + loaders.Add(loader); - foreach (ILearningPipelineItem currentItem in this) + step = currentItem.ApplyStep(step, experiment); + if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null) + transformModels.Add(dataStep.Model); + else if (step is ILearningPipelinePredictorStep predictorDataStep) { - if (currentItem is ILearningPipelineLoader loader) - loaders.Add(loader); + if (lastTransformModel != null) + transformModels.Insert(0, lastTransformModel); - step = currentItem.ApplyStep(step, experiment); - if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null) - transformModels.Add(dataStep.Model); - else if (step is ILearningPipelinePredictorStep predictorDataStep) + Var predictorModel; + if (transformModels.Count != 0) { - if (lastTransformModel != null) - transformModels.Insert(0, lastTransformModel); - - Var predictorModel; - if (transformModels.Count != 0) + var localModelInput = new Transforms.ManyHeterogeneousModelCombiner { - var localModelInput = new Transforms.ManyHeterogeneousModelCombiner - { - PredictorModel = predictorDataStep.Model, - TransformModels = new ArrayVar(transformModels.ToArray()) - }; - var localModelOutput = experiment.Add(localModelInput); - predictorModel = localModelOutput.PredictorModel; - } - else - predictorModel = predictorDataStep.Model; - - var scorer = new Transforms.Scorer - { - PredictorModel = predictorModel + PredictorModel = predictorDataStep.Model, + TransformModels = new ArrayVar(transformModels.ToArray()) }; - - var scorerOutput = experiment.Add(scorer); - lastTransformModel = scorerOutput.ScoringTransform; - step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform); - transformModels.Clear(); + var localModelOutput = experiment.Add(localModelInput); + predictorModel = localModelOutput.PredictorModel; } - } + else + predictorModel = predictorDataStep.Model; - if (transformModels.Count > 0) - { - if (lastTransformModel != null) - transformModels.Insert(0, lastTransformModel); - - var modelInput = new Transforms.ModelCombiner + var scorer = new Transforms.Scorer { - Models = new ArrayVar(transformModels.ToArray()) + PredictorModel = predictorModel }; - var modelOutput = experiment.Add(modelInput); - lastTransformModel = modelOutput.OutputModel; + var scorerOutput = experiment.Add(scorer); + lastTransformModel = scorerOutput.ScoringTransform; + step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform); + transformModels.Clear(); } + } - experiment.Compile(); - foreach (ILearningPipelineLoader loader in loaders) - { - loader.SetInput(environment, experiment); - } - experiment.Run(); + if (transformModels.Count > 0) + { + if (lastTransformModel != null) + transformModels.Insert(0, lastTransformModel); - ITransformModel model = experiment.GetOutput(lastTransformModel); - BatchPredictionEngine predictor; - using (var memoryStream = new MemoryStream()) + var modelInput = new Transforms.ModelCombiner { - model.Save(environment, memoryStream); + Models = new ArrayVar(transformModels.ToArray()) + }; - memoryStream.Position = 0; + var modelOutput = experiment.Add(modelInput); + lastTransformModel = modelOutput.OutputModel; + } - predictor = environment.CreateBatchPredictionEngine(memoryStream); + experiment.Compile(); + foreach (ILearningPipelineLoader loader in loaders) + { + loader.SetInput(environment, experiment); + } + experiment.Run(); - return new PredictionModel(predictor, memoryStream); - } + ITransformModel model = experiment.GetOutput(lastTransformModel); + BatchPredictionEngine predictor; + using (var memoryStream = new MemoryStream()) + { + model.Save(environment, memoryStream); + + memoryStream.Position = 0; + + predictor = environment.CreateBatchPredictionEngine(memoryStream); + + return new PredictionModel(predictor, memoryStream); } } diff --git a/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs b/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs index af99d8ff3c..586dbf99e3 100644 --- a/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs +++ b/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Legacy.Transforms; using System; @@ -25,7 +26,7 @@ internal sealed class LearningPipelineDebugProxy private const int MaxSlotNamesToDisplay = 100; private readonly LearningPipeline _pipeline; - private readonly ConsoleEnvironment _environment; + private readonly IHostEnvironment _environment; private IDataView _preview; private Exception _pipelineExecutionException; private PipelineItemDebugColumn[] _columns; @@ -39,7 +40,7 @@ public LearningPipelineDebugProxy(LearningPipeline pipeline) _pipeline = new LearningPipeline(); // use a ConcurrencyFactor of 1 so other threads don't need to run in the debugger - _environment = new ConsoleEnvironment(conc: 1); + _environment = new MLContext(conc: 1); foreach (ILearningPipelineItem item in pipeline) { diff --git a/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj b/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj index 6e4034aa24..1ccf77d19d 100644 --- a/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj +++ b/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj @@ -6,10 +6,6 @@ CORECLR - - - - diff --git a/src/Microsoft.ML.Legacy/Models/BinaryClassificationEvaluator.cs b/src/Microsoft.ML.Legacy/Models/BinaryClassificationEvaluator.cs index 71b32a323a..2f2bce8016 100644 --- a/src/Microsoft.ML.Legacy/Models/BinaryClassificationEvaluator.cs +++ b/src/Microsoft.ML.Legacy/Models/BinaryClassificationEvaluator.cs @@ -24,54 +24,52 @@ public sealed partial class BinaryClassificationEvaluator /// public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) { - using (var environment = new ConsoleEnvironment()) - { - environment.CheckValue(model, nameof(model)); - environment.CheckValue(testData, nameof(testData)); + var environment = new MLContext(); + environment.CheckValue(model, nameof(model)); + environment.CheckValue(testData, nameof(testData)); - Experiment experiment = environment.CreateExperiment(); + Experiment experiment = environment.CreateExperiment(); - ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment); - if (!(testDataStep is ILearningPipelineDataStep testDataOutput)) - { - throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep."); - } + ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment); + if (!(testDataStep is ILearningPipelineDataStep testDataOutput)) + { + throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep."); + } - var datasetScorer = new DatasetTransformScorer - { - Data = testDataOutput.Data - }; - DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer); + var datasetScorer = new DatasetTransformScorer + { + Data = testDataOutput.Data + }; + DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer); - Data = scoreOutput.ScoredData; - Output evaluteOutput = experiment.Add(this); + Data = scoreOutput.ScoredData; + Output evaluteOutput = experiment.Add(this); - experiment.Compile(); + experiment.Compile(); - experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel); - testData.SetInput(environment, experiment); + experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel); + testData.SetInput(environment, experiment); - experiment.Run(); + experiment.Run(); - IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); - if (overallMetrics == null) - { - throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate."); - } + IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); + if (overallMetrics == null) + { + throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate."); + } - IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix); - if (confusionMatrix == null) - { - throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate."); - } + IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix); + if (confusionMatrix == null) + { + throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate."); + } - var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); + var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); - if (metric.Count != 1) - throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics"); + if (metric.Count != 1) + throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics"); - return metric[0]; - } + return metric[0]; } } } diff --git a/src/Microsoft.ML.Legacy/Models/ClassificationEvaluator.cs b/src/Microsoft.ML.Legacy/Models/ClassificationEvaluator.cs index 63e7f8055e..5d644baf32 100644 --- a/src/Microsoft.ML.Legacy/Models/ClassificationEvaluator.cs +++ b/src/Microsoft.ML.Legacy/Models/ClassificationEvaluator.cs @@ -25,54 +25,52 @@ public sealed partial class ClassificationEvaluator /// public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) { - using (var environment = new ConsoleEnvironment()) - { - environment.CheckValue(model, nameof(model)); - environment.CheckValue(testData, nameof(testData)); + var environment = new MLContext(); + environment.CheckValue(model, nameof(model)); + environment.CheckValue(testData, nameof(testData)); - Experiment experiment = environment.CreateExperiment(); + Experiment experiment = environment.CreateExperiment(); - ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment); - if (!(testDataStep is ILearningPipelineDataStep testDataOutput)) - { - throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep."); - } + ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment); + if (!(testDataStep is ILearningPipelineDataStep testDataOutput)) + { + throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep."); + } - var datasetScorer = new DatasetTransformScorer - { - Data = testDataOutput.Data, - }; - DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer); + var datasetScorer = new DatasetTransformScorer + { + Data = testDataOutput.Data, + }; + DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer); - Data = scoreOutput.ScoredData; - Output evaluteOutput = experiment.Add(this); + Data = scoreOutput.ScoredData; + Output evaluteOutput = experiment.Add(this); - experiment.Compile(); + experiment.Compile(); - experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel); - testData.SetInput(environment, experiment); + experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel); + testData.SetInput(environment, experiment); - experiment.Run(); + experiment.Run(); - IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); - if (overallMetrics == null) - { - throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClassificationEvaluator)} Evaluate."); - } + IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); + if (overallMetrics == null) + { + throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClassificationEvaluator)} Evaluate."); + } - IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix); - if (confusionMatrix == null) - { - throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(ClassificationEvaluator)} Evaluate."); - } + IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix); + if (confusionMatrix == null) + { + throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(ClassificationEvaluator)} Evaluate."); + } - var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); + var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); - if (metric.Count != 1) - throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics"); + if (metric.Count != 1) + throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics"); - return metric[0]; - } + return metric[0]; } } } diff --git a/src/Microsoft.ML.Legacy/Models/ClusterEvaluator.cs b/src/Microsoft.ML.Legacy/Models/ClusterEvaluator.cs index 411c85b176..5d12ad85f9 100644 --- a/src/Microsoft.ML.Legacy/Models/ClusterEvaluator.cs +++ b/src/Microsoft.ML.Legacy/Models/ClusterEvaluator.cs @@ -24,48 +24,46 @@ public sealed partial class ClusterEvaluator /// public ClusterMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) { - using (var environment = new ConsoleEnvironment()) - { - environment.CheckValue(model, nameof(model)); - environment.CheckValue(testData, nameof(testData)); + var environment = new MLContext(); + environment.CheckValue(model, nameof(model)); + environment.CheckValue(testData, nameof(testData)); - Experiment experiment = environment.CreateExperiment(); + Experiment experiment = environment.CreateExperiment(); - ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment); - if (!(testDataStep is ILearningPipelineDataStep testDataOutput)) - { - throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep."); - } + ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment); + if (!(testDataStep is ILearningPipelineDataStep testDataOutput)) + { + throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep."); + } - var datasetScorer = new DatasetTransformScorer - { - Data = testDataOutput.Data, - }; - DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer); + var datasetScorer = new DatasetTransformScorer + { + Data = testDataOutput.Data, + }; + DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer); - Data = scoreOutput.ScoredData; - Output evaluteOutput = experiment.Add(this); + Data = scoreOutput.ScoredData; + Output evaluteOutput = experiment.Add(this); - experiment.Compile(); + experiment.Compile(); - experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel); - testData.SetInput(environment, experiment); + experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel); + testData.SetInput(environment, experiment); - experiment.Run(); + experiment.Run(); - IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); + IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); - if (overallMetrics == null) - { - throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClusterEvaluator)} Evaluate."); - } + if (overallMetrics == null) + { + throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClusterEvaluator)} Evaluate."); + } - var metric = ClusterMetrics.FromOverallMetrics(environment, overallMetrics); + var metric = ClusterMetrics.FromOverallMetrics(environment, overallMetrics); - Contracts.Assert(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics"); + Contracts.Assert(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics"); - return metric[0]; - } + return metric[0]; } } } diff --git a/src/Microsoft.ML.Legacy/Models/CrossValidator.cs b/src/Microsoft.ML.Legacy/Models/CrossValidator.cs index d9d133a779..96be9db419 100644 --- a/src/Microsoft.ML.Legacy/Models/CrossValidator.cs +++ b/src/Microsoft.ML.Legacy/Models/CrossValidator.cs @@ -27,7 +27,7 @@ public CrossValidationOutput CrossValidate(Lea where TInput : class where TOutput : class, new() { - using (var environment = new ConsoleEnvironment()) + var environment = new MLContext(); { Experiment subGraph = environment.CreateExperiment(); ILearningPipelineStep step = null; diff --git a/src/Microsoft.ML.Legacy/Models/OneVersusAll.cs b/src/Microsoft.ML.Legacy/Models/OneVersusAll.cs index acd756556a..3269556f1f 100644 --- a/src/Microsoft.ML.Legacy/Models/OneVersusAll.cs +++ b/src/Microsoft.ML.Legacy/Models/OneVersusAll.cs @@ -52,26 +52,24 @@ public OvaPipelineItem(ITrainerInputWithLabel trainer, bool useProbabilities) public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var subgraph = env.CreateExperiment(); + subgraph.Add(_trainer); + var ova = new OneVersusAll(); + if (previousStep != null) { - var subgraph = env.CreateExperiment(); - subgraph.Add(_trainer); - var ova = new OneVersusAll(); - if (previousStep != null) + if (!(previousStep is ILearningPipelineDataStep dataStep)) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) - { - throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } - - _data = dataStep.Data; - ova.TrainingData = dataStep.Data; - ova.UseProbabilities = _useProbabilities; - ova.Nodes = subgraph; + throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); } - Output output = experiment.Add(ova); - return new OvaPipelineStep(output); + + _data = dataStep.Data; + ova.TrainingData = dataStep.Data; + ova.UseProbabilities = _useProbabilities; + ova.Nodes = subgraph; } + Output output = experiment.Add(ova); + return new OvaPipelineStep(output); } public Var GetInputData() => _data; diff --git a/src/Microsoft.ML.Legacy/Models/OnnxConverter.cs b/src/Microsoft.ML.Legacy/Models/OnnxConverter.cs index c49c71cf71..0c9eca3ee8 100644 --- a/src/Microsoft.ML.Legacy/Models/OnnxConverter.cs +++ b/src/Microsoft.ML.Legacy/Models/OnnxConverter.cs @@ -70,16 +70,14 @@ public sealed partial class OnnxConverter /// Model that needs to be converted to ONNX format. public void Convert(PredictionModel model) { - using (var environment = new ConsoleEnvironment()) - { - environment.CheckValue(model, nameof(model)); + var environment = new MLContext(); + environment.CheckValue(model, nameof(model)); - Experiment experiment = environment.CreateExperiment(); - experiment.Add(this); - experiment.Compile(); - experiment.SetInput(Model, model.PredictorModel); - experiment.Run(); - } + Experiment experiment = environment.CreateExperiment(); + experiment.Add(this); + experiment.Compile(); + experiment.SetInput(Model, model.PredictorModel); + experiment.Run(); } } } diff --git a/src/Microsoft.ML.Legacy/Models/RegressionEvaluator.cs b/src/Microsoft.ML.Legacy/Models/RegressionEvaluator.cs index 35a7fd7500..ffee6108c6 100644 --- a/src/Microsoft.ML.Legacy/Models/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Legacy/Models/RegressionEvaluator.cs @@ -24,49 +24,47 @@ public sealed partial class RegressionEvaluator /// public RegressionMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) { - using (var environment = new ConsoleEnvironment()) - { - environment.CheckValue(model, nameof(model)); - environment.CheckValue(testData, nameof(testData)); + var environment = new MLContext(); + environment.CheckValue(model, nameof(model)); + environment.CheckValue(testData, nameof(testData)); - Experiment experiment = environment.CreateExperiment(); + Experiment experiment = environment.CreateExperiment(); - ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment); - if (!(testDataStep is ILearningPipelineDataStep testDataOutput)) - { - throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep."); - } + ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment); + if (!(testDataStep is ILearningPipelineDataStep testDataOutput)) + { + throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep."); + } - var datasetScorer = new DatasetTransformScorer - { - Data = testDataOutput.Data, - }; - DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer); + var datasetScorer = new DatasetTransformScorer + { + Data = testDataOutput.Data, + }; + DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer); - Data = scoreOutput.ScoredData; - Output evaluteOutput = experiment.Add(this); + Data = scoreOutput.ScoredData; + Output evaluteOutput = experiment.Add(this); - experiment.Compile(); + experiment.Compile(); - experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel); - testData.SetInput(environment, experiment); + experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel); + testData.SetInput(environment, experiment); - experiment.Run(); + experiment.Run(); - IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); + IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics); - if (overallMetrics == null) - { - throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(RegressionEvaluator)} Evaluate."); - } + if (overallMetrics == null) + { + throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(RegressionEvaluator)} Evaluate."); + } - var metric = RegressionMetrics.FromOverallMetrics(environment, overallMetrics); + var metric = RegressionMetrics.FromOverallMetrics(environment, overallMetrics); - if (metric.Count != 1) - throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics"); + if (metric.Count != 1) + throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics"); - return metric[0]; - } + return metric[0]; } } } diff --git a/src/Microsoft.ML.Legacy/Models/TrainTestEvaluator.cs b/src/Microsoft.ML.Legacy/Models/TrainTestEvaluator.cs index 6972b8cf3e..dbf5df1b50 100644 --- a/src/Microsoft.ML.Legacy/Models/TrainTestEvaluator.cs +++ b/src/Microsoft.ML.Legacy/Models/TrainTestEvaluator.cs @@ -30,7 +30,7 @@ public TrainTestEvaluatorOutput TrainTestEvaluate /// Returns labels that correspond to indices of the score array in the case of @@ -40,7 +36,7 @@ internal TransformModel PredictorModel public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = DefaultColumnNames.Score) { names = null; - var schema = _predictorModel.OutputSchema; + var schema = PredictorModel.OutputSchema; int colIndex = -1; if (!schema.TryGetColumnIndex(scoreColumnName, out colIndex)) return false; @@ -125,15 +121,13 @@ public static Task> ReadAsync( if (stream == null) throw new ArgumentNullException(nameof(stream)); - using (var environment = new ConsoleEnvironment()) - { - AssemblyRegistration.RegisterAssemblies(environment); + var environment = new MLContext(); + AssemblyRegistration.RegisterAssemblies(environment); - BatchPredictionEngine predictor = - environment.CreateBatchPredictionEngine(stream); + BatchPredictionEngine predictor = + environment.CreateBatchPredictionEngine(stream); - return Task.FromResult(new PredictionModel(predictor, stream)); - } + return Task.FromResult(new PredictionModel(predictor, stream)); } /// @@ -141,7 +135,7 @@ public static Task> ReadAsync( /// /// Incoming IDataView /// IDataView which contains predictions - public IDataView Predict(IDataView input) => _predictorModel.Apply(_env, input); + public IDataView Predict(IDataView input) => PredictorModel.Apply(_env, input); /// /// Save model to file. @@ -168,7 +162,7 @@ public Task WriteAsync(Stream stream) { if (stream == null) throw new ArgumentNullException(nameof(stream)); - _predictorModel.Save(_env, stream); + PredictorModel.Save(_env, stream); return Task.CompletedTask; } } diff --git a/src/Microsoft.ML.Legacy/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Legacy/Properties/AssemblyInfo.cs index 52835b2f81..70fe21adb6 100644 --- a/src/Microsoft.ML.Legacy/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Legacy/Properties/AssemblyInfo.cs @@ -2,8 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Reflection; +using Microsoft.ML; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -[assembly: InternalsVisibleTo("Microsoft.ML.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] +[assembly: InternalsVisibleTo("Microsoft.ML.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo("Microsoft.ML.Core.Tests" + PublicKey.TestValue)] diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/ModuleGenerator.cs b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/ModuleGenerator.cs index c612197f31..817706522b 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/ModuleGenerator.cs +++ b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/ModuleGenerator.cs @@ -21,7 +21,7 @@ namespace Microsoft.ML.Runtime.EntryPoints.CodeGen { - public class ModuleGenerator : IGenerator + internal sealed class ModuleGenerator : IGenerator { private readonly string _modulePrefix; private readonly bool _generateModule; diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs index a3d02d10e4..d6e95961ec 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs +++ b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs @@ -21,7 +21,7 @@ namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils { - public sealed class ExecuteGraphCommand : ICommand + internal sealed class ExecuteGraphCommand : ICommand { public sealed class Arguments { diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs index 9ebb25e301..af6e2b818d 100644 --- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs +++ b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs @@ -140,7 +140,7 @@ public void SetInput(string name, TInput input) /// /// Get the data kind of a particular port. /// - public TlcModule.DataKind GetPortDataKind(string name) + internal TlcModule.DataKind GetPortDataKind(string name) { _host.CheckNonEmpty(name, nameof(name)); EntryPointVariable variable; diff --git a/src/Microsoft.ML.Legacy/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML.Legacy/Runtime/Internal/Tools/CSharpApiGenerator.cs index 0bcf5a6776..0fd2aa9aa0 100644 --- a/src/Microsoft.ML.Legacy/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML.Legacy/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -22,7 +22,7 @@ namespace Microsoft.ML.Runtime.Internal.Tools { - public sealed class CSharpApiGenerator : IGenerator + internal sealed class CSharpApiGenerator : IGenerator { public sealed class Arguments { diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index 3a0c496868..53e6ba543b 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -102,7 +102,7 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGb InitParallelTraining(); } - protected override TModel TrainModelCore(TrainContext context) + private protected override TModel TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); diff --git a/src/Microsoft.ML.Maml/ChainCommand.cs b/src/Microsoft.ML.Maml/ChainCommand.cs index e49162b45b..c51a78952b 100644 --- a/src/Microsoft.ML.Maml/ChainCommand.cs +++ b/src/Microsoft.ML.Maml/ChainCommand.cs @@ -15,7 +15,8 @@ namespace Microsoft.ML.Runtime.Tools { using Stopwatch = System.Diagnostics.Stopwatch; - public sealed class ChainCommand : ICommand + [BestFriend] + internal sealed class ChainCommand : ICommand { public sealed class Arguments { diff --git a/src/Microsoft.ML.Maml/HelpCommand.cs b/src/Microsoft.ML.Maml/HelpCommand.cs index ee2adbb495..80ddd63e47 100644 --- a/src/Microsoft.ML.Maml/HelpCommand.cs +++ b/src/Microsoft.ML.Maml/HelpCommand.cs @@ -23,14 +23,15 @@ namespace Microsoft.ML.Runtime.Tools { - public interface IGenerator + [BestFriend] + internal interface IGenerator { void Generate(IEnumerable infos); } public delegate void SignatureModuleGenerator(string regenerate); - public sealed class HelpCommand : ICommand + internal sealed class HelpCommand : ICommand { public sealed class Arguments { @@ -100,7 +101,9 @@ public void Run() public void Run(int? columns) { +#pragma warning disable CS0618 // The help command should be entirely within the command line anyway. AssemblyLoadingUtils.LoadAndRegister(_env, _extraAssemblies); +#pragma warning restore CCS0618 using (var ch = _env.Start("Help")) using (var sw = new StringWriter(CultureInfo.InvariantCulture)) @@ -423,7 +426,7 @@ private void GenerateModule(List components) } } - public sealed class XmlGenerator : IGenerator + internal sealed class XmlGenerator : IGenerator { public sealed class Arguments { diff --git a/src/Microsoft.ML.Maml/MAML.cs b/src/Microsoft.ML.Maml/MAML.cs index ff207b38e4..a4eaa43491 100644 --- a/src/Microsoft.ML.Maml/MAML.cs +++ b/src/Microsoft.ML.Maml/MAML.cs @@ -58,7 +58,9 @@ private static int MainWithProgress(string args) string currentDirectory = Path.GetDirectoryName(typeof(Maml).Module.FullyQualifiedName); using (var env = CreateEnvironment()) +#pragma warning disable CS0618 // This is the command line project, so the usage here is OK. using (AssemblyLoadingUtils.CreateAssemblyRegistrar(env, currentDirectory)) +#pragma warning restore CS0618 using (var progressCancel = new CancellationTokenSource()) { var progressTrackerTask = Task.Run(() => TrackProgress(env, progressCancel.Token)); @@ -107,7 +109,7 @@ private static ConsoleEnvironment CreateEnvironment() /// so we always write . If set to true though, this executable will also print stack traces from the /// marked exceptions as well. /// - internal static int MainCore(ConsoleEnvironment env, string args, bool alwaysPrintStacktrace) + internal static int MainCore(IHostEnvironment env, string args, bool alwaysPrintStacktrace) { // REVIEW: How should extra dlls, tracking, etc be handled? Should the args objects for // all commands derive from a common base? diff --git a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj index 219b6bf0b8..a366fb6f9c 100644 --- a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj +++ b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj @@ -7,10 +7,6 @@ netstandard2.0 - - - - diff --git a/src/Microsoft.ML.Maml/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Maml/Properties/AssemblyInfo.cs index 2ddc9c4ffa..2226c6d7fc 100644 --- a/src/Microsoft.ML.Maml/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Maml/Properties/AssemblyInfo.cs @@ -2,9 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Reflection; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; +using Microsoft.ML; -[assembly: InternalsVisibleTo("Microsoft.ML.TestFramework, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] -[assembly: InternalsVisibleTo("Microsoft.ML.Benchmarks, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] \ No newline at end of file +[assembly: InternalsVisibleTo("Microsoft.ML.TestFramework" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo("Microsoft.ML.Benchmarks" + PublicKey.TestValue)] + +[assembly: InternalsVisibleTo("Microsoft.ML.Legacy" + PublicKey.Value)] +[assembly: InternalsVisibleTo("Microsoft.ML.ResultProcessor" + PublicKey.Value)] diff --git a/src/Microsoft.ML.Maml/VersionCommand.cs b/src/Microsoft.ML.Maml/VersionCommand.cs index a59f01e58b..7e20db2260 100644 --- a/src/Microsoft.ML.Maml/VersionCommand.cs +++ b/src/Microsoft.ML.Maml/VersionCommand.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.Runtime.Tools { - public sealed class VersionCommand : ICommand + internal sealed class VersionCommand : ICommand { internal const string Summary = "Prints the TLC version."; diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index 486bf10eeb..6ee3c8a37e 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -136,7 +136,7 @@ private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featur } //Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9) - protected override PcaPredictor TrainModelCore(TrainContext context) + private protected override PcaPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); diff --git a/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs b/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs index ba4b1e3872..8837ac945b 100644 --- a/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs +++ b/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs @@ -338,7 +338,7 @@ public static AutoInference.LevelDependencyMap ComputeColumnResponsibilities(IDa return mapping; } - public static TlcModule.SweepableParamAttribute[] GetSweepRanges(Type learnerInputType) + internal static TlcModule.SweepableParamAttribute[] GetSweepRanges(Type learnerInputType) { var paramSet = new List(); foreach (var prop in learnerInputType.GetProperties(BindingFlags.Instance | @@ -370,7 +370,7 @@ public static TlcModule.SweepableParamAttribute[] GetSweepRanges(Type learnerInp return paramSet.ToArray(); } - public static IValueGenerator ToIValueGenerator(TlcModule.SweepableParamAttribute attr) + internal static IValueGenerator ToIValueGenerator(TlcModule.SweepableParamAttribute attr) { if (attr is TlcModule.SweepableLongParamAttribute sweepableLongParamAttr) { @@ -430,7 +430,7 @@ private static void SetValue(PropertyInfo pi, IComparable value, object entryPoi /// /// Updates properties of entryPointObj instance based on the values in sweepParams /// - public static bool UpdateProperties(object entryPointObj, TlcModule.SweepableParamAttribute[] sweepParams) + internal static bool UpdateProperties(object entryPointObj, TlcModule.SweepableParamAttribute[] sweepParams) { bool result = true; foreach (var param in sweepParams) @@ -501,7 +501,7 @@ public static void PopulateSweepableParams(RecipeInference.SuggestedRecipe.Sugge } } - public static bool CheckEntryPointStateMatchesParamValues(object entryPointObj, + internal static bool CheckEntryPointStateMatchesParamValues(object entryPointObj, TlcModule.SweepableParamAttribute[] sweepParams) { foreach (var param in sweepParams) @@ -584,7 +584,7 @@ public static IRunResult[] ConvertToRunResults(PipelinePattern[] history, bool i /// Method to convert set of sweepable hyperparameters into instances used /// by the current smart hyperparameter sweepers. /// - public static IComponentFactory[] ConvertToComponentFactories(TlcModule.SweepableParamAttribute[] hps) + internal static IComponentFactory[] ConvertToComponentFactories(TlcModule.SweepableParamAttribute[] hps) { var results = new IComponentFactory[hps.Length]; diff --git a/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs b/src/Microsoft.ML.PipelineInference/GenerateSweepCandidatesCommand.cs similarity index 96% rename from test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs rename to src/Microsoft.ML.PipelineInference/GenerateSweepCandidatesCommand.cs index 61aad5729a..8d03868719 100644 --- a/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs +++ b/src/Microsoft.ML.PipelineInference/GenerateSweepCandidatesCommand.cs @@ -10,23 +10,23 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.MLTesting.Inference; using Microsoft.ML.Runtime.PipelineInference; using Microsoft.ML.Runtime.Sweeper; [assembly: LoadableClass(typeof(GenerateSweepCandidatesCommand), typeof(GenerateSweepCandidatesCommand.Arguments), typeof(SignatureCommand), "Generate Experiment Candidates", "GenerateSweepCandidates", DocName = "command/GenerateSweepCandidates.md")] -namespace Microsoft.ML.Runtime.MLTesting.Inference +namespace Microsoft.ML.Runtime.PipelineInference { /// /// This is a command that takes as an input: /// 1- the schema of a dataset, in the format produced by InferSchema - /// 2- the path to the datafile - /// and generates experiment candidates by combining all the transform recipes suggested for the dataset, with all the learners available for the task. + /// 2- the path to the datafile + /// and generates experiment candidates by combining all the transform recipes suggested for the dataset, with all the learners available for the task. /// - public sealed class GenerateSweepCandidatesCommand : ICommand + internal sealed class GenerateSweepCandidatesCommand : ICommand { +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. public sealed class Arguments { [Argument(ArgumentType.Required, HelpText = "Text file with data to analyze.", ShortName = "data")] @@ -53,6 +53,7 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "If this option is provided, the result RSP are indented and written in separate files for easier visual inspection. Otherwise all the generated RSPs are written to a single file, one RSP per line.")] public bool Indent; } +#pragma warning disable CS0649 private readonly IHost _host; private readonly string _dataFile; diff --git a/src/Microsoft.ML.PipelineInference/InferenceUtils.cs b/src/Microsoft.ML.PipelineInference/InferenceUtils.cs index acff2197d8..9510e94eb3 100644 --- a/src/Microsoft.ML.PipelineInference/InferenceUtils.cs +++ b/src/Microsoft.ML.PipelineInference/InferenceUtils.cs @@ -18,7 +18,7 @@ public static IDataView Take(this IDataView data, int count) { Contracts.CheckValue(data, nameof(data)); // REVIEW: This should take an env as a parameter, not create one. - var env = new ConsoleEnvironment(0); + var env = new MLContext(seed: 0); var take = SkipTakeFilter.Create(env, new SkipTakeFilter.TakeArguments { Count = count }, data); return CacheCore(take, env); } @@ -27,7 +27,7 @@ public static IDataView Cache(this IDataView data) { Contracts.CheckValue(data, nameof(data)); // REVIEW: This should take an env as a parameter, not create one. - return CacheCore(data, new ConsoleEnvironment(0)); + return CacheCore(data, new MLContext(0)); } private static IDataView CacheCore(IDataView data, IHostEnvironment env) diff --git a/src/Microsoft.ML.PipelineInference/Interfaces/IPipelineNode.cs b/src/Microsoft.ML.PipelineInference/Interfaces/IPipelineNode.cs index 30f93d2eb4..a8bee5d046 100644 --- a/src/Microsoft.ML.PipelineInference/Interfaces/IPipelineNode.cs +++ b/src/Microsoft.ML.PipelineInference/Interfaces/IPipelineNode.cs @@ -60,7 +60,7 @@ protected string GetEpName(Type type) return epName; } - protected void PropagateParamSetValues(ParameterSet hyperParams, + private protected void PropagateParamSetValues(ParameterSet hyperParams, TlcModule.SweepableParamAttribute[] sweepParams) { var spMap = sweepParams.ToDictionary(sp => sp.Name); @@ -79,9 +79,9 @@ public sealed class TransformPipelineNode : PipelineNodeBase, IPipelineNode sweepParams = null, CommonInputs.ITrainerInput subTrainerObj = null) { @@ -136,9 +136,10 @@ public sealed class TrainerPipelineNode : PipelineNodeBase, IPipelineNode sweepParams = null, ParameterSet hyperParameterSet = null) { diff --git a/src/Microsoft.ML.PipelineInference/Properties/AssemblyInfo.cs b/src/Microsoft.ML.PipelineInference/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..db1d151fa2 --- /dev/null +++ b/src/Microsoft.ML.PipelineInference/Properties/AssemblyInfo.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Microsoft.ML; + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)] + +[assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.PipelineInference/TextFileContents.cs b/src/Microsoft.ML.PipelineInference/TextFileContents.cs index f3f45cc8b9..df2dbc4a7f 100644 --- a/src/Microsoft.ML.PipelineInference/TextFileContents.cs +++ b/src/Microsoft.ML.PipelineInference/TextFileContents.cs @@ -114,7 +114,7 @@ private static bool TryParseFile(IChannel ch, TextLoader.Arguments args, IMultiS try { // No need to provide information from unsuccessful loader, so we create temporary environment and get information from it in case of success - using (var loaderEnv = new ConsoleEnvironment(0, true)) + using (var loaderEnv = new ConsoleEnvironment(0, verbose: true)) { var messages = new ConcurrentBag(); loaderEnv.AddListener( diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs index 2d545bb9a5..de68f9be1e 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs @@ -244,7 +244,7 @@ public MatrixFactorizationTrainer(IHostEnvironment env, /// Train a matrix factorization model based on training data, validation data, and so on in the given context. /// /// The information collection needed for training. for details. - public override MatrixFactorizationPredictor Train(TrainContext context) + private protected override MatrixFactorizationPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); diff --git a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj index e5610126df..e0f084d70b 100644 --- a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj +++ b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj @@ -7,10 +7,6 @@ true - - - - diff --git a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs index b896e37bf6..85b9234856 100644 --- a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs +++ b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs @@ -151,7 +151,9 @@ public void GetDefaultSettingValues(IHostEnvironment env, string predictorName, /// private Dictionary GetDefaultSettings(IHostEnvironment env, string predictorName, string[] extraAssemblies = null) { +#pragma warning disable CS0618 // The result processor is an internal command line processing utility anyway, so this is, while not great, OK. AssemblyLoadingUtils.LoadAndRegister(env, extraAssemblies); +#pragma warning restore CS0618 var cls = env.ComponentCatalog.GetLoadableClassInfo(predictorName); if (cls == null) @@ -1154,7 +1156,9 @@ public static int Main(string[] args) { string currentDirectory = Path.GetDirectoryName(typeof(ResultProcessor).Module.FullyQualifiedName); using (var env = new ConsoleEnvironment(42)) +#pragma warning disable CS0618 // The result processor is an internal command line processing utility anyway, so this is, while not great, OK. using (AssemblyLoadingUtils.CreateAssemblyRegistrar(env, currentDirectory)) +#pragma warning restore CS0618 return Main(env, args); } @@ -1197,7 +1201,9 @@ protected static void Run(IHostEnvironment env, string[] args) if (cmd.IncludePerFoldResults) cmd.PerFoldResultSeparator = "" + PredictionUtil.SepCharFromString(cmd.PerFoldResultSeparator); +#pragma warning disable CS0618 // The result processor is an internal command line processing utility anyway, so this is, while not great, OK. AssemblyLoadingUtils.LoadAndRegister(env, cmd.ExtraAssemblies); +#pragma warning restore CS0618 if (cmd.Metrics.Length == 0) cmd.Metrics = null; diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index b730050a10..756c1964a6 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -429,7 +429,7 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress return new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned); } - public override FieldAwareFactorizationMachinePredictor Train(TrainContext context) + private protected override FieldAwareFactorizationMachinePredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachinePredictor; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 04f272683c..14dd0692cb 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -378,7 +378,7 @@ protected virtual void PreTrainingProcessInstance(float label, in VBuffer /// /// The basic training calls the optimizer /// - protected override TModel TrainModelCore(TrainContext context) + private protected override TModel TrainModelCore(TrainContext context) { Contracts.CheckValue(context, nameof(context)); Host.CheckParam(context.InitialPredictor == null || context.InitialPredictor is TModel, nameof(context.InitialPredictor)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 286fad1063..9da2a2b23d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -135,7 +135,7 @@ protected TScalarTrainer GetTrainer() /// /// The trainig context for this learner. /// The trained model. - public TModel Train(TrainContext context) + TModel ITrainer.Train(TrainContext context) { Host.CheckValue(context, nameof(context)); var data = context.TrainingSet; @@ -216,7 +216,7 @@ private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) return cols; } - IPredictor ITrainer.Train(TrainContext context) => Train(context); + IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context); /// /// Fits the data to the trainer. diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 756607deea..ce6b0566f0 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -91,7 +91,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc protected override MulticlassPredictionTransformer MakeTransformer(MultiClassNaiveBayesPredictor model, Schema trainSchema) => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); - protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context) + private protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var data = context.TrainingSet; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index 1c25986c77..93f2361325 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -254,7 +254,7 @@ private protected static TArgs InvokeAdvanced(Action advancedSetti return args; } - protected sealed override TModel TrainModelCore(TrainContext context) + private protected sealed override TModel TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var initPredictor = context.InitialPredictor; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index 3ac918501d..283d4272b9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -68,7 +68,7 @@ private protected LinearTrainerBase(IHostEnvironment env, string featureColumn, { } - protected override TModel TrainModelCore(TrainContext context) + private protected override TModel TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); using (var ch = Host.Start("Training")) diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index 2640144cc6..ec39fc1e1c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -75,7 +75,7 @@ public BinaryPredictionTransformer Fit(IDataView input) return new BinaryPredictionTransformer(Host, pred, input.Schema, featureColumn: null); } - public override RandomPredictor Train(TrainContext context) + private protected override RandomPredictor Train(TrainContext context) { Host.CheckValue(context, nameof(context)); return new RandomPredictor(Host, Host.Rand.Next()); @@ -273,7 +273,7 @@ public BinaryPredictionTransformer Fit(IDataView input) return new BinaryPredictionTransformer(Host, pred, input.Schema, featureColumn: null); } - public override PriorPredictor Train(TrainContext context) + private protected override PriorPredictor Train(TrainContext context) { Contracts.CheckValue(context, nameof(context)); var data = context.TrainingSet; diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs index fe2e6b752b..d419a3a8a1 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs @@ -28,7 +28,7 @@ public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape private static readonly TrainerInfo _info = new TrainerInfo(); public override TrainerInfo Info => _info; - protected override TModel TrainModelCore(TrainContext context) + private protected override TModel TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); using (var ch = Host.Start("Training")) diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index d78b6eaa7d..f3bc7a362d 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -163,20 +163,20 @@ private FastForestRegressionPredictor FitModel(IEnumerable previousR /// An array of ParamaterSets which are the candidate configurations to sweep. private ParameterSet[] GenerateCandidateConfigurations(int numOfCandidates, IEnumerable previousRuns, FastForestRegressionPredictor forest) { - ParameterSet[] configs = new ParameterSet[numOfCandidates]; - // Get k best previous runs ParameterSets. ParameterSet[] bestKParamSets = GetKBestConfigurations(previousRuns, forest, _args.LocalSearchParentCount); // Perform local searches using the k best previous run configurations. ParameterSet[] eiChallengers = GreedyPlusRandomSearch(bestKParamSets, forest, (int)Math.Ceiling(numOfCandidates / 2.0F), previousRuns); - // Generate another set of random configurations to interleave + // Generate another set of random configurations to interleave. ParameterSet[] randomChallengers = _randomSweeper.ProposeSweeps(numOfCandidates - eiChallengers.Length, previousRuns); - // Return interleaved challenger candidates with random candidates - for (int j = 0; j < configs.Length; j++) - configs[j] = j % 2 == 0 ? eiChallengers[j / 2] : randomChallengers[j / 2]; + // Return interleaved challenger candidates with random candidates. Since the number of candidates from either can be less than + // the number asked for, since we only generate unique candidates, and the number from either method may vary considerably. + ParameterSet[] configs = new ParameterSet[eiChallengers.Length + randomChallengers.Length]; + Array.Copy(eiChallengers, 0, configs, 0, eiChallengers.Length); + Array.Copy(randomChallengers, 0, configs, eiChallengers.Length, randomChallengers.Length); return configs; } diff --git a/src/Microsoft.ML.Sweeper/ConfigRunner.cs b/src/Microsoft.ML.Sweeper/ConfigRunner.cs index 3219d691b1..806af2f5e6 100644 --- a/src/Microsoft.ML.Sweeper/ConfigRunner.cs +++ b/src/Microsoft.ML.Sweeper/ConfigRunner.cs @@ -110,7 +110,9 @@ public virtual void Finish() string currentDirectory = Path.GetDirectoryName(typeof(ExeConfigRunnerBase).Module.FullyQualifiedName); using (var ch = Host.Start("Finish")) +#pragma warning disable CS0618 // As this deals with invoking command lines, this may be OK, though this code has some other problems. using (AssemblyLoadingUtils.CreateAssemblyRegistrar(Host, currentDirectory)) +#pragma warning restore CS0618 { var runs = RunNums.ToArray(); var args = Utils.BuildArray(RunNums.Count + 2, diff --git a/src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj b/src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj index d48a762104..9ed5d25e0e 100644 --- a/src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj +++ b/src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj @@ -13,11 +13,6 @@ - - - - - diff --git a/src/Microsoft.ML.Sweeper/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Sweeper/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..160c67eb55 --- /dev/null +++ b/src/Microsoft.ML.Sweeper/Properties/AssemblyInfo.cs @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Microsoft.ML; + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)] diff --git a/src/Microsoft.ML.Sweeper/SweepCommand.cs b/src/Microsoft.ML.Sweeper/SweepCommand.cs index 51dc28559d..fe61dc3c6e 100644 --- a/src/Microsoft.ML.Sweeper/SweepCommand.cs +++ b/src/Microsoft.ML.Sweeper/SweepCommand.cs @@ -16,8 +16,10 @@ namespace Microsoft.ML.Runtime.Sweeper { - public sealed class SweepCommand : ICommand + [BestFriend] + internal sealed class SweepCommand : ICommand { +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. public sealed class Arguments { [Argument(ArgumentType.Multiple, HelpText = "Config runner", ShortName = "run,ev,evaluator", SignatureType = typeof(SignatureConfigRunner))] @@ -39,6 +41,7 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Random seed", ShortName = "seed")] public int? RandomSeed; } +#pragma warning restore CS0649 internal const string Summary = "Given a command line template and sweep ranges, creates and runs a sweep."; diff --git a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs index 09a054be25..e9a057feb6 100644 --- a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs @@ -21,10 +21,11 @@ namespace Microsoft.ML.Transforms /// is greater than a threshold. /// Instantiates a DropSlots transform to actually drop the slots. /// - public static class LearnerFeatureSelectionTransform + internal static class LearnerFeatureSelectionTransform { internal const string Summary = "Selects the slots for which the absolute value of the corresponding weight in a linear learner is greater than a threshold."; +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. public sealed class Arguments { [Argument(ArgumentType.LastOccurenceWins, HelpText = "If the corresponding absolute value of the weight for a slot is greater than this threshold, the slot is preserved", ShortName = "ft", SortOrder = 2)] @@ -33,6 +34,8 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "The number of slots to preserve", ShortName = "topk", SortOrder = 1)] public int? NumSlotsToKeep; + // If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer, but the utility + // of this would be limited because estimators and transformers now act more or less like this transform used to. [Argument(ArgumentType.Multiple, HelpText = "Filter", ShortName = "f", SortOrder = 1, SignatureType = typeof(SignatureFeatureScorerTrainer))] public IComponentFactory>> Filter = ComponentFactoryUtils.CreateFromFunction(env => @@ -74,6 +77,7 @@ internal void Check(IExceptionContext ectx) ectx.CheckUserArg((NumSlotsToKeep ?? int.MaxValue) > 0, nameof(NumSlotsToKeep), "Must be positive"); } } +#pragma warning restore CS0649 internal static string RegistrationName = "LearnerFeatureSelectionTransform"; diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.json b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.json index 880cd67637..d74ebe1c3f 100644 --- a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.json +++ b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.json @@ -166,24 +166,24 @@ ], "dataType": "FLOAT", "floatData": [ - 0.6232001, - 0.482400715, - 0.495733529, - 0.416533619, - 0.425866365, - 0.5456011, - 0.477333277, - 0.4338674, - 0.20399937, - 0.225407124, - 0.111400835, - 0.109446481, - 0.120521255, - 0.198697388, - 0.121824168, - 0.182410166, - 0.108143583, - 0.107166387 + 0.5522167, + 0.3039403, + 0.319211155, + 0.261575729, + 0.320196062, + 0.344088882, + 0.293349, + 0.273151934, + 0.15763472, + 0.285144627, + 0.332245946, + 0.325724274, + 0.315217048, + 0.328623, + 0.3706516, + 0.41992715, + 0.307970464, + 0.164492577 ], "name": "C" }, @@ -193,8 +193,8 @@ ], "dataType": "FLOAT", "floatData": [ - 1.97708726, - 0.200497344 + 0.9740776, + 0.940771043 ], "name": "C2" }, diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.onnx b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.onnx deleted file mode 100644 index e1dc568fc4..0000000000 Binary files a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.onnx and /dev/null differ diff --git a/test/BaselineOutput/SingleDebug/PCA/pca.tsv b/test/BaselineOutput/SingleDebug/PCA/pca.tsv index ece1e59164..328ff1bf86 100644 --- a/test/BaselineOutput/SingleDebug/PCA/pca.tsv +++ b/test/BaselineOutput/SingleDebug/PCA/pca.tsv @@ -2,7 +2,7 @@ #@ sep=tab #@ col=pca:R4:0-4 #@ } -2.085487 0.09400085 2.58366132 -1.721405 -0.732070744 -0.9069792 0.7748574 0.6097196 1.07868779 0.453838825 --0.167718172 -0.92723 -0.19140324 0.243479848 -1.060547 -0.548309 0.5576686 -0.587472439 -1.38610959 0.9422219 +-2.085465 -0.09400512 -2.58367229 1.72141707 0.732049346 +-0.906982958 -0.774861753 -0.609727442 -1.07867944 -0.453824759 +0.167715371 0.927231133 0.191398591 -0.243467987 1.06056178 +-0.54830873 -0.5576661 0.587476134 1.38609958 -0.9422357 diff --git a/test/BaselineOutput/SingleRelease/PCA/pca.tsv b/test/BaselineOutput/SingleRelease/PCA/pca.tsv index ece1e59164..328ff1bf86 100644 --- a/test/BaselineOutput/SingleRelease/PCA/pca.tsv +++ b/test/BaselineOutput/SingleRelease/PCA/pca.tsv @@ -2,7 +2,7 @@ #@ sep=tab #@ col=pca:R4:0-4 #@ } -2.085487 0.09400085 2.58366132 -1.721405 -0.732070744 -0.9069792 0.7748574 0.6097196 1.07868779 0.453838825 --0.167718172 -0.92723 -0.19140324 0.243479848 -1.060547 -0.548309 0.5576686 -0.587472439 -1.38610959 0.9422219 +-2.085465 -0.09400512 -2.58367229 1.72141707 0.732049346 +-0.906982958 -0.774861753 -0.609727442 -1.07867944 -0.453824759 +0.167715371 0.927231133 0.191398591 -0.243467987 1.06056178 +-0.54830873 -0.5576661 0.587476134 1.38609958 -0.9422357 diff --git a/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs b/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs index ff2ddd1bbe..3b0b055d30 100644 --- a/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs +++ b/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs @@ -5,33 +5,36 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Training; using Microsoft.ML.Transforms; namespace Microsoft.ML.Benchmarks { internal static class EnvironmentFactory { - internal static ConsoleEnvironment CreateClassificationEnvironment() + internal static MLContext CreateClassificationEnvironment() where TLoader : IDataReader where TTransformer : ITransformer - where TTrainer : ITrainer + where TTrainer : ITrainerEstimator, IPredictor> { - var environment = new ConsoleEnvironment(verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance); + var ctx = new MLContext(); + IHostEnvironment environment = ctx; environment.ComponentCatalog.RegisterAssembly(typeof(TLoader).Assembly); environment.ComponentCatalog.RegisterAssembly(typeof(TTransformer).Assembly); environment.ComponentCatalog.RegisterAssembly(typeof(TTrainer).Assembly); - return environment; + return ctx; } - internal static ConsoleEnvironment CreateRankingEnvironment() + internal static MLContext CreateRankingEnvironment() where TEvaluator : IEvaluator where TLoader : IDataReader where TTransformer : ITransformer - where TTrainer : ITrainer + where TTrainer : ITrainerEstimator, IPredictor> { - var environment = new ConsoleEnvironment(verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance); + var ctx = new MLContext(); + IHostEnvironment environment = ctx; environment.ComponentCatalog.RegisterAssembly(typeof(TEvaluator).Assembly); environment.ComponentCatalog.RegisterAssembly(typeof(TLoader).Assembly); @@ -40,7 +43,7 @@ internal static ConsoleEnvironment CreateRankingEnvironment + { + s.HasHeader = true; + s.Separator = ","; + }); - trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures"); - trans = new ColumnConcatenatingTransformer(env, "Features", "NumFeatures", "CatFeatures").Transform(trans); - trans = TrainAndScoreTransformer.Create(env, new TrainAndScoreTransformer.Arguments - { - Trainer = ComponentFactoryUtils.CreateFromFunction(host => - new KMeansPlusPlusTrainer(host, "Features", advancedSettings: s=> - { - s.K = 100; - })), - FeatureColumn = "Features" - }, trans); - trans = new ColumnConcatenatingTransformer(env, "Features", "Features", "Score").Transform(trans); + var estimatorPipeline = ml.Transforms.Categorical.OneHotEncoding("CatFeatures") + .Append(ml.Transforms.Normalize("NumFeatures")) + .Append(ml.Transforms.Concatenate("Features", "NumFeatures", "CatFeatures")) + .Append(ml.Clustering.Trainers.KMeans("Features")) + .Append(ml.Transforms.Concatenate("Features", "Features", "Score")) + .Append(ml.BinaryClassification.Trainers.LogisticRegression(advancedSettings: args => { args.EnforceNonNegativity = true; args.OptTol = 1e-3f; })); - // Train - var trainer = new LogisticRegression(env, "Label", "Features", advancedSettings: args => { args.EnforceNonNegativity = true; args.OptTol = 1e-3f; }); - var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); - return trainer.Train(trainRoles); - } + var model = estimatorPipeline.Fit(input); + // Return the last model in the chain. + return model.LastTransformer.Model; } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Benchmarks/Numeric/Ranking.cs b/test/Microsoft.ML.Benchmarks/Numeric/Ranking.cs index 42ea0c040f..adc1bccdce 100644 --- a/test/Microsoft.ML.Benchmarks/Numeric/Ranking.cs +++ b/test/Microsoft.ML.Benchmarks/Numeric/Ranking.cs @@ -24,7 +24,7 @@ public void SetupTrainingSpeedTests() { _mslrWeb10k_Validate = Path.GetFullPath(TestDatasets.MSLRWeb.validFilename); _mslrWeb10k_Train = Path.GetFullPath(TestDatasets.MSLRWeb.trainFilename); - + if (!File.Exists(_mslrWeb10k_Validate)) throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _mslrWeb10k_Validate)); @@ -42,10 +42,8 @@ public void TrainTest_Ranking_MSLRWeb10K_RawNumericFeatures_FastTreeRanking() " xf=HashTransform{col=GroupId} xf=NAHandleTransform{col=Features}" + " tr=FastTreeRanking{}"; - using (var environment = EnvironmentFactory.CreateRankingEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateRankingEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } [Benchmark] @@ -59,10 +57,8 @@ public void TrainTest_Ranking_MSLRWeb10K_RawNumericFeatures_LightGBMRanking() " xf=NAHandleTransform{col=Features}" + " tr=LightGBMRanking{}"; - using (var environment = EnvironmentFactory.CreateRankingEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateRankingEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } } @@ -100,10 +96,8 @@ public void SetupScoringSpeedTests() " tr=FastTreeRanking{}" + " out={" + _modelPath_MSLR + "}"; - using (var environment = EnvironmentFactory.CreateRankingEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateRankingEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } [Benchmark] @@ -112,10 +106,8 @@ public void Test_Ranking_MSLRWeb10K_RawNumericFeatures_FastTreeRanking() // This benchmark is profiling bulk scoring speed and not training speed. string cmd = @"Test data=" + _mslrWeb10k_Test + " in=" + _modelPath_MSLR; - using (var environment = EnvironmentFactory.CreateRankingEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateRankingEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } } } diff --git a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs index 4d728ab45b..39342cf224 100644 --- a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs +++ b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs @@ -36,32 +36,30 @@ public void SetupIrisPipeline() string _irisDataPath = Program.GetInvariantCultureDataPath("iris.txt"); - using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance)) - { - var reader = new TextLoader(env, - new TextLoader.Arguments() + var env = new MLContext(seed: 1, conc: 1); + var reader = new TextLoader(env, + new TextLoader.Arguments() + { + Separator = "\t", + HasHeader = true, + Column = new[] { - Separator = "\t", - HasHeader = true, - Column = new[] - { new TextLoader.Column("Label", DataKind.R4, 0), new TextLoader.Column("SepalLength", DataKind.R4, 1), new TextLoader.Column("SepalWidth", DataKind.R4, 2), new TextLoader.Column("PetalLength", DataKind.R4, 3), new TextLoader.Column("PetalWidth", DataKind.R4, 4), - } - }); + } + }); - IDataView data = reader.Read(_irisDataPath); + IDataView data = reader.Read(_irisDataPath); - var pipeline = new ColumnConcatenatingEstimator (env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" }) - .Append(new SdcaMultiClassTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; })); + var pipeline = new ColumnConcatenatingEstimator(env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" }) + .Append(new SdcaMultiClassTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; })); - var model = pipeline.Fit(data); + var model = pipeline.Fit(data); - _irisModel = model.MakePredictionFunction(env); - } + _irisModel = model.MakePredictionFunction(env); } [GlobalSetup(Target = nameof(MakeSentimentPredictions))] @@ -74,9 +72,8 @@ public void SetupSentimentPipeline() string _sentimentDataPath = Program.GetInvariantCultureDataPath("wikipedia-detox-250-line-data.tsv"); - using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance)) - { - var reader = new TextLoader(env, + var env = new MLContext(seed: 1, conc: 1); + var reader = new TextLoader(env, new TextLoader.Arguments() { Separator = "\t", @@ -88,15 +85,14 @@ public void SetupSentimentPipeline() } }); - IDataView data = reader.Read(_sentimentDataPath); + IDataView data = reader.Read(_sentimentDataPath); - var pipeline = new TextFeaturizingEstimator(env, "SentimentText", "Features") - .Append(new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; })); + var pipeline = new TextFeaturizingEstimator(env, "SentimentText", "Features") + .Append(new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; })); - var model = pipeline.Fit(data); + var model = pipeline.Fit(data); - _sentimentModel = model.MakePredictionFunction(env); - } + _sentimentModel = model.MakePredictionFunction(env); } [GlobalSetup(Target = nameof(MakeBreastCancerPredictions))] @@ -109,9 +105,8 @@ public void SetupBreastCancerPipeline() string _breastCancerDataPath = Program.GetInvariantCultureDataPath("breast-cancer.txt"); - using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance)) - { - var reader = new TextLoader(env, + var env = new MLContext(seed: 1, conc: 1); + var reader = new TextLoader(env, new TextLoader.Arguments() { Separator = "\t", @@ -123,14 +118,13 @@ public void SetupBreastCancerPipeline() } }); - IDataView data = reader.Read(_breastCancerDataPath); + IDataView data = reader.Read(_breastCancerDataPath); - var pipeline = new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; }); + var pipeline = new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; }); - var model = pipeline.Fit(data); + var model = pipeline.Fit(data); - _breastCancerModel = model.MakePredictionFunction(env); - } + _breastCancerModel = model.MakePredictionFunction(env); } [Benchmark] diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index 4227c8708f..7e2e5f8702 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -63,18 +63,17 @@ private Legacy.PredictionModel Train(string dataPath) [Benchmark] public void TrainSentiment() { - using (var env = new ConsoleEnvironment(seed: 1)) - { - // Pipeline - var loader = TextLoader.ReadFile(env, - new TextLoader.Arguments() + var env = new MLContext(seed: 1); + // Pipeline + var loader = TextLoader.ReadFile(env, + new TextLoader.Arguments() + { + AllowQuoting = false, + AllowSparse = false, + Separator = "tab", + HasHeader = true, + Column = new[] { - AllowQuoting = false, - AllowSparse = false, - Separator = "tab", - HasHeader = true, - Column = new[] - { new TextLoader.Column() { Name = "Label", @@ -88,14 +87,14 @@ public void TrainSentiment() Source = new [] { new TextLoader.Range() { Min=1, Max=1} }, Type = DataKind.Text } - } - }, new MultiFileSource(_sentimentDataPath)); + } + }, new MultiFileSource(_sentimentDataPath)); - var text = TextFeaturizingEstimator.Create(env, - new TextFeaturizingEstimator.Arguments() + var text = TextFeaturizingEstimator.Create(env, + new TextFeaturizingEstimator.Arguments() + { + Column = new TextFeaturizingEstimator.Column { - Column = new TextFeaturizingEstimator.Column - { Name = "WordEmbeddings", Source = new[] { "SentimentText" } }, @@ -121,13 +120,10 @@ public void TrainSentiment() ModelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe, }, text); - // Train - var trainer = new SdcaMultiClassTrainer(env, "Label", "Features", maxIterations: 20); - var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); - - var predicted = trainer.Train(trainRoles); - _consumer.Consume(predicted); - } + // Train + var trainer = new SdcaMultiClassTrainer(env, "Label", "Features", maxIterations: 20); + var predicted = trainer.Fit(trans); + _consumer.Consume(predicted); } [GlobalSetup(Targets = new string[] { nameof(PredictIris), nameof(PredictIrisBatchOf1), nameof(PredictIrisBatchOf2), nameof(PredictIrisBatchOf5) })] diff --git a/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs b/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs index 8918c7b2e5..d7705aaadb 100644 --- a/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs +++ b/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs @@ -25,13 +25,13 @@ public void SetupTrainingSpeedTests() _dataPath_Wiki = Path.GetFullPath(TestDatasets.WikiDetox.trainFilename); if (!File.Exists(_dataPath_Wiki)) - throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _dataPath_Wiki)); + throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _dataPath_Wiki)); } [Benchmark] public void CV_Multiclass_WikiDetox_BigramsAndTrichar_OVAAveragedPerceptron() { - string cmd = @"CV k=5 data=" + _dataPath_Wiki + + string cmd = @"CV k=5 data=" + _dataPath_Wiki + " loader=TextLoader{quote=- sparse=- col=Label:R4:0 col=rev_id:TX:1 col=comment:TX:2 col=logged_in:BL:4 col=ns:TX:5 col=sample:TX:6 col=split:TX:7 col=year:R4:3 header=+}" + " xf=Convert{col=logged_in type=R4}" + " xf=CategoricalTransform{col=ns}" + @@ -39,10 +39,8 @@ public void CV_Multiclass_WikiDetox_BigramsAndTrichar_OVAAveragedPerceptron() " xf=Concat{col=Features:FeaturesText,logged_in,ns}" + " tr=OVA{p=AveragedPerceptron{iter=10}}"; - using (var environment = EnvironmentFactory.CreateClassificationEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateClassificationEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } [Benchmark] @@ -56,10 +54,8 @@ public void CV_Multiclass_WikiDetox_BigramsAndTrichar_LightGBMMulticlass() " xf=Concat{col=Features:FeaturesText,logged_in,ns}" + " tr=LightGBMMulticlass{iter=10}"; - using (var environment = EnvironmentFactory.CreateClassificationEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateClassificationEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } [Benchmark] @@ -74,10 +70,8 @@ public void CV_Multiclass_WikiDetox_WordEmbeddings_OVAAveragedPerceptron() " xf=WordEmbeddingsTransform{col=FeaturesWordEmbedding:FeaturesText_TransformedText model=FastTextWikipedia300D}" + " xf=Concat{col=Features:FeaturesText,FeaturesWordEmbedding,logged_in,ns}"; - using (var environment = EnvironmentFactory.CreateClassificationEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateClassificationEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } [Benchmark] @@ -92,10 +86,8 @@ public void CV_Multiclass_WikiDetox_WordEmbeddings_SDCAMC() " xf=WordEmbeddingsTransform{col=FeaturesWordEmbedding:FeaturesText_TransformedText model=FastTextWikipedia300D}" + " xf=Concat{col=Features:FeaturesWordEmbedding,logged_in,ns}"; - using (var environment = EnvironmentFactory.CreateClassificationEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateClassificationEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } } @@ -122,10 +114,8 @@ public void SetupScoringSpeedTests() " tr=OVA{p=AveragedPerceptron{iter=10}}" + " out={" + _modelPath_Wiki + "}"; - using (var environment = EnvironmentFactory.CreateClassificationEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateClassificationEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } [Benchmark] @@ -135,10 +125,8 @@ public void Test_Multiclass_WikiDetox_BigramsAndTrichar_OVAAveragedPerceptron() string modelpath = Path.Combine(Directory.GetCurrentDirectory(), @"WikiModel.fold000.zip"); string cmd = @"Test data=" + _dataPath_Wiki + " in=" + modelpath; - using (var environment = EnvironmentFactory.CreateClassificationEnvironment()) - { - Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); - } + var environment = EnvironmentFactory.CreateClassificationEnvironment(); + Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false); } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 5f6fdb1620..30cb1fcc99 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -25,99 +25,95 @@ public TestCSharpApi(ITestOutputHelper output) : base(output) public void TestSimpleExperiment() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - using (var env = new ConsoleEnvironment()) - { - var experiment = env.CreateExperiment(); + var env = new MLContext(); + var experiment = env.CreateExperiment(); - var importInput = new Legacy.Data.TextLoader(dataPath); - var importOutput = experiment.Add(importInput); + var importInput = new Legacy.Data.TextLoader(dataPath); + var importOutput = experiment.Add(importInput); - var normalizeInput = new Legacy.Transforms.MinMaxNormalizer - { - Data = importOutput.Data - }; - normalizeInput.AddColumn("NumericFeatures"); - var normalizeOutput = experiment.Add(normalizeInput); - - experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); - experiment.Run(); - var data = experiment.GetOutput(normalizeOutput.OutputData); - - var schema = data.Schema; - Assert.Equal(5, schema.ColumnCount); - var expected = new[] { "Label", "Workclass", "Categories", "NumericFeatures", "NumericFeatures" }; - for (int i = 0; i < schema.ColumnCount; i++) - Assert.Equal(expected[i], schema.GetColumnName(i)); - } + var normalizeInput = new Legacy.Transforms.MinMaxNormalizer + { + Data = importOutput.Data + }; + normalizeInput.AddColumn("NumericFeatures"); + var normalizeOutput = experiment.Add(normalizeInput); + + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(normalizeOutput.OutputData); + + var schema = data.Schema; + Assert.Equal(5, schema.ColumnCount); + var expected = new[] { "Label", "Workclass", "Categories", "NumericFeatures", "NumericFeatures" }; + for (int i = 0; i < schema.ColumnCount; i++) + Assert.Equal(expected[i], schema.GetColumnName(i)); } [Fact] public void TestSimpleTrainExperiment() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - using (var env = new ConsoleEnvironment()) - { - var experiment = env.CreateExperiment(); - - var importInput = new Legacy.Data.TextLoader(dataPath); - var importOutput = experiment.Add(importInput); - - var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer - { - Data = importOutput.Data - }; - catInput.AddColumn("Categories"); - var catOutput = experiment.Add(catInput); + var env = new MLContext(); + var experiment = env.CreateExperiment(); - var concatInput = new Legacy.Transforms.ColumnConcatenator - { - Data = catOutput.OutputData - }; - concatInput.AddColumn("Features", "Categories", "NumericFeatures"); - var concatOutput = experiment.Add(concatInput); - - var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier - { - TrainingData = concatOutput.OutputData, - LossFunction = new HingeLossSDCAClassificationLossFunction() { Margin = 1.1f }, - NumThreads = 1, - Shuffle = false - }; - var sdcaOutput = experiment.Add(sdcaInput); + var importInput = new Legacy.Data.TextLoader(dataPath); + var importOutput = experiment.Add(importInput); - var scoreInput = new Legacy.Transforms.DatasetScorer - { - Data = concatOutput.OutputData, - PredictorModel = sdcaOutput.PredictorModel - }; - var scoreOutput = experiment.Add(scoreInput); + var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer + { + Data = importOutput.Data + }; + catInput.AddColumn("Categories"); + var catOutput = experiment.Add(catInput); - var evalInput = new Legacy.Models.BinaryClassificationEvaluator - { - Data = scoreOutput.ScoredData - }; - var evalOutput = experiment.Add(evalInput); + var concatInput = new Legacy.Transforms.ColumnConcatenator + { + Data = catOutput.OutputData + }; + concatInput.AddColumn("Features", "Categories", "NumericFeatures"); + var concatOutput = experiment.Add(concatInput); - experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); - experiment.Run(); - var data = experiment.GetOutput(evalOutput.OverallMetrics); + var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier + { + TrainingData = concatOutput.OutputData, + LossFunction = new HingeLossSDCAClassificationLossFunction() { Margin = 1.1f }, + NumThreads = 1, + Shuffle = false + }; + var sdcaOutput = experiment.Add(sdcaInput); + + var scoreInput = new Legacy.Transforms.DatasetScorer + { + Data = concatOutput.OutputData, + PredictorModel = sdcaOutput.PredictorModel + }; + var scoreOutput = experiment.Add(scoreInput); - var schema = data.Schema; - var b = schema.TryGetColumnIndex("AUC", out int aucCol); + var evalInput = new Legacy.Models.BinaryClassificationEvaluator + { + Data = scoreOutput.ScoredData + }; + var evalOutput = experiment.Add(evalInput); + + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(evalOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("AUC", out int aucCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == aucCol)) + { + var getter = cursor.GetGetter(aucCol); + b = cursor.MoveNext(); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == aucCol)) - { - var getter = cursor.GetGetter(aucCol); - b = cursor.MoveNext(); - Assert.True(b); - double auc = 0; - getter(ref auc); - Assert.Equal(0.93, auc, 2); - b = cursor.MoveNext(); - Assert.False(b); - } + double auc = 0; + getter(ref auc); + Assert.Equal(0.93, auc, 2); + b = cursor.MoveNext(); + Assert.False(b); } } @@ -125,71 +121,69 @@ public void TestSimpleTrainExperiment() public void TestTrainTestMacro() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - using (var env = new ConsoleEnvironment()) - { - var subGraph = env.CreateExperiment(); - - var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer(); - catInput.AddColumn("Categories"); - var catOutput = subGraph.Add(catInput); - - var concatInput = new Legacy.Transforms.ColumnConcatenator - { - Data = catOutput.OutputData - }; - concatInput.AddColumn("Features", "Categories", "NumericFeatures"); - var concatOutput = subGraph.Add(concatInput); + var env = new MLContext(); + var subGraph = env.CreateExperiment(); - var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier - { - TrainingData = concatOutput.OutputData, - LossFunction = new HingeLossSDCAClassificationLossFunction() { Margin = 1.1f }, - NumThreads = 1, - Shuffle = false - }; - var sdcaOutput = subGraph.Add(sdcaInput); - - var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner - { - TransformModels = new ArrayVar(catOutput.Model, concatOutput.Model), - PredictorModel = sdcaOutput.PredictorModel - }; - var modelCombineOutput = subGraph.Add(modelCombine); + var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer(); + catInput.AddColumn("Categories"); + var catOutput = subGraph.Add(catInput); - var experiment = env.CreateExperiment(); + var concatInput = new Legacy.Transforms.ColumnConcatenator + { + Data = catOutput.OutputData + }; + concatInput.AddColumn("Features", "Categories", "NumericFeatures"); + var concatOutput = subGraph.Add(concatInput); - var importInput = new Legacy.Data.TextLoader(dataPath); - var importOutput = experiment.Add(importInput); + var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier + { + TrainingData = concatOutput.OutputData, + LossFunction = new HingeLossSDCAClassificationLossFunction() { Margin = 1.1f }, + NumThreads = 1, + Shuffle = false + }; + var sdcaOutput = subGraph.Add(sdcaInput); + + var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(catOutput.Model, concatOutput.Model), + PredictorModel = sdcaOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); - var trainTestInput = new Legacy.Models.TrainTestBinaryEvaluator - { - TrainingData = importOutput.Data, - TestingData = importOutput.Data, - Nodes = subGraph - }; - trainTestInput.Inputs.Data = catInput.Data; - trainTestInput.Outputs.Model = modelCombineOutput.PredictorModel; - var trainTestOutput = experiment.Add(trainTestInput); + var experiment = env.CreateExperiment(); - experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); - experiment.Run(); - var data = experiment.GetOutput(trainTestOutput.OverallMetrics); + var importInput = new Legacy.Data.TextLoader(dataPath); + var importOutput = experiment.Add(importInput); - var schema = data.Schema; - var b = schema.TryGetColumnIndex("AUC", out int aucCol); + var trainTestInput = new Legacy.Models.TrainTestBinaryEvaluator + { + TrainingData = importOutput.Data, + TestingData = importOutput.Data, + Nodes = subGraph + }; + trainTestInput.Inputs.Data = catInput.Data; + trainTestInput.Outputs.Model = modelCombineOutput.PredictorModel; + var trainTestOutput = experiment.Add(trainTestInput); + + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(trainTestOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("AUC", out int aucCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == aucCol)) + { + var getter = cursor.GetGetter(aucCol); + b = cursor.MoveNext(); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == aucCol)) - { - var getter = cursor.GetGetter(aucCol); - b = cursor.MoveNext(); - Assert.True(b); - double auc = 0; - getter(ref auc); - Assert.Equal(0.93, auc, 2); - b = cursor.MoveNext(); - Assert.False(b); - } + double auc = 0; + getter(ref auc); + Assert.Equal(0.93, auc, 2); + b = cursor.MoveNext(); + Assert.False(b); } } @@ -197,68 +191,66 @@ public void TestTrainTestMacro() public void TestCrossValidationBinaryMacro() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); - using (var env = new ConsoleEnvironment()) - { - var subGraph = env.CreateExperiment(); - - var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer(); - catInput.AddColumn("Categories"); - var catOutput = subGraph.Add(catInput); + var env = new MLContext(); + var subGraph = env.CreateExperiment(); - var concatInput = new Legacy.Transforms.ColumnConcatenator - { - Data = catOutput.OutputData - }; - concatInput.AddColumn("Features", "Categories", "NumericFeatures"); - var concatOutput = subGraph.Add(concatInput); - - var lrInput = new Legacy.Trainers.LogisticRegressionBinaryClassifier - { - TrainingData = concatOutput.OutputData, - NumThreads = 1 - }; - var lrOutput = subGraph.Add(lrInput); + var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer(); + catInput.AddColumn("Categories"); + var catOutput = subGraph.Add(catInput); - var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner - { - TransformModels = new ArrayVar(catOutput.Model, concatOutput.Model), - PredictorModel = lrOutput.PredictorModel - }; - var modelCombineOutput = subGraph.Add(modelCombine); + var concatInput = new Legacy.Transforms.ColumnConcatenator + { + Data = catOutput.OutputData + }; + concatInput.AddColumn("Features", "Categories", "NumericFeatures"); + var concatOutput = subGraph.Add(concatInput); - var experiment = env.CreateExperiment(); + var lrInput = new Legacy.Trainers.LogisticRegressionBinaryClassifier + { + TrainingData = concatOutput.OutputData, + NumThreads = 1 + }; + var lrOutput = subGraph.Add(lrInput); - var importInput = new Legacy.Data.TextLoader(dataPath); - var importOutput = experiment.Add(importInput); + var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(catOutput.Model, concatOutput.Model), + PredictorModel = lrOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); - var crossValidateBinary = new Legacy.Models.BinaryCrossValidator - { - Data = importOutput.Data, - Nodes = subGraph - }; - crossValidateBinary.Inputs.Data = catInput.Data; - crossValidateBinary.Outputs.Model = modelCombineOutput.PredictorModel; - var crossValidateOutput = experiment.Add(crossValidateBinary); + var experiment = env.CreateExperiment(); - experiment.Compile(); - importInput.SetInput(env, experiment); - experiment.Run(); - var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); + var importInput = new Legacy.Data.TextLoader(dataPath); + var importOutput = experiment.Add(importInput); - var schema = data.Schema; - var b = schema.TryGetColumnIndex("AUC", out int aucCol); + var crossValidateBinary = new Legacy.Models.BinaryCrossValidator + { + Data = importOutput.Data, + Nodes = subGraph + }; + crossValidateBinary.Inputs.Data = catInput.Data; + crossValidateBinary.Outputs.Model = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidateBinary); + + experiment.Compile(); + importInput.SetInput(env, experiment); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("AUC", out int aucCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == aucCol)) + { + var getter = cursor.GetGetter(aucCol); + b = cursor.MoveNext(); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == aucCol)) - { - var getter = cursor.GetGetter(aucCol); - b = cursor.MoveNext(); - Assert.True(b); - double auc = 0; - getter(ref auc); - Assert.Equal(0.87, auc, 1); - b = cursor.MoveNext(); - Assert.False(b); - } + double auc = 0; + getter(ref auc); + Assert.Equal(0.87, auc, 1); + b = cursor.MoveNext(); + Assert.False(b); } } @@ -266,42 +258,41 @@ public void TestCrossValidationBinaryMacro() public void TestCrossValidationMacro() { var dataPath = GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename); - using (var env = new ConsoleEnvironment(42)) - { - var subGraph = env.CreateExperiment(); + var env = new MLContext(42); + var subGraph = env.CreateExperiment(); - var nop = new Legacy.Transforms.NoOperation(); - var nopOutput = subGraph.Add(nop); + var nop = new Legacy.Transforms.NoOperation(); + var nopOutput = subGraph.Add(nop); - var generate = new Legacy.Transforms.RandomNumberGenerator(); - generate.Column = new[] { new Legacy.Transforms.GenerateNumberTransformColumn() { Name = "Weight1" } }; - generate.Data = nopOutput.OutputData; - var generateOutput = subGraph.Add(generate); + var generate = new Legacy.Transforms.RandomNumberGenerator(); + generate.Column = new[] { new Legacy.Transforms.GenerateNumberTransformColumn() { Name = "Weight1" } }; + generate.Data = nopOutput.OutputData; + var generateOutput = subGraph.Add(generate); - var learnerInput = new Legacy.Trainers.PoissonRegressor - { - TrainingData = generateOutput.OutputData, - NumThreads = 1, - WeightColumn = "Weight1" - }; - var learnerOutput = subGraph.Add(learnerInput); + var learnerInput = new Legacy.Trainers.PoissonRegressor + { + TrainingData = generateOutput.OutputData, + NumThreads = 1, + WeightColumn = "Weight1" + }; + var learnerOutput = subGraph.Add(learnerInput); - var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner - { - TransformModels = new ArrayVar(nopOutput.Model, generateOutput.Model), - PredictorModel = learnerOutput.PredictorModel - }; - var modelCombineOutput = subGraph.Add(modelCombine); + var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(nopOutput.Model, generateOutput.Model), + PredictorModel = learnerOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); - var experiment = env.CreateExperiment(); - var importInput = new Legacy.Data.TextLoader(dataPath) + var experiment = env.CreateExperiment(); + var importInput = new Legacy.Data.TextLoader(dataPath) + { + Arguments = new Legacy.Data.TextLoaderArguments + { + Separator = new[] { ';' }, + HasHeader = true, + Column = new[] { - Arguments = new Legacy.Data.TextLoaderArguments - { - Separator = new[] { ';' }, - HasHeader = true, - Column = new[] - { new TextLoaderColumn() { Name = "Label", @@ -316,95 +307,94 @@ public void TestCrossValidationMacro() Type = Legacy.Data.DataKind.Num } } - } - }; - var importOutput = experiment.Add(importInput); + } + }; + var importOutput = experiment.Add(importInput); - var crossValidate = new Legacy.Models.CrossValidator + var crossValidate = new Legacy.Models.CrossValidator + { + Data = importOutput.Data, + Nodes = subGraph, + Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureRegressorTrainer, + TransformModel = null, + WeightColumn = "Weight1" + }; + crossValidate.Inputs.Data = nop.Data; + crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + + experiment.Compile(); + importInput.SetInput(env, experiment); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("L1(avg)", out int metricCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + b = schema.TryGetColumnIndex("IsWeighted", out int isWeightedCol); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol || col == isWeightedCol)) + { + var getter = cursor.GetGetter(metricCol); + var foldGetter = cursor.GetGetter>(foldCol); + ReadOnlyMemory fold = default; + var isWeightedGetter = cursor.GetGetter(isWeightedCol); + bool isWeighted = default; + double avg = 0; + double weightedAvg = 0; + for (int w = 0; w < 2; w++) { - Data = importOutput.Data, - Nodes = subGraph, - Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureRegressorTrainer, - TransformModel = null, - WeightColumn = "Weight1" - }; - crossValidate.Inputs.Data = nop.Data; - crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; - var crossValidateOutput = experiment.Add(crossValidate); - - experiment.Compile(); - importInput.SetInput(env, experiment); - experiment.Run(); - var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + // Get the average. + b = cursor.MoveNext(); + Assert.True(b); + if (w == 1) + getter(ref weightedAvg); + else + getter(ref avg); + foldGetter(ref fold); + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold)); + isWeightedGetter(ref isWeighted); + Assert.True(isWeighted == (w == 1)); - var schema = data.Schema; - var b = schema.TryGetColumnIndex("L1(avg)", out int metricCol); - Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out int foldCol); - Assert.True(b); - b = schema.TryGetColumnIndex("IsWeighted", out int isWeightedCol); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol || col == isWeightedCol)) + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); + if (w == 1) + Assert.Equal(1.585, stdev, 3); + else + Assert.Equal(1.39, stdev, 2); + isWeightedGetter(ref isWeighted); + Assert.True(isWeighted == (w == 1)); + } + double sum = 0; + double weightedSum = 0; + for (int f = 0; f < 2; f++) { - var getter = cursor.GetGetter(metricCol); - var foldGetter = cursor.GetGetter>(foldCol); - ReadOnlyMemory fold = default; - var isWeightedGetter = cursor.GetGetter(isWeightedCol); - bool isWeighted = default; - double avg = 0; - double weightedAvg = 0; for (int w = 0; w < 2; w++) { - // Get the average. b = cursor.MoveNext(); Assert.True(b); - if (w == 1) - getter(ref weightedAvg); - else - getter(ref avg); - foldGetter(ref fold); - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold)); - isWeightedGetter(ref isWeighted); - Assert.True(isWeighted == (w == 1)); - - // Get the standard deviation. - b = cursor.MoveNext(); - Assert.True(b); - double stdev = 0; - getter(ref stdev); + double val = 0; + getter(ref val); foldGetter(ref fold); - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); if (w == 1) - Assert.Equal(1.585, stdev, 3); + weightedSum += val; else - Assert.Equal(1.39, stdev, 2); + sum += val; + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold)); isWeightedGetter(ref isWeighted); Assert.True(isWeighted == (w == 1)); } - double sum = 0; - double weightedSum = 0; - for (int f = 0; f < 2; f++) - { - for (int w = 0; w < 2; w++) - { - b = cursor.MoveNext(); - Assert.True(b); - double val = 0; - getter(ref val); - foldGetter(ref fold); - if (w == 1) - weightedSum += val; - else - sum += val; - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold)); - isWeightedGetter(ref isWeighted); - Assert.True(isWeighted == (w == 1)); - } - } - Assert.Equal(weightedAvg, weightedSum / 2); - Assert.Equal(avg, sum / 2); - b = cursor.MoveNext(); - Assert.False(b); } + Assert.Equal(weightedAvg, weightedSum / 2); + Assert.Equal(avg, sum / 2); + b = cursor.MoveNext(); + Assert.False(b); } } @@ -412,207 +402,203 @@ public void TestCrossValidationMacro() public void TestCrossValidationMacroWithMultiClass() { var dataPath = GetDataPath(@"Train-Tiny-28x28.txt"); - using (var env = new ConsoleEnvironment(42)) - { - var subGraph = env.CreateExperiment(); + var env = new MLContext(42); + var subGraph = env.CreateExperiment(); - var nop = new Legacy.Transforms.NoOperation(); - var nopOutput = subGraph.Add(nop); + var nop = new Legacy.Transforms.NoOperation(); + var nopOutput = subGraph.Add(nop); - var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentClassifier - { - TrainingData = nopOutput.OutputData, - NumThreads = 1 - }; - var learnerOutput = subGraph.Add(learnerInput); - - var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner - { - TransformModels = new ArrayVar(nopOutput.Model), - PredictorModel = learnerOutput.PredictorModel - }; - var modelCombineOutput = subGraph.Add(modelCombine); + var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentClassifier + { + TrainingData = nopOutput.OutputData, + NumThreads = 1 + }; + var learnerOutput = subGraph.Add(learnerInput); - var experiment = env.CreateExperiment(); - var importInput = new Legacy.Data.TextLoader(dataPath); - var importOutput = experiment.Add(importInput); + var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(nopOutput.Model), + PredictorModel = learnerOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); - var crossValidate = new Legacy.Models.CrossValidator - { - Data = importOutput.Data, - Nodes = subGraph, - Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer, - TransformModel = null - }; - crossValidate.Inputs.Data = nop.Data; - crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; - var crossValidateOutput = experiment.Add(crossValidate); + var experiment = env.CreateExperiment(); + var importInput = new Legacy.Data.TextLoader(dataPath); + var importOutput = experiment.Add(importInput); - experiment.Compile(); - importInput.SetInput(env, experiment); - experiment.Run(); - var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + var crossValidate = new Legacy.Models.CrossValidator + { + Data = importOutput.Data, + Nodes = subGraph, + Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer, + TransformModel = null + }; + crossValidate.Inputs.Data = nop.Data; + crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + + experiment.Compile(); + importInput.SetInput(env, experiment); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("Accuracy(micro-avg)", out int metricCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + { + var getter = cursor.GetGetter(metricCol); + var foldGetter = cursor.GetGetter>(foldCol); + ReadOnlyMemory fold = default; - var schema = data.Schema; - var b = schema.TryGetColumnIndex("Accuracy(micro-avg)", out int metricCol); - Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + // Get the average. + b = cursor.MoveNext(); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) - { - var getter = cursor.GetGetter(metricCol); - var foldGetter = cursor.GetGetter>(foldCol); - ReadOnlyMemory fold = default; + double avg = 0; + getter(ref avg); + foldGetter(ref fold); + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold)); - // Get the average. - b = cursor.MoveNext(); - Assert.True(b); - double avg = 0; - getter(ref avg); - foldGetter(ref fold); - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold)); + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); + Assert.Equal(0.015, stdev, 3); - // Get the standard deviation. + double sum = 0; + double val = 0; + for (int f = 0; f < 2; f++) + { b = cursor.MoveNext(); Assert.True(b); - double stdev = 0; - getter(ref stdev); + getter(ref val); foldGetter(ref fold); - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); - Assert.Equal(0.025, stdev, 3); - - double sum = 0; - double val = 0; - for (int f = 0; f < 2; f++) - { - b = cursor.MoveNext(); - Assert.True(b); - getter(ref val); - foldGetter(ref fold); - sum += val; - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold)); - } - Assert.Equal(avg, sum / 2); - b = cursor.MoveNext(); - Assert.False(b); + sum += val; + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold)); } + Assert.Equal(avg, sum / 2); + b = cursor.MoveNext(); + Assert.False(b); + } - var confusion = experiment.GetOutput(crossValidateOutput.ConfusionMatrix); - schema = confusion.Schema; - b = schema.TryGetColumnIndex("Count", out int countCol); - Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out foldCol); - Assert.True(b); - var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol); - Assert.True(type is VectorType vecType && vecType.ItemType is TextType && vecType.Size == 10); - var slotNames = default(VBuffer>); - schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref slotNames); - Assert.True(slotNames.Values.Select((s, i) => ReadOnlyMemoryUtils.EqualsStr(i.ToString(), s)).All(x => x)); - using (var curs = confusion.GetRowCursor(col => true)) - { - var countGetter = curs.GetGetter>(countCol); - var foldGetter = curs.GetGetter>(foldCol); - var confCount = default(VBuffer); - var foldIndex = default(ReadOnlyMemory); - int rowCount = 0; - var foldCur = "Fold 0"; - while (curs.MoveNext()) + var confusion = experiment.GetOutput(crossValidateOutput.ConfusionMatrix); + schema = confusion.Schema; + b = schema.TryGetColumnIndex("Count", out int countCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out foldCol); + Assert.True(b); + var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol); + Assert.True(type is VectorType vecType && vecType.ItemType is TextType && vecType.Size == 10); + var slotNames = default(VBuffer>); + schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref slotNames); + Assert.True(slotNames.Values.Select((s, i) => ReadOnlyMemoryUtils.EqualsStr(i.ToString(), s)).All(x => x)); + using (var curs = confusion.GetRowCursor(col => true)) + { + var countGetter = curs.GetGetter>(countCol); + var foldGetter = curs.GetGetter>(foldCol); + var confCount = default(VBuffer); + var foldIndex = default(ReadOnlyMemory); + int rowCount = 0; + var foldCur = "Fold 0"; + while (curs.MoveNext()) + { + countGetter(ref confCount); + foldGetter(ref foldIndex); + rowCount++; + Assert.True(ReadOnlyMemoryUtils.EqualsStr(foldCur, foldIndex)); + if (rowCount == 10) { - countGetter(ref confCount); - foldGetter(ref foldIndex); - rowCount++; - Assert.True(ReadOnlyMemoryUtils.EqualsStr(foldCur, foldIndex)); - if (rowCount == 10) - { - rowCount = 0; - foldCur = "Fold 1"; - } + rowCount = 0; + foldCur = "Fold 1"; } - Assert.Equal(0, rowCount); } - - var warnings = experiment.GetOutput(crossValidateOutput.Warnings); - using (var cursor = warnings.GetRowCursor(col => true)) - Assert.False(cursor.MoveNext()); + Assert.Equal(0, rowCount); } + + var warnings = experiment.GetOutput(crossValidateOutput.Warnings); + using (var cursor = warnings.GetRowCursor(col => true)) + Assert.False(cursor.MoveNext()); } [Fact] public void TestCrossValidationMacroMultiClassWithWarnings() { var dataPath = GetDataPath(@"Train-Tiny-28x28.txt"); - using (var env = new ConsoleEnvironment(42)) - { - var subGraph = env.CreateExperiment(); + var env = new MLContext(42); + var subGraph = env.CreateExperiment(); - var nop = new Legacy.Transforms.NoOperation(); - var nopOutput = subGraph.Add(nop); + var nop = new Legacy.Transforms.NoOperation(); + var nopOutput = subGraph.Add(nop); - var learnerInput = new Legacy.Trainers.LogisticRegressionClassifier - { - TrainingData = nopOutput.OutputData, - NumThreads = 1 - }; - var learnerOutput = subGraph.Add(learnerInput); - - var experiment = env.CreateExperiment(); - var importInput = new Legacy.Data.TextLoader(dataPath); - var importOutput = experiment.Add(importInput); - - var filter = new Legacy.Transforms.RowRangeFilter(); - filter.Data = importOutput.Data; - filter.Column = "Label"; - filter.Min = 0; - filter.Max = 5; - var filterOutput = experiment.Add(filter); - - var term = new Legacy.Transforms.TextToKeyConverter(); - term.Column = new[] - { + var learnerInput = new Legacy.Trainers.LogisticRegressionClassifier + { + TrainingData = nopOutput.OutputData, + NumThreads = 1 + }; + var learnerOutput = subGraph.Add(learnerInput); + + var experiment = env.CreateExperiment(); + var importInput = new Legacy.Data.TextLoader(dataPath); + var importOutput = experiment.Add(importInput); + + var filter = new Legacy.Transforms.RowRangeFilter(); + filter.Data = importOutput.Data; + filter.Column = "Label"; + filter.Min = 0; + filter.Max = 5; + var filterOutput = experiment.Add(filter); + + var term = new Legacy.Transforms.TextToKeyConverter(); + term.Column = new[] + { new Legacy.Transforms.ValueToKeyMappingTransformerColumn() { Source = "Label", Name = "Strat", Sort = Legacy.Transforms.ValueToKeyMappingTransformerSortOrder.Value } }; - term.Data = filterOutput.OutputData; - var termOutput = experiment.Add(term); + term.Data = filterOutput.OutputData; + var termOutput = experiment.Add(term); - var crossValidate = new Legacy.Models.CrossValidator - { - Data = termOutput.OutputData, - Nodes = subGraph, - Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer, - TransformModel = null, - StratificationColumn = "Strat" - }; - crossValidate.Inputs.Data = nop.Data; - crossValidate.Outputs.PredictorModel = learnerOutput.PredictorModel; - var crossValidateOutput = experiment.Add(crossValidate); - - experiment.Compile(); - importInput.SetInput(env, experiment); - experiment.Run(); - var warnings = experiment.GetOutput(crossValidateOutput.Warnings); + var crossValidate = new Legacy.Models.CrossValidator + { + Data = termOutput.OutputData, + Nodes = subGraph, + Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer, + TransformModel = null, + StratificationColumn = "Strat" + }; + crossValidate.Inputs.Data = nop.Data; + crossValidate.Outputs.PredictorModel = learnerOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + + experiment.Compile(); + importInput.SetInput(env, experiment); + experiment.Run(); + var warnings = experiment.GetOutput(crossValidateOutput.Warnings); + + var schema = warnings.Schema; + var b = schema.TryGetColumnIndex("WarningText", out int warningCol); + Assert.True(b); + using (var cursor = warnings.GetRowCursor(col => col == warningCol)) + { + var getter = cursor.GetGetter>(warningCol); - var schema = warnings.Schema; - var b = schema.TryGetColumnIndex("WarningText", out int warningCol); + b = cursor.MoveNext(); Assert.True(b); - using (var cursor = warnings.GetRowCursor(col => col == warningCol)) - { - var getter = cursor.GetGetter>(warningCol); - - b = cursor.MoveNext(); - Assert.True(b); - var warning = default(ReadOnlyMemory); - getter(ref warning); - Assert.Contains("test instances with class values not seen in the training set.", warning.ToString()); - b = cursor.MoveNext(); - Assert.True(b); - getter(ref warning); - Assert.Contains("Detected columns of variable length: SortedScores, SortedClasses", warning.ToString()); - b = cursor.MoveNext(); - Assert.False(b); - } + var warning = default(ReadOnlyMemory); + getter(ref warning); + Assert.Contains("test instances with class values not seen in the training set.", warning.ToString()); + b = cursor.MoveNext(); + Assert.True(b); + getter(ref warning); + Assert.Contains("Detected columns of variable length: SortedScores, SortedClasses", warning.ToString()); + b = cursor.MoveNext(); + Assert.False(b); } } @@ -620,95 +606,93 @@ public void TestCrossValidationMacroMultiClassWithWarnings() public void TestCrossValidationMacroWithStratification() { var dataPath = GetDataPath(@"breast-cancer.txt"); - using (var env = new ConsoleEnvironment(42)) - { - var subGraph = env.CreateExperiment(); + var env = new MLContext(42); + var subGraph = env.CreateExperiment(); - var nop = new Legacy.Transforms.NoOperation(); - var nopOutput = subGraph.Add(nop); + var nop = new Legacy.Transforms.NoOperation(); + var nopOutput = subGraph.Add(nop); - var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier - { - TrainingData = nopOutput.OutputData, - NumThreads = 1 - }; - var learnerOutput = subGraph.Add(learnerInput); - - var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner - { - TransformModels = new ArrayVar(nopOutput.Model), - PredictorModel = learnerOutput.PredictorModel - }; - var modelCombineOutput = subGraph.Add(modelCombine); + var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier + { + TrainingData = nopOutput.OutputData, + NumThreads = 1 + }; + var learnerOutput = subGraph.Add(learnerInput); - var experiment = env.CreateExperiment(); - var importInput = new Legacy.Data.TextLoader(dataPath); - importInput.Arguments.Column = new Legacy.Data.TextLoaderColumn[] - { + var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(nopOutput.Model), + PredictorModel = learnerOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); + + var experiment = env.CreateExperiment(); + var importInput = new Legacy.Data.TextLoader(dataPath); + importInput.Arguments.Column = new Legacy.Data.TextLoaderColumn[] + { new Legacy.Data.TextLoaderColumn { Name = "Label", Source = new[] { new Legacy.Data.TextLoaderRange(0) } }, new Legacy.Data.TextLoaderColumn { Name = "Strat", Source = new[] { new Legacy.Data.TextLoaderRange(1) } }, new Legacy.Data.TextLoaderColumn { Name = "Features", Source = new[] { new Legacy.Data.TextLoaderRange(2, 9) } } - }; - var importOutput = experiment.Add(importInput); + }; + var importOutput = experiment.Add(importInput); - var crossValidate = new Legacy.Models.CrossValidator - { - Data = importOutput.Data, - Nodes = subGraph, - TransformModel = null, - StratificationColumn = "Strat" - }; - crossValidate.Inputs.Data = nop.Data; - crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; - var crossValidateOutput = experiment.Add(crossValidate); - experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); - experiment.Run(); - var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); - - var schema = data.Schema; - var b = schema.TryGetColumnIndex("AUC", out int metricCol); - Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + var crossValidate = new Legacy.Models.CrossValidator + { + Data = importOutput.Data, + Nodes = subGraph, + TransformModel = null, + StratificationColumn = "Strat" + }; + crossValidate.Inputs.Data = nop.Data; + crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("AUC", out int metricCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + { + var getter = cursor.GetGetter(metricCol); + var foldGetter = cursor.GetGetter>(foldCol); + ReadOnlyMemory fold = default; + + // Get the verage. + b = cursor.MoveNext(); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) - { - var getter = cursor.GetGetter(metricCol); - var foldGetter = cursor.GetGetter>(foldCol); - ReadOnlyMemory fold = default; + double avg = 0; + getter(ref avg); + foldGetter(ref fold); + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold)); - // Get the verage. - b = cursor.MoveNext(); - Assert.True(b); - double avg = 0; - getter(ref avg); - foldGetter(ref fold); - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold)); + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); + Assert.Equal(0.00488, stdev, 5); - // Get the standard deviation. + double sum = 0; + double val = 0; + for (int f = 0; f < 2; f++) + { b = cursor.MoveNext(); Assert.True(b); - double stdev = 0; - getter(ref stdev); + getter(ref val); foldGetter(ref fold); - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); - Assert.Equal(0.00485, stdev, 5); - - double sum = 0; - double val = 0; - for (int f = 0; f < 2; f++) - { - b = cursor.MoveNext(); - Assert.True(b); - getter(ref val); - foldGetter(ref fold); - sum += val; - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold)); - } - Assert.Equal(avg, sum / 2); - b = cursor.MoveNext(); - Assert.False(b); + sum += val; + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold)); } + Assert.Equal(avg, sum / 2); + b = cursor.MoveNext(); + Assert.False(b); } } @@ -716,127 +700,125 @@ public void TestCrossValidationMacroWithStratification() public void TestCrossValidationMacroWithNonDefaultNames() { string dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); - using (var env = new ConsoleEnvironment(42)) - { - var subGraph = env.CreateExperiment(); + var env = new MLContext(42); + var subGraph = env.CreateExperiment(); - var textToKey = new Legacy.Transforms.TextToKeyConverter(); - textToKey.Column = new[] { new Legacy.Transforms.ValueToKeyMappingTransformerColumn() { Name = "Label1", Source = "Label" } }; - var textToKeyOutput = subGraph.Add(textToKey); + var textToKey = new Legacy.Transforms.TextToKeyConverter(); + textToKey.Column = new[] { new Legacy.Transforms.ValueToKeyMappingTransformerColumn() { Name = "Label1", Source = "Label" } }; + var textToKeyOutput = subGraph.Add(textToKey); - var hash = new Legacy.Transforms.HashConverter(); - hash.Column = new[] { new Legacy.Transforms.HashJoiningTransformColumn() { Name = "GroupId1", Source = "Workclass" } }; - hash.Data = textToKeyOutput.OutputData; - var hashOutput = subGraph.Add(hash); + var hash = new Legacy.Transforms.HashConverter(); + hash.Column = new[] { new Legacy.Transforms.HashJoiningTransformColumn() { Name = "GroupId1", Source = "Workclass" } }; + hash.Data = textToKeyOutput.OutputData; + var hashOutput = subGraph.Add(hash); - var learnerInput = new Legacy.Trainers.FastTreeRanker - { - TrainingData = hashOutput.OutputData, - NumThreads = 1, - LabelColumn = "Label1", - GroupIdColumn = "GroupId1" - }; - var learnerOutput = subGraph.Add(learnerInput); - - var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner - { - TransformModels = new ArrayVar(textToKeyOutput.Model, hashOutput.Model), - PredictorModel = learnerOutput.PredictorModel - }; - var modelCombineOutput = subGraph.Add(modelCombine); - - var experiment = env.CreateExperiment(); - var importInput = new Legacy.Data.TextLoader(dataPath); - importInput.Arguments.HasHeader = true; - importInput.Arguments.Column = new TextLoaderColumn[] - { + var learnerInput = new Legacy.Trainers.FastTreeRanker + { + TrainingData = hashOutput.OutputData, + NumThreads = 1, + LabelColumn = "Label1", + GroupIdColumn = "GroupId1" + }; + var learnerOutput = subGraph.Add(learnerInput); + + var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(textToKeyOutput.Model, hashOutput.Model), + PredictorModel = learnerOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); + + var experiment = env.CreateExperiment(); + var importInput = new Legacy.Data.TextLoader(dataPath); + importInput.Arguments.HasHeader = true; + importInput.Arguments.Column = new TextLoaderColumn[] + { new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, new TextLoaderColumn { Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = Legacy.Data.DataKind.Text }, new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(9, 14) } } - }; - var importOutput = experiment.Add(importInput); + }; + var importOutput = experiment.Add(importInput); - var crossValidate = new Legacy.Models.CrossValidator - { - Data = importOutput.Data, - Nodes = subGraph, - TransformModel = null, - LabelColumn = "Label1", - GroupColumn = "GroupId1", - NameColumn = "Workclass", - Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureRankerTrainer - }; - crossValidate.Inputs.Data = textToKey.Data; - crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; - var crossValidateOutput = experiment.Add(crossValidate); - experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); - experiment.Run(); - var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); - - var schema = data.Schema; - var b = schema.TryGetColumnIndex("NDCG", out int metricCol); + var crossValidate = new Legacy.Models.CrossValidator + { + Data = importOutput.Data, + Nodes = subGraph, + TransformModel = null, + LabelColumn = "Label1", + GroupColumn = "GroupId1", + NameColumn = "Workclass", + Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureRankerTrainer + }; + crossValidate.Inputs.Data = textToKey.Data; + crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("NDCG", out int metricCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + { + var getter = cursor.GetGetter>(metricCol); + var foldGetter = cursor.GetGetter>(foldCol); + ReadOnlyMemory fold = default; + + // Get the verage. + b = cursor.MoveNext(); Assert.True(b); - b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + var avg = default(VBuffer); + getter(ref avg); + foldGetter(ref fold); + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold)); + + // Get the standard deviation. + b = cursor.MoveNext(); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + var stdev = default(VBuffer); + getter(ref stdev); + foldGetter(ref fold); + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); + Assert.Equal(2.462, stdev.Values[0], 3); + Assert.Equal(2.763, stdev.Values[1], 3); + Assert.Equal(3.273, stdev.Values[2], 3); + + var sumBldr = new BufferBuilder(R8Adder.Instance); + sumBldr.Reset(avg.Length, true); + var val = default(VBuffer); + for (int f = 0; f < 2; f++) { - var getter = cursor.GetGetter>(metricCol); - var foldGetter = cursor.GetGetter>(foldCol); - ReadOnlyMemory fold = default; - - // Get the verage. b = cursor.MoveNext(); Assert.True(b); - var avg = default(VBuffer); - getter(ref avg); + getter(ref val); foldGetter(ref fold); - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold)); - - // Get the standard deviation. - b = cursor.MoveNext(); - Assert.True(b); - var stdev = default(VBuffer); - getter(ref stdev); - foldGetter(ref fold); - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); - Assert.Equal(2.462, stdev.Values[0], 3); - Assert.Equal(2.763, stdev.Values[1], 3); - Assert.Equal(3.273, stdev.Values[2], 3); - - var sumBldr = new BufferBuilder(R8Adder.Instance); - sumBldr.Reset(avg.Length, true); - var val = default(VBuffer); - for (int f = 0; f < 2; f++) - { - b = cursor.MoveNext(); - Assert.True(b); - getter(ref val); - foldGetter(ref fold); - sumBldr.AddFeatures(0, in val); - Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold)); - } - var sum = default(VBuffer); - sumBldr.GetResult(ref sum); - for (int i = 0; i < avg.Length; i++) - Assert.Equal(avg.Values[i], sum.Values[i] / 2); - b = cursor.MoveNext(); - Assert.False(b); + sumBldr.AddFeatures(0, in val); + Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold)); } + var sum = default(VBuffer); + sumBldr.GetResult(ref sum); + for (int i = 0; i < avg.Length; i++) + Assert.Equal(avg.Values[i], sum.Values[i] / 2); + b = cursor.MoveNext(); + Assert.False(b); + } - data = experiment.GetOutput(crossValidateOutput.PerInstanceMetrics); - Assert.True(data.Schema.TryGetColumnIndex("Instance", out int nameCol)); - using (var cursor = data.GetRowCursor(col => col == nameCol)) - { - var getter = cursor.GetGetter>(nameCol); - while (cursor.MoveNext()) - { - ReadOnlyMemory name = default; - getter(ref name); - Assert.Subset(new HashSet() { "Private", "?", "Federal-gov" }, new HashSet() { name.ToString() }); - if (cursor.Position > 4) - break; - } + data = experiment.GetOutput(crossValidateOutput.PerInstanceMetrics); + Assert.True(data.Schema.TryGetColumnIndex("Instance", out int nameCol)); + using (var cursor = data.GetRowCursor(col => col == nameCol)) + { + var getter = cursor.GetGetter>(nameCol); + while (cursor.MoveNext()) + { + ReadOnlyMemory name = default; + getter(ref name); + Assert.Subset(new HashSet() { "Private", "?", "Federal-gov" }, new HashSet() { name.ToString() }); + if (cursor.Position > 4) + break; } } } @@ -845,58 +827,56 @@ public void TestCrossValidationMacroWithNonDefaultNames() public void TestOvaMacro() { var dataPath = GetDataPath(@"iris.txt"); - using (var env = new ConsoleEnvironment(42)) - { - // Specify subgraph for OVA - var subGraph = env.CreateExperiment(); - var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier { NumThreads = 1 }; - var learnerOutput = subGraph.Add(learnerInput); - // Create pipeline with OVA and multiclass scoring. - var experiment = env.CreateExperiment(); - var importInput = new Legacy.Data.TextLoader(dataPath); - importInput.Arguments.Column = new TextLoaderColumn[] - { + var env = new MLContext(42); + // Specify subgraph for OVA + var subGraph = env.CreateExperiment(); + var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier { NumThreads = 1 }; + var learnerOutput = subGraph.Add(learnerInput); + // Create pipeline with OVA and multiclass scoring. + var experiment = env.CreateExperiment(); + var importInput = new Legacy.Data.TextLoader(dataPath); + importInput.Arguments.Column = new TextLoaderColumn[] + { new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } } - }; - var importOutput = experiment.Add(importInput); - var oneVersusAll = new Legacy.Models.OneVersusAll - { - TrainingData = importOutput.Data, - Nodes = subGraph, - UseProbabilities = true, - }; - var ovaOutput = experiment.Add(oneVersusAll); - var scoreInput = new Legacy.Transforms.DatasetScorer - { - Data = importOutput.Data, - PredictorModel = ovaOutput.PredictorModel - }; - var scoreOutput = experiment.Add(scoreInput); - var evalInput = new Legacy.Models.ClassificationEvaluator - { - Data = scoreOutput.ScoredData - }; - var evalOutput = experiment.Add(evalInput); - experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); - experiment.Run(); - - var data = experiment.GetOutput(evalOutput.OverallMetrics); - var schema = data.Schema; - var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol); + }; + var importOutput = experiment.Add(importInput); + var oneVersusAll = new Legacy.Models.OneVersusAll + { + TrainingData = importOutput.Data, + Nodes = subGraph, + UseProbabilities = true, + }; + var ovaOutput = experiment.Add(oneVersusAll); + var scoreInput = new Legacy.Transforms.DatasetScorer + { + Data = importOutput.Data, + PredictorModel = ovaOutput.PredictorModel + }; + var scoreOutput = experiment.Add(scoreInput); + var evalInput = new Legacy.Models.ClassificationEvaluator + { + Data = scoreOutput.ScoredData + }; + var evalOutput = experiment.Add(evalInput); + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + + var data = experiment.GetOutput(evalOutput.OverallMetrics); + var schema = data.Schema; + var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == accCol)) + { + var getter = cursor.GetGetter(accCol); + b = cursor.MoveNext(); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == accCol)) - { - var getter = cursor.GetGetter(accCol); - b = cursor.MoveNext(); - Assert.True(b); - double acc = 0; - getter(ref acc); - Assert.Equal(0.96, acc, 2); - b = cursor.MoveNext(); - Assert.False(b); - } + double acc = 0; + getter(ref acc); + Assert.Equal(0.96, acc, 2); + b = cursor.MoveNext(); + Assert.False(b); } } @@ -904,58 +884,56 @@ public void TestOvaMacro() public void TestOvaMacroWithUncalibratedLearner() { var dataPath = GetDataPath(@"iris.txt"); - using (var env = new ConsoleEnvironment(42)) - { - // Specify subgraph for OVA - var subGraph = env.CreateExperiment(); - var learnerInput = new Legacy.Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false }; - var learnerOutput = subGraph.Add(learnerInput); - // Create pipeline with OVA and multiclass scoring. - var experiment = env.CreateExperiment(); - var importInput = new Legacy.Data.TextLoader(dataPath); - importInput.Arguments.Column = new TextLoaderColumn[] - { + var env = new MLContext(42); + // Specify subgraph for OVA + var subGraph = env.CreateExperiment(); + var learnerInput = new Legacy.Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false }; + var learnerOutput = subGraph.Add(learnerInput); + // Create pipeline with OVA and multiclass scoring. + var experiment = env.CreateExperiment(); + var importInput = new Legacy.Data.TextLoader(dataPath); + importInput.Arguments.Column = new TextLoaderColumn[] + { new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } } - }; - var importOutput = experiment.Add(importInput); - var oneVersusAll = new Legacy.Models.OneVersusAll - { - TrainingData = importOutput.Data, - Nodes = subGraph, - UseProbabilities = true, - }; - var ovaOutput = experiment.Add(oneVersusAll); - var scoreInput = new Legacy.Transforms.DatasetScorer - { - Data = importOutput.Data, - PredictorModel = ovaOutput.PredictorModel - }; - var scoreOutput = experiment.Add(scoreInput); - var evalInput = new Legacy.Models.ClassificationEvaluator - { - Data = scoreOutput.ScoredData - }; - var evalOutput = experiment.Add(evalInput); - experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); - experiment.Run(); - - var data = experiment.GetOutput(evalOutput.OverallMetrics); - var schema = data.Schema; - var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol); + }; + var importOutput = experiment.Add(importInput); + var oneVersusAll = new Legacy.Models.OneVersusAll + { + TrainingData = importOutput.Data, + Nodes = subGraph, + UseProbabilities = true, + }; + var ovaOutput = experiment.Add(oneVersusAll); + var scoreInput = new Legacy.Transforms.DatasetScorer + { + Data = importOutput.Data, + PredictorModel = ovaOutput.PredictorModel + }; + var scoreOutput = experiment.Add(scoreInput); + var evalInput = new Legacy.Models.ClassificationEvaluator + { + Data = scoreOutput.ScoredData + }; + var evalOutput = experiment.Add(evalInput); + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + + var data = experiment.GetOutput(evalOutput.OverallMetrics); + var schema = data.Schema; + var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == accCol)) + { + var getter = cursor.GetGetter(accCol); + b = cursor.MoveNext(); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == accCol)) - { - var getter = cursor.GetGetter(accCol); - b = cursor.MoveNext(); - Assert.True(b); - double acc = 0; - getter(ref acc); - Assert.Equal(0.71, acc, 2); - b = cursor.MoveNext(); - Assert.False(b); - } + double acc = 0; + getter(ref acc); + Assert.Equal(0.71, acc, 2); + b = cursor.MoveNext(); + Assert.False(b); } } @@ -963,37 +941,35 @@ public void TestOvaMacroWithUncalibratedLearner() public void TestTensorFlowEntryPoint() { var dataPath = GetDataPath("Train-Tiny-28x28.txt"); - using (var env = new ConsoleEnvironment(42)) - { - var experiment = env.CreateExperiment(); + var env = new MLContext(42); + var experiment = env.CreateExperiment(); - var importInput = new Legacy.Data.TextLoader(dataPath); - importInput.Arguments.Column = new TextLoaderColumn[] - { + var importInput = new Legacy.Data.TextLoader(dataPath); + importInput.Arguments.Column = new TextLoaderColumn[] + { new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, new TextLoaderColumn { Name = "Placeholder", Source = new[] { new TextLoaderRange(1, 784) } } - }; - var importOutput = experiment.Add(importInput); + }; + var importOutput = experiment.Add(importInput); - var tfTransformInput = new Legacy.Transforms.TensorFlowScorer - { - Data = importOutput.Data, - ModelLocation = "mnist_model/frozen_saved_model.pb", - InputColumns = new[] { "Placeholder" }, - OutputColumns = new[] { "Softmax" }, - }; - var tfTransformOutput = experiment.Add(tfTransformInput); - - experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); - experiment.Run(); - var data = experiment.GetOutput(tfTransformOutput.OutputData); - - var schema = data.Schema; - Assert.Equal(3, schema.ColumnCount); - Assert.Equal("Softmax", schema.GetColumnName(2)); - Assert.Equal(10, (schema.GetColumnType(2) as VectorType)?.Size); - } + var tfTransformInput = new Legacy.Transforms.TensorFlowScorer + { + Data = importOutput.Data, + ModelLocation = "mnist_model/frozen_saved_model.pb", + InputColumns = new[] { "Placeholder" }, + OutputColumns = new[] { "Softmax" }, + }; + var tfTransformOutput = experiment.Add(tfTransformInput); + + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(tfTransformOutput.OutputData); + + var schema = data.Schema; + Assert.Equal(3, schema.ColumnCount); + Assert.Equal("Softmax", schema.GetColumnName(2)); + Assert.Equal(10, (schema.GetColumnType(2) as VectorType)?.Size); } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestContracts.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestContracts.cs index 6b4fd5297f..ca1315b83b 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestContracts.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestContracts.cs @@ -3,8 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; using Xunit; namespace Microsoft.ML.Runtime.RunTests { @@ -42,7 +40,7 @@ private void Helper(IExceptionContext ectx, MessageSensitivity expected) [Fact] public void ExceptionSensitivity() { - var env = new ConsoleEnvironment(); + var env = new MLContext(); // Default sensitivity should be unknown, that is, all bits set. Helper(null, MessageSensitivity.Unknown); // If we set it to be not sensitive, then the messages should be marked insensitive, diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs index 3798246fa8..688287dda2 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs @@ -13,7 +13,7 @@ public sealed class TestEarlyStoppingCriteria { private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter) { - var env = new ConsoleEnvironment() + var env = new MLContext() .AddStandardComponents(); var sub = new SubComponent(name, args); return sub.CreateInstance(env, lowerIsBetter); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs index ce01945bcb..2f1df54285 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs @@ -20,7 +20,7 @@ public class TestHosts [Fact] public void TestCancellation() { - var env = new ConsoleEnvironment(seed: 42); + IHostEnvironment env = new MLContext(seed: 42); for (int z = 0; z < 1000; z++) { var mainHost = env.Register("Main"); diff --git a/test/Microsoft.ML.InferenceTesting/Microsoft.ML.InferenceTesting.csproj b/test/Microsoft.ML.InferenceTesting/Microsoft.ML.InferenceTesting.csproj deleted file mode 100644 index 60b52497a0..0000000000 --- a/test/Microsoft.ML.InferenceTesting/Microsoft.ML.InferenceTesting.csproj +++ /dev/null @@ -1,22 +0,0 @@ - - - - true - CORECLR - - - - - - - - - - - - - - - - - diff --git a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs index 00927936b1..9c8bc34a2f 100644 --- a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs +++ b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs @@ -166,42 +166,40 @@ public void OnnxStatic() var modelFile = "squeezenet/00000001/model.onnx"; - using (var env = new ConsoleEnvironment(null, false, 0, 1, null, null)) + var env = new MLContext(conc: 1); + var imageHeight = 224; + var imageWidth = 224; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + + var data = TextLoader.CreateReader(env, ctx => ( + imagePath: ctx.LoadText(0), + name: ctx.LoadText(1))) + .Read(dataFile); + + // Note that CamelCase column names are there to match the TF graph node names. + var pipe = data.MakeNewEstimator() + .Append(row => ( + row.name, + data_0: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true))) + .Append(row => (row.name, softmaxout_1: row.data_0.ApplyOnnxModel(modelFile))); + + TestEstimatorCore(pipe.AsDynamic, data.AsDynamic); + + var result = pipe.Fit(data).Transform(data).AsDynamic; + result.Schema.TryGetColumnIndex("softmaxout_1", out int output); + using (var cursor = result.GetRowCursor(col => col == output)) { - var imageHeight = 224; - var imageWidth = 224; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - - var data = TextLoader.CreateReader(env, ctx => ( - imagePath: ctx.LoadText(0), - name: ctx.LoadText(1))) - .Read(dataFile); - - // Note that CamelCase column names are there to match the TF graph node names. - var pipe = data.MakeNewEstimator() - .Append(row => ( - row.name, - data_0: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true))) - .Append(row => (row.name, softmaxout_1: row.data_0.ApplyOnnxModel(modelFile))); - - TestEstimatorCore(pipe.AsDynamic, data.AsDynamic); - - var result = pipe.Fit(data).Transform(data).AsDynamic; - result.Schema.TryGetColumnIndex("softmaxout_1", out int output); - using (var cursor = result.GetRowCursor(col => col == output)) + var buffer = default(VBuffer); + var getter = cursor.GetGetter>(output); + var numRows = 0; + while (cursor.MoveNext()) { - var buffer = default(VBuffer); - var getter = cursor.GetGetter>(output); - var numRows = 0; - while (cursor.MoveNext()) - { - getter(ref buffer); - Assert.Equal(1000, buffer.Length); - numRows += 1; - } - Assert.Equal(3, numRows); + getter(ref buffer); + Assert.Equal(1000, buffer.Length); + numRows += 1; } + Assert.Equal(3, numRows); } } @@ -211,11 +209,9 @@ void TestCommandLine() if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) return; - using (var env = new ConsoleEnvironment()) - { - var x = Maml.Main(new[] { @"showschema loader=Text{col=data_0:R4:0-150527} xf=Onnx{InputColumn=data_0 OutputColumn=softmaxout_1 model={squeezenet/00000001/model.onnx}}" }); - Assert.Equal(0, x); - } + var env = new MLContext(); + var x = Maml.Main(new[] { @"showschema loader=Text{col=data_0:R4:0-150527} xf=Onnx{InputColumn=data_0 OutputColumn=softmaxout_1 model={squeezenet/00000001/model.onnx}}" }); + Assert.Equal(0, x); } } } diff --git a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLine.cs b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLine.cs index 0839a57ed3..3f4fa341dd 100644 --- a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLine.cs +++ b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLine.cs @@ -97,7 +97,7 @@ public void CmdParsingSingle() /// private static void Init(IndentingTextWriter wrt, object defaults) { - var env = new ConsoleEnvironment(seed: 42); + var env = new MLContext(seed: 42); wrt.WriteLine("Usage:"); wrt.WriteLine(CmdParser.ArgumentsUsage(env, defaults.GetType(), defaults, false, 200)); } @@ -107,7 +107,7 @@ private static void Init(IndentingTextWriter wrt, object defaults) /// private static void Process(IndentingTextWriter wrt, string text, ArgsBase defaults) { - var env = new ConsoleEnvironment(seed: 42); + var env = new MLContext(seed: 42); using (wrt.Nest()) { var args1 = defaults.Clone(); diff --git a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs index 6d95829454..51330648ee 100644 --- a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs +++ b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs @@ -19,7 +19,7 @@ public class CmdLineReverseTests [TestCategory("Cmd Parsing")] public void ArgumentParseTest() { - var env = new ConsoleEnvironment(seed: 42); + var env = new MLContext(seed: 42); var innerArg1 = new SimpleArg() { required = -2, diff --git a/test/Microsoft.ML.InferenceTesting/InferRecipesCommand.cs b/test/Microsoft.ML.Predictor.Tests/Commands/InferRecipesCommand.cs similarity index 96% rename from test/Microsoft.ML.InferenceTesting/InferRecipesCommand.cs rename to test/Microsoft.ML.Predictor.Tests/Commands/InferRecipesCommand.cs index a6dbae8248..d84b6f9a7d 100644 --- a/test/Microsoft.ML.InferenceTesting/InferRecipesCommand.cs +++ b/test/Microsoft.ML.Predictor.Tests/Commands/InferRecipesCommand.cs @@ -22,8 +22,9 @@ namespace Microsoft.ML.Runtime.MLTesting.Inference /// This command generates a suggested RSP to load the text file and recipes it prior to training. /// The results are output to the console and also to the RSP file, if it's specified. /// - public sealed class InferRecipesCommand : ICommand + internal sealed class InferRecipesCommand : ICommand { +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. public sealed class Arguments { [Argument(ArgumentType.Required, HelpText = "Text file with data to analyze", ShortName = "data")] @@ -35,6 +36,7 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Optional path to the schema definition file generated by the InferSchema command", ShortName = "schema")] public string SchemaDefinitionFile; } +#pragma warning restore CS0649 private readonly IHost _host; private readonly string _dataFile; diff --git a/test/Microsoft.ML.InferenceTesting/InferSchemaCommand.cs b/test/Microsoft.ML.Predictor.Tests/Commands/InferSchemaCommand.cs similarity index 95% rename from test/Microsoft.ML.InferenceTesting/InferSchemaCommand.cs rename to test/Microsoft.ML.Predictor.Tests/Commands/InferSchemaCommand.cs index 7a96008bd0..269da88b19 100644 --- a/test/Microsoft.ML.InferenceTesting/InferSchemaCommand.cs +++ b/test/Microsoft.ML.Predictor.Tests/Commands/InferSchemaCommand.cs @@ -22,8 +22,9 @@ namespace Microsoft.ML.Runtime.MLTesting.Inference /// This command generates a suggested RSP to load the text file and recipes it prior to training. /// The results are output to the console and also to the RSP file, if it's specified. /// - public sealed class InferSchemaCommand : ICommand + internal sealed class InferSchemaCommand : ICommand { +#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms. public sealed class Arguments { [Argument(ArgumentType.Required, HelpText = "Text file with data to analyze", ShortName = "data")] @@ -32,6 +33,7 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Path to the output json file describing the columns", ShortName = "out")] public string OutputFile; } +#pragma warning restore CS0649 private readonly IHost _host; private readonly string _dataFile; diff --git a/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj b/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj index d089946b87..9cc27b17e2 100644 --- a/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj +++ b/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj @@ -16,7 +16,6 @@ - diff --git a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs index 48f6336ecb..84cbea24ba 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs @@ -28,111 +28,102 @@ public TestAutoInference(ITestOutputHelper helper) [TestCategory("EntryPoints")] public void TestLearn() { - using (var env = new ConsoleEnvironment() - .AddStandardComponents()) // AutoInference.InferPipelines uses ComponentCatalog to read text data - { - string pathData = GetDataPath("adult.train"); - string pathDataTest = GetDataPath("adult.test"); - int numOfSampleRows = 1000; - int batchSize = 5; - int numIterations = 10; - int numTransformLevels = 3; - SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); - - // Using the simple, uniform random sampling (with replacement) engine - PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(env); - - // Test initial learning - var amls = AutoInference.InferPipelines(env, autoMlEngine, pathData, "", out var schema, numTransformLevels, batchSize, - metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations / 2), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer); - env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations / 2); - - // Resume learning - amls.UpdateTerminator(new IterationTerminator(numIterations)); - bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows); - env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations); - - // Use best pipeline for another task - var inputFileTrain = new SimpleFileHandle(env, pathData, false, false); + var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners + string pathData = GetDataPath("adult.train"); + string pathDataTest = GetDataPath("adult.test"); + int numOfSampleRows = 1000; + int batchSize = 5; + int numIterations = 10; + int numTransformLevels = 3; + SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); + + // Using the simple, uniform random sampling (with replacement) engine + PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(env); + + // Test initial learning + var amls = AutoInference.InferPipelines(env, autoMlEngine, pathData, "", out var schema, numTransformLevels, batchSize, + metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations / 2), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer); + env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations / 2); + + // Resume learning + amls.UpdateTerminator(new IterationTerminator(numIterations)); + bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows); + env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations); + + // Use best pipeline for another task + var inputFileTrain = new SimpleFileHandle(env, pathData, false, false); #pragma warning disable 0618 - var datasetTrain = ImportTextData.ImportText(env, - new ImportTextData.Input { InputFile = inputFileTrain, CustomSchema = schema }).Data; - var inputFileTest = new SimpleFileHandle(env, pathDataTest, false, false); - var datasetTest = ImportTextData.ImportText(env, - new ImportTextData.Input { InputFile = inputFileTest, CustomSchema = schema }).Data; + var datasetTrain = ImportTextData.ImportText(env, + new ImportTextData.Input { InputFile = inputFileTrain, CustomSchema = schema }).Data; + var inputFileTest = new SimpleFileHandle(env, pathDataTest, false, false); + var datasetTest = ImportTextData.ImportText(env, + new ImportTextData.Input { InputFile = inputFileTest, CustomSchema = schema }).Data; #pragma warning restore 0618 - // REVIEW: Theoretically, it could be the case that a new, very bad learner is introduced and - // we get unlucky and only select it every time, such that this test fails. Not - // likely at all, but a non-zero probability. Should be ok, since all current learners are returning d > .80. - bestPipeline.RunTrainTestExperiment(datasetTrain, datasetTest, metric, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer, - out var testMetricValue, out var trainMtericValue); - env.Check(testMetricValue > 0.2); - } + // REVIEW: Theoretically, it could be the case that a new, very bad learner is introduced and + // we get unlucky and only select it every time, such that this test fails. Not + // likely at all, but a non-zero probability. Should be ok, since all current learners are returning d > .80. + bestPipeline.RunTrainTestExperiment(datasetTrain, datasetTest, metric, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer, + out var testMetricValue, out var trainMtericValue); + env.Check(testMetricValue > 0.2); Done(); } [Fact(Skip = "Need CoreTLC specific baseline update")] public void TestTextDatasetLearn() { - using (var env = new ConsoleEnvironment() - .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners - { - string pathData = GetDataPath(@"../UnitTest/tweets_labeled_10k_test_validation.tsv"); - int batchSize = 5; - int numIterations = 35; - int numTransformLevels = 1; - int numSampleRows = 100; - SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro); - - // Using the simple, uniform random sampling (with replacement) engine - PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(env); - - // Test initial learning - var amls = AutoInference.InferPipelines(env, autoMlEngine, pathData, "", out var _, numTransformLevels, batchSize, - metric, out var _, numSampleRows, new IterationTerminator(numIterations), - MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer); - env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations); - } + var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners + string pathData = GetDataPath(@"../UnitTest/tweets_labeled_10k_test_validation.tsv"); + int batchSize = 5; + int numIterations = 35; + int numTransformLevels = 1; + int numSampleRows = 100; + SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro); + + // Using the simple, uniform random sampling (with replacement) engine + PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(env); + + // Test initial learning + var amls = AutoInference.InferPipelines(env, autoMlEngine, pathData, "", out var _, numTransformLevels, batchSize, + metric, out var _, numSampleRows, new IterationTerminator(numIterations), + MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer); + env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations); Done(); } [Fact] public void TestPipelineNodeCloning() { - using (var env = new ConsoleEnvironment() - .AddStandardComponents()) // RecipeInference.AllowedLearners uses ComponentCatalog to find all learners - { - var lr1 = RecipeInference + var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners + var lr1 = RecipeInference .AllowedLearners(env, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer) .First(learner => learner.PipelineNode != null && learner.LearnerName.Contains("LogisticRegression")); - var sdca1 = RecipeInference - .AllowedLearners(env, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer) - .First(learner => learner.PipelineNode != null && learner.LearnerName.Contains("StochasticDualCoordinateAscent")); - - // Clone and change hyperparam values - var lr2 = lr1.Clone(); - lr1.PipelineNode.SweepParams[0].RawValue = 1.2f; - lr2.PipelineNode.SweepParams[0].RawValue = 3.5f; - var sdca2 = sdca1.Clone(); - sdca1.PipelineNode.SweepParams[0].RawValue = 3; - sdca2.PipelineNode.SweepParams[0].RawValue = 0; - - // Make sure the changes are propagated to entry point objects - env.Check(lr1.PipelineNode.UpdateProperties()); - env.Check(lr2.PipelineNode.UpdateProperties()); - env.Check(sdca1.PipelineNode.UpdateProperties()); - env.Check(sdca2.PipelineNode.UpdateProperties()); - env.Check(lr1.PipelineNode.CheckEntryPointStateMatchesParamValues()); - env.Check(lr2.PipelineNode.CheckEntryPointStateMatchesParamValues()); - env.Check(sdca1.PipelineNode.CheckEntryPointStateMatchesParamValues()); - env.Check(sdca2.PipelineNode.CheckEntryPointStateMatchesParamValues()); - - // Make sure second object's set of changes didn't overwrite first object's - env.Check(!lr1.PipelineNode.SweepParams[0].RawValue.Equals(lr2.PipelineNode.SweepParams[0].RawValue)); - env.Check(!sdca2.PipelineNode.SweepParams[0].RawValue.Equals(sdca1.PipelineNode.SweepParams[0].RawValue)); - } + var sdca1 = RecipeInference + .AllowedLearners(env, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer) + .First(learner => learner.PipelineNode != null && learner.LearnerName.Contains("StochasticDualCoordinateAscent")); + + // Clone and change hyperparam values + var lr2 = lr1.Clone(); + lr1.PipelineNode.SweepParams[0].RawValue = 1.2f; + lr2.PipelineNode.SweepParams[0].RawValue = 3.5f; + var sdca2 = sdca1.Clone(); + sdca1.PipelineNode.SweepParams[0].RawValue = 3; + sdca2.PipelineNode.SweepParams[0].RawValue = 0; + + // Make sure the changes are propagated to entry point objects + env.Check(lr1.PipelineNode.UpdateProperties()); + env.Check(lr2.PipelineNode.UpdateProperties()); + env.Check(sdca1.PipelineNode.UpdateProperties()); + env.Check(sdca2.PipelineNode.UpdateProperties()); + env.Check(lr1.PipelineNode.CheckEntryPointStateMatchesParamValues()); + env.Check(lr2.PipelineNode.CheckEntryPointStateMatchesParamValues()); + env.Check(sdca1.PipelineNode.CheckEntryPointStateMatchesParamValues()); + env.Check(sdca2.PipelineNode.CheckEntryPointStateMatchesParamValues()); + + // Make sure second object's set of changes didn't overwrite first object's + env.Check(!lr1.PipelineNode.SweepParams[0].RawValue.Equals(lr2.PipelineNode.SweepParams[0].RawValue)); + env.Check(!sdca2.PipelineNode.SweepParams[0].RawValue.Equals(sdca1.PipelineNode.SweepParams[0].RawValue)); } [Fact] @@ -143,45 +134,42 @@ public void TestHyperparameterFreezing() int batchSize = 1; int numIterations = 10; int numTransformLevels = 3; - using (var env = new ConsoleEnvironment() - .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners - { - SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); - - // Using the simple, uniform random sampling (with replacement) brain - PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env); - - // Run initial experiments - var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize, - metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations), - MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer); - - // Clear results - amls.ClearEvaluatedPipelines(); - - // Get space, remove transforms and all but one learner, freeze hyperparameters on learner. - var space = amls.GetSearchSpace(); - var transforms = space.Item1.Where(t => - t.ExpertType != typeof(TransformInference.Experts.Categorical)).ToArray(); - var learners = new[] { space.Item2.First() }; - var hyperParam = learners[0].PipelineNode.SweepParams.First(); - var frozenParamValue = hyperParam.RawValue; - hyperParam.Frozen = true; - amls.UpdateSearchSpace(learners, transforms); - - // Allow for one more iteration - amls.UpdateTerminator(new IterationTerminator(numIterations + 1)); - - // Do learning. Only retained learner should be left in all pipelines. - bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows); - - // Make sure all pipelines have retained learner - Assert.True(amls.GetAllEvaluatedPipelines().All(p => p.Learner.LearnerName == learners[0].LearnerName)); - - // Make sure hyperparameter value did not change - Assert.NotNull(bestPipeline); - Assert.Equal(bestPipeline.Learner.PipelineNode.SweepParams.First().RawValue, frozenParamValue); - } + var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners + SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); + + // Using the simple, uniform random sampling (with replacement) brain + PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env); + + // Run initial experiments + var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize, + metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations), + MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer); + + // Clear results + amls.ClearEvaluatedPipelines(); + + // Get space, remove transforms and all but one learner, freeze hyperparameters on learner. + var space = amls.GetSearchSpace(); + var transforms = space.Item1.Where(t => + t.ExpertType != typeof(TransformInference.Experts.Categorical)).ToArray(); + var learners = new[] { space.Item2.First() }; + var hyperParam = learners[0].PipelineNode.SweepParams.First(); + var frozenParamValue = hyperParam.RawValue; + hyperParam.Frozen = true; + amls.UpdateSearchSpace(learners, transforms); + + // Allow for one more iteration + amls.UpdateTerminator(new IterationTerminator(numIterations + 1)); + + // Do learning. Only retained learner should be left in all pipelines. + bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows); + + // Make sure all pipelines have retained learner + Assert.True(amls.GetAllEvaluatedPipelines().All(p => p.Learner.LearnerName == learners[0].LearnerName)); + + // Make sure hyperparameter value did not change + Assert.NotNull(bestPipeline); + Assert.Equal(bestPipeline.Learner.PipelineNode.SweepParams.First().RawValue, frozenParamValue); } [Fact(Skip = "Dataset not available.")] @@ -192,30 +180,27 @@ public void TestRegressionPipelineWithMinimizingMetric() int batchSize = 5; int numIterations = 10; int numTransformLevels = 1; - using (var env = new ConsoleEnvironment() - .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners - { - SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro); + var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners + SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro); - // Using the simple, uniform random sampling (with replacement) brain - PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env); + // Using the simple, uniform random sampling (with replacement) brain + PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env); - // Run initial experiments - var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize, - metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations), - MacroUtils.TrainerKinds.SignatureRegressorTrainer); + // Run initial experiments + var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize, + metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations), + MacroUtils.TrainerKinds.SignatureRegressorTrainer); - // Allow for one more iteration - amls.UpdateTerminator(new IterationTerminator(numIterations + 1)); + // Allow for one more iteration + amls.UpdateTerminator(new IterationTerminator(numIterations + 1)); - // Do learning. Only retained learner should be left in all pipelines. - bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows); + // Do learning. Only retained learner should be left in all pipelines. + bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows); - // Make sure hyperparameter value did not change - Assert.NotNull(bestPipeline); - Assert.True(amls.GetAllEvaluatedPipelines().All( - p => p.PerformanceSummary.MetricValue >= bestPipeline.PerformanceSummary.MetricValue)); - } + // Make sure hyperparameter value did not change + Assert.NotNull(bestPipeline); + Assert.True(amls.GetAllEvaluatedPipelines().All( + p => p.PerformanceSummary.MetricValue >= bestPipeline.PerformanceSummary.MetricValue)); } [Fact] @@ -227,27 +212,24 @@ public void TestLearnerConstrainingByName() int numIterations = 1; int numTransformLevels = 2; var retainedLearnerNames = new[] { $"LogisticRegressionBinaryClassifier", $"FastTreeBinaryClassifier" }; - using (var env = new ConsoleEnvironment() - .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners - { - SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); - - // Using the simple, uniform random sampling (with replacement) brain. - PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env); - - // Run initial experiment. - var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, - numTransformLevels, batchSize, metric, out var _, numOfSampleRows, - new IterationTerminator(numIterations), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer); - - // Keep only logistic regression and FastTree. - amls.KeepSelectedLearners(retainedLearnerNames); - var space = amls.GetSearchSpace(); - - // Make sure only learners left are those retained. - Assert.Equal(retainedLearnerNames.Length, space.Item2.Length); - Assert.True(space.Item2.All(l => retainedLearnerNames.Any(r => r == l.LearnerName))); - } + var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners + SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); + + // Using the simple, uniform random sampling (with replacement) brain. + PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env); + + // Run initial experiment. + var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, + numTransformLevels, batchSize, metric, out var _, numOfSampleRows, + new IterationTerminator(numIterations), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer); + + // Keep only logistic regression and FastTree. + amls.KeepSelectedLearners(retainedLearnerNames); + var space = amls.GetSearchSpace(); + + // Make sure only learners left are those retained. + Assert.Equal(retainedLearnerNames.Length, space.Item2.Length); + Assert.True(space.Item2.All(l => retainedLearnerNames.Any(r => r == l.LearnerName))); } [Fact] diff --git a/test/Microsoft.ML.Predictor.Tests/TestDatasetInference.cs b/test/Microsoft.ML.Predictor.Tests/TestDatasetInference.cs index 0b92eb7b30..5cf5a86ceb 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestDatasetInference.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestDatasetInference.cs @@ -25,7 +25,7 @@ public TestDatasetInference(ITestOutputHelper helper) { } - [Fact(Skip="Disabled")] + [Fact(Skip = "Disabled")] public void DatasetInferenceTest() { var datasets = new[] @@ -35,51 +35,49 @@ public void DatasetInferenceTest() GetDataPath(@"..\UnitTest\breast-cancer.txt"), }; - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + var h = env.Register("InferDatasetFeatures", seed: 0, verbose: false); + + using (var ch = h.Start("InferDatasetFeatures")) { - var h = env.Register("InferDatasetFeatures", seed: 0, verbose: false); - using (var ch = h.Start("InferDatasetFeatures")) + for (int i = 0; i < datasets.Length; i++) { + var sample = TextFileSample.CreateFromFullFile(h, datasets[i]); + var splitResult = TextFileContents.TrySplitColumns(h, sample, TextFileContents.DefaultSeparators); + if (!splitResult.IsSuccess) + throw ch.ExceptDecode("Couldn't detect separator."); - for (int i = 0; i < datasets.Length; i++) - { - var sample = TextFileSample.CreateFromFullFile(h, datasets[i]); - var splitResult = TextFileContents.TrySplitColumns(h, sample, TextFileContents.DefaultSeparators); - if (!splitResult.IsSuccess) - throw ch.ExceptDecode("Couldn't detect separator."); - - var typeInfResult = ColumnTypeInference.InferTextFileColumnTypes(Env, sample, - new ColumnTypeInference.Arguments - { - Separator = splitResult.Separator, - AllowSparse = splitResult.AllowSparse, - AllowQuote = splitResult.AllowQuote, - ColumnCount = splitResult.ColumnCount - }); - - if (!typeInfResult.IsSuccess) - return; - - ColumnGroupingInference.GroupingColumn[] columns = null; - bool hasHeader = false; - columns = InferenceUtils.InferColumnPurposes(ch, h, sample, splitResult, out hasHeader); - Guid id = new Guid("60C77F4E-DB62-4351-8311-9B392A12968E"); - var commandArgs = new DatasetFeatureInference.Arguments(typeInfResult.Data, - columns.Select( - col => - new DatasetFeatureInference.Column(col.SuggestedName, col.Purpose, col.ItemKind, - col.ColumnRangeSelector)).ToArray(), sample.FullFileSize, sample.ApproximateRowCount, - false, id, true); - - string jsonString = DatasetFeatureInference.InferDatasetFeatures(env, commandArgs); - var outFile = string.Format("dataset-inference-result-{0:00}.txt", i); - string dataPath = GetOutputPath(@"..\Common\Inference", outFile); - using (var sw = new StreamWriter(File.Create(dataPath))) - sw.WriteLine(jsonString); - - CheckEquality(@"..\Common\Inference", outFile); - } + var typeInfResult = ColumnTypeInference.InferTextFileColumnTypes(Env, sample, + new ColumnTypeInference.Arguments + { + Separator = splitResult.Separator, + AllowSparse = splitResult.AllowSparse, + AllowQuote = splitResult.AllowQuote, + ColumnCount = splitResult.ColumnCount + }); + + if (!typeInfResult.IsSuccess) + return; + + ColumnGroupingInference.GroupingColumn[] columns = null; + bool hasHeader = false; + columns = InferenceUtils.InferColumnPurposes(ch, h, sample, splitResult, out hasHeader); + Guid id = new Guid("60C77F4E-DB62-4351-8311-9B392A12968E"); + var commandArgs = new DatasetFeatureInference.Arguments(typeInfResult.Data, + columns.Select( + col => + new DatasetFeatureInference.Column(col.SuggestedName, col.Purpose, col.ItemKind, + col.ColumnRangeSelector)).ToArray(), sample.FullFileSize, sample.ApproximateRowCount, + false, id, true); + + string jsonString = DatasetFeatureInference.InferDatasetFeatures(env, commandArgs); + var outFile = string.Format("dataset-inference-result-{0:00}.txt", i); + string dataPath = GetOutputPath(@"..\Common\Inference", outFile); + using (var sw = new StreamWriter(File.Create(dataPath))) + sw.WriteLine(jsonString); + + CheckEquality(@"..\Common\Inference", outFile); } } Done(); @@ -93,26 +91,24 @@ public void InferSchemaCommandTest() GetDataPath(Path.Combine("..", "data", "wikipedia-detox-250-line-data.tsv")) }; - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + var h = env.Register("InferSchemaCommandTest", seed: 0, verbose: false); + using (var ch = h.Start("InferSchemaCommandTest")) { - var h = env.Register("InferSchemaCommandTest", seed: 0, verbose: false); - using (var ch = h.Start("InferSchemaCommandTest")) + for (int i = 0; i < datasets.Length; i++) { - for (int i = 0; i < datasets.Length; i++) + var outFile = string.Format("dataset-infer-schema-result-{0:00}.txt", i); + string dataPath = GetOutputPath(Path.Combine("..", "Common", "Inference"), outFile); + var args = new InferSchemaCommand.Arguments() { - var outFile = string.Format("dataset-infer-schema-result-{0:00}.txt", i); - string dataPath = GetOutputPath(Path.Combine("..", "Common", "Inference"), outFile); - var args = new InferSchemaCommand.Arguments() - { - DataFile = datasets[i], - OutputFile = dataPath, - }; + DataFile = datasets[i], + OutputFile = dataPath, + }; - var cmd = new InferSchemaCommand(Env, args); - cmd.Run(); + var cmd = new InferSchemaCommand(Env, args); + cmd.Run(); - CheckEquality(Path.Combine("..", "Common", "Inference"), outFile); - } + CheckEquality(Path.Combine("..", "Common", "Inference"), outFile); } } Done(); @@ -128,26 +124,24 @@ public void InferRecipesCommandTest() GetDataPath(Path.Combine("..", "data", "wikipedia-detox-250-line-data-schema.txt"))) }; - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + var h = env.Register("InferRecipesCommandTest", seed: 0, verbose: false); + using (var ch = h.Start("InferRecipesCommandTest")) { - var h = env.Register("InferRecipesCommandTest", seed: 0, verbose: false); - using (var ch = h.Start("InferRecipesCommandTest")) + for (int i = 0; i < datasets.Length; i++) { - for (int i = 0; i < datasets.Length; i++) + var outFile = string.Format("dataset-infer-recipe-result-{0:00}.txt", i); + string dataPath = GetOutputPath(Path.Combine("..", "Common", "Inference"), outFile); + var args = new InferRecipesCommand.Arguments() { - var outFile = string.Format("dataset-infer-recipe-result-{0:00}.txt", i); - string dataPath = GetOutputPath(Path.Combine("..", "Common", "Inference"), outFile); - var args = new InferRecipesCommand.Arguments() - { - DataFile = datasets[i].Item1, - SchemaDefinitionFile = datasets[i].Item2, - RspOutputFile = dataPath - }; - var cmd = new InferRecipesCommand(Env, args); - cmd.Run(); - - CheckEquality(Path.Combine("..", "Common", "Inference"), outFile); - } + DataFile = datasets[i].Item1, + SchemaDefinitionFile = datasets[i].Item2, + RspOutputFile = dataPath + }; + var cmd = new InferRecipesCommand(Env, args); + cmd.Run(); + + CheckEquality(Path.Combine("..", "Common", "Inference"), outFile); } } Done(); diff --git a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs index b72c9773d3..5e277daa05 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs @@ -123,42 +123,40 @@ public void PipelineSweeperNoTransforms() const int batchSize = 5; const int numIterations = 20; const int numTransformLevels = 2; - using (var env = new ConsoleEnvironment()) - { - SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); + var env = new MLContext(); + SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); - // Using the simple, uniform random sampling (with replacement) engine - PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(Env); + // Using the simple, uniform random sampling (with replacement) engine + PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(Env); - // Create search object - var amls = new AutoInference.AutoMlMlState(Env, metric, autoMlEngine, new IterationTerminator(numIterations), - MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer, datasetTrain, datasetTest); + // Create search object + var amls = new AutoInference.AutoMlMlState(Env, metric, autoMlEngine, new IterationTerminator(numIterations), + MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer, datasetTrain, datasetTest); - // Infer search space - amls.InferSearchSpace(numTransformLevels); + // Infer search space + amls.InferSearchSpace(numTransformLevels); - // Create macro object - var pipelineSweepInput = new Microsoft.ML.Legacy.Models.PipelineSweeper() - { - BatchSize = batchSize, - }; - - var exp = new Experiment(Env); - var output = exp.Add(pipelineSweepInput); - exp.Compile(); - exp.SetInput(pipelineSweepInput.TrainingData, datasetTrain); - exp.SetInput(pipelineSweepInput.TestingData, datasetTest); - exp.SetInput(pipelineSweepInput.State, amls); - exp.SetInput(pipelineSweepInput.CandidateOutputs, new IDataView[0]); - exp.Run(); - - // Make sure you get back an AutoMlState, and that it ran for correct number of iterations - // with at least minimal performance values (i.e., best should have AUC better than 0.1 on this dataset). - AutoInference.AutoMlMlState amlsOut = (AutoInference.AutoMlMlState)exp.GetOutput(output.State); - Assert.NotNull(amlsOut); - Assert.Equal(amlsOut.GetAllEvaluatedPipelines().Length, numIterations); - Assert.True(amlsOut.GetBestPipeline().PerformanceSummary.MetricValue > 0.8); - } + // Create macro object + var pipelineSweepInput = new Microsoft.ML.Legacy.Models.PipelineSweeper() + { + BatchSize = batchSize, + }; + + var exp = new Experiment(Env); + var output = exp.Add(pipelineSweepInput); + exp.Compile(); + exp.SetInput(pipelineSweepInput.TrainingData, datasetTrain); + exp.SetInput(pipelineSweepInput.TestingData, datasetTest); + exp.SetInput(pipelineSweepInput.State, amls); + exp.SetInput(pipelineSweepInput.CandidateOutputs, new IDataView[0]); + exp.Run(); + + // Make sure you get back an AutoMlState, and that it ran for correct number of iterations + // with at least minimal performance values (i.e., best should have AUC better than 0.1 on this dataset). + AutoInference.AutoMlMlState amlsOut = (AutoInference.AutoMlMlState)exp.GetOutput(output.State); + Assert.NotNull(amlsOut); + Assert.Equal(amlsOut.GetAllEvaluatedPipelines().Length, numIterations); + Assert.True(amlsOut.GetBestPipeline().PerformanceSummary.MetricValue > 0.8); } [Fact] diff --git a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs index 3b28fc1cfa..bdb9677495 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.ImageAnalytics; using Xunit; @@ -20,7 +19,7 @@ public ImageAnalyticsTests(ITestOutputHelper output) [Fact] public void SimpleImageSmokeTest() { - var env = new ConsoleEnvironment(0, verbose: true); + var env = new MLContext(0); var reader = TextLoader.CreateReader(env, ctx => ctx.LoadText(0).LoadAsImage().AsGrayscale().Resize(10, 8).ExtractPixels()); diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index b5fb351b99..ab7049a58a 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -58,7 +58,7 @@ private void CheckSchemaHasColumn(ISchema schema, string name, out int idx) [Fact] public void SimpleTextLoaderCopyColumnsTest() { - var env = new ConsoleEnvironment(0, verbose: true); + var env = new MLContext(0); const string data = "0 hello 3.14159 -0 2\n" + "1 1 2 4 15"; @@ -159,7 +159,7 @@ private static Obnoxious3 MakeObnoxious3(Scalar hi, Obnoxious1 my, T [Fact] public void SimpleTextLoaderObnoxiousTypeTest() { - var env = new ConsoleEnvironment(0, verbose: true); + var env = new MLContext(0); const string data = "0 hello 3.14159 -0 2\n" + "1 1 2 4 15"; @@ -206,7 +206,7 @@ private static KeyValuePair P(string name, ColumnType type) [Fact] public void AssertStaticSimple() { - var env = new ConsoleEnvironment(0, verbose: true); + var env = new MLContext(0); var schema = SimpleSchemaUtils.Create(env, P("hello", TextType.Instance), P("my", new VectorType(NumberType.I8, 5)), @@ -230,7 +230,7 @@ public void AssertStaticSimple() [Fact] public void AssertStaticSimpleFailure() { - var env = new ConsoleEnvironment(0, verbose: true); + var env = new MLContext(0); var schema = SimpleSchemaUtils.Create(env, P("hello", TextType.Instance), P("my", new VectorType(NumberType.I8, 5)), @@ -262,7 +262,7 @@ private sealed class MetaCounted : ICounted [Fact] public void AssertStaticKeys() { - var env = new ConsoleEnvironment(0, verbose: true); + var env = new MLContext(0); var counted = new MetaCounted(); // We'll test a few things here. First, the case where the key-value metadata is text. @@ -369,7 +369,7 @@ public void AssertStaticKeys() [Fact] public void Normalizer() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("generated_regression_dataset.csv"); var dataSource = new MultiFileSource(dataPath); @@ -394,7 +394,7 @@ public void Normalizer() [Fact] public void NormalizerWithOnFit() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("generated_regression_dataset.csv"); var dataSource = new MultiFileSource(dataPath); @@ -438,7 +438,7 @@ public void NormalizerWithOnFit() [Fact] public void ToKey() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("iris.data"); var reader = TextLoader.CreateReader(env, c => (label: c.LoadText(4), values: c.LoadFloat(0, 3)), @@ -476,7 +476,7 @@ public void ToKey() [Fact] public void ConcatWith() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("iris.data"); var reader = TextLoader.CreateReader(env, c => (label: c.LoadText(4), values: c.LoadFloat(0, 3), value: c.LoadFloat(2)), @@ -514,7 +514,7 @@ public void ConcatWith() [Fact] public void Tokenize() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); var reader = TextLoader.CreateReader(env, ctx => ( label: ctx.LoadBool(0), @@ -543,7 +543,7 @@ public void Tokenize() [Fact] public void NormalizeTextAndRemoveStopWords() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); var reader = TextLoader.CreateReader(env, ctx => ( label: ctx.LoadBool(0), @@ -572,7 +572,7 @@ public void NormalizeTextAndRemoveStopWords() [Fact] public void ConvertToWordBag() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); var reader = TextLoader.CreateReader(env, ctx => ( label: ctx.LoadBool(0), @@ -601,7 +601,7 @@ public void ConvertToWordBag() [Fact] public void Ngrams() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); var reader = TextLoader.CreateReader(env, ctx => ( label: ctx.LoadBool(0), @@ -631,7 +631,7 @@ public void Ngrams() [Fact] public void LpGcNormAndWhitening() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("generated_regression_dataset.csv"); var dataSource = new MultiFileSource(dataPath); @@ -669,7 +669,7 @@ public void LpGcNormAndWhitening() [Fact(Skip = "LDA transform cannot be trained on empty data, schema propagation fails")] public void LdaTopicModel() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); var reader = TextLoader.CreateReader(env, ctx => ( label: ctx.LoadBool(0), @@ -696,7 +696,7 @@ public void LdaTopicModel() [Fact(Skip = "FeatureSeclection transform cannot be trained on empty data, schema propagation fails")] public void FeatureSelection() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); var reader = TextLoader.CreateReader(env, ctx => ( label: ctx.LoadBool(0), @@ -725,7 +725,7 @@ public void FeatureSelection() [Fact] public void TrainTestSplit() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath(TestDatasets.iris.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -755,7 +755,7 @@ public void TrainTestSplit() [Fact] public void PrincipalComponentAnalysis() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("generated_regression_dataset.csv"); var dataSource = new MultiFileSource(dataPath); @@ -778,10 +778,10 @@ public void PrincipalComponentAnalysis() [Fact] public void NAIndicatorStatic() { - var Env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); string dataPath = GetDataPath("breast-cancer.txt"); - var reader = TextLoader.CreateReader(Env, ctx => ( + var reader = TextLoader.CreateReader(env, ctx => ( ScalarFloat: ctx.LoadFloat(1), ScalarDouble: ctx.LoadDouble(1), VectorFloat: ctx.LoadFloat(1, 4), @@ -798,12 +798,12 @@ public void NAIndicatorStatic() D: row.VectorDoulbe.IsMissingValue() )); - IDataView newData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); + IDataView newData = TakeFilter.Create(env, est.Fit(data).Transform(data).AsDynamic, 4); Assert.NotNull(newData); - bool[] ScalarFloat = newData.GetColumn(Env, "A").ToArray(); - bool[] ScalarDouble = newData.GetColumn(Env, "B").ToArray(); - bool[][] VectorFloat = newData.GetColumn(Env, "C").ToArray(); - bool[][] VectorDoulbe = newData.GetColumn(Env, "D").ToArray(); + bool[] ScalarFloat = newData.GetColumn(env, "A").ToArray(); + bool[] ScalarDouble = newData.GetColumn(env, "B").ToArray(); + bool[][] VectorFloat = newData.GetColumn(env, "C").ToArray(); + bool[][] VectorDoulbe = newData.GetColumn(env, "D").ToArray(); Assert.NotNull(ScalarFloat); Assert.NotNull(ScalarDouble); @@ -822,7 +822,7 @@ public void NAIndicatorStatic() [Fact] public void TextNormalizeStatic() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(0); var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); var reader = TextLoader.CreateReader(env, ctx => ( label: ctx.LoadBool(0), @@ -862,7 +862,7 @@ public void TextNormalizeStatic() [Fact] public void TestPcaStatic() { - var env = new ConsoleEnvironment(seed: 1); + var env = new MLContext(0); var dataSource = GetDataPath("generated_regression_dataset.csv"); var reader = TextLoader.CreateReader(env, c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index c689cba935..b59c2118b5 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -32,7 +32,7 @@ public Training(ITestOutputHelper output) : base(output) [Fact] public void SdcaRegression() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0, conc: 1); var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -45,7 +45,8 @@ public void SdcaRegression() LinearRegressionPredictor pred = null; var est = reader.MakeNewEstimator() - .Append(r => (r.label, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, onFit: p => pred = p))); + .Append(r => (r.label, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, + onFit: p => pred = p, advancedSettings: s => s.NumThreads = 1))); var pipe = reader.Append(est); @@ -74,7 +75,7 @@ public void SdcaRegression() [Fact] public void SdcaRegressionNameCollision() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new RegressionContext(env); @@ -85,7 +86,7 @@ public void SdcaRegressionNameCollision() separator: ';', hasHeader: true); var est = reader.MakeNewEstimator() - .Append(r => (r.label, r.Score, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2))); + .Append(r => (r.label, r.Score, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, advancedSettings: s => s.NumThreads = 1))); var pipe = reader.Append(est); @@ -104,7 +105,7 @@ public void SdcaRegressionNameCollision() [Fact] public void SdcaBinaryClassification() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -118,7 +119,8 @@ public void SdcaBinaryClassification() var est = reader.MakeNewEstimator() .Append(r => (r.label, preds: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, - onFit: (p, c) => { pred = p; cali = c; }))); + onFit: (p, c) => { pred = p; cali = c; }, + advancedSettings: s => s.NumThreads = 1))); var pipe = reader.Append(est); @@ -149,7 +151,7 @@ public void SdcaBinaryClassification() [Fact] public void SdcaBinaryClassificationNoCalibration() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -165,7 +167,8 @@ public void SdcaBinaryClassificationNoCalibration() var est = reader.MakeNewEstimator() .Append(r => (r.label, preds: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, - loss: loss, onFit: p => pred = p))); + loss: loss, onFit: p => pred = p, + advancedSettings: s => s.NumThreads = 1))); var pipe = reader.Append(est); @@ -192,7 +195,7 @@ public void SdcaBinaryClassificationNoCalibration() [Fact] public void AveragePerceptronNoCalibration() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -228,7 +231,7 @@ public void AveragePerceptronNoCalibration() [Fact] public void FfmBinaryClassification() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -260,7 +263,7 @@ public void FfmBinaryClassification() [Fact] public void SdcaMulticlass() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.iris.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -310,7 +313,7 @@ public void SdcaMulticlass() [Fact] public void CrossValidate() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.iris.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -334,7 +337,7 @@ public void CrossValidate() [Fact] public void FastTreeBinaryClassification() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -373,7 +376,7 @@ public void FastTreeBinaryClassification() [Fact] public void FastTreeRegression() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -415,7 +418,7 @@ public void FastTreeRegression() [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only public void LightGbmBinaryClassification() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -455,7 +458,7 @@ public void LightGbmBinaryClassification() [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only public void LightGbmRegression() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -497,7 +500,7 @@ public void LightGbmRegression() [Fact] public void PoissonRegression() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -513,7 +516,8 @@ public void PoissonRegression() .Append(r => (r.label, score: ctx.Trainers.PoissonRegression(r.label, r.features, l1Weight: 2, enoforceNoNegativity: true, - onFit: (p) => { pred = p; }))); + onFit: (p) => { pred = p; }, + advancedSettings: s => s.NumThreads = 1))); var pipe = reader.Append(est); @@ -539,7 +543,7 @@ public void PoissonRegression() [Fact] public void LogisticRegressionBinaryClassification() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -552,7 +556,8 @@ public void LogisticRegressionBinaryClassification() var est = reader.MakeNewEstimator() .Append(r => (r.label, preds: ctx.Trainers.LogisticRegressionBinaryClassifier(r.label, r.features, l1Weight: 10, - onFit: (p) => { pred = p; }))); + onFit: (p) => { pred = p; }, + advancedSettings: s => s.NumThreads = 1))); var pipe = reader.Append(est); @@ -577,7 +582,7 @@ public void LogisticRegressionBinaryClassification() [Fact] public void MulticlassLogisticRegression() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.iris.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -592,7 +597,8 @@ public void MulticlassLogisticRegression() .Append(r => (label: r.label.ToKey(), r.features)) .Append(r => (r.label, preds: ctx.Trainers.MultiClassLogisticRegression( r.label, - r.features, onFit: p => pred = p))); + r.features, onFit: p => pred = p, + advancedSettings: s => s.NumThreads = 1))); var pipe = reader.Append(est); @@ -620,7 +626,7 @@ public void MulticlassLogisticRegression() [Fact] public void OnlineGradientDescent() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -663,11 +669,10 @@ public void OnlineGradientDescent() [Fact] public void KMeans() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0, conc: 1); var dataPath = GetDataPath(TestDatasets.iris.trainFilename); var dataSource = new MultiFileSource(dataPath); - var ctx = new ClusteringContext(env); var reader = TextLoader.CreateReader(env, c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); @@ -675,7 +680,7 @@ public void KMeans() var est = reader.MakeNewEstimator() .Append(r => (label: r.label.ToKey(), r.features)) - .Append(r => (r.label, r.features, preds: ctx.Trainers.KMeans(r.features, clustersCount: 3, onFit: p => pred = p))); + .Append(r => (r.label, r.features, preds: env.Clustering.Trainers.KMeans(r.features, clustersCount: 3, onFit: p => pred = p, advancedSettings: s => s.NumThreads = 1))); var pipe = reader.Append(est); @@ -691,23 +696,23 @@ public void KMeans() var data = model.Read(dataSource); - var metrics = ctx.Evaluate(data, r => r.preds.score, r => r.label, r => r.features); + var metrics = env.Clustering.Evaluate(data, r => r.preds.score, r => r.label, r => r.features); Assert.NotNull(metrics); Assert.InRange(metrics.AvgMinScore, 0.5262, 0.5264); Assert.InRange(metrics.Nmi, 0.73, 0.77); Assert.InRange(metrics.Dbi, 0.662, 0.667); - metrics = ctx.Evaluate(data, r => r.preds.score, label: r => r.label); + metrics = env.Clustering.Evaluate(data, r => r.preds.score, label: r => r.label); Assert.NotNull(metrics); Assert.InRange(metrics.AvgMinScore, 0.5262, 0.5264); Assert.True(metrics.Dbi == 0.0); - metrics = ctx.Evaluate(data, r => r.preds.score, features: r => r.features); + metrics = env.Clustering.Evaluate(data, r => r.preds.score, features: r => r.features); Assert.True(double.IsNaN(metrics.Nmi)); - metrics = ctx.Evaluate(data, r => r.preds.score); + metrics = env.Clustering.Evaluate(data, r => r.preds.score); Assert.NotNull(metrics); Assert.InRange(metrics.AvgMinScore, 0.5262, 0.5264); Assert.True(double.IsNaN(metrics.Nmi)); @@ -718,7 +723,7 @@ public void KMeans() [Fact] public void FastTreeRanking() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -759,7 +764,7 @@ public void FastTreeRanking() [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only public void LightGBMRanking() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -800,7 +805,7 @@ public void LightGBMRanking() [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only public void MultiClassLightGBM() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.iris.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -838,7 +843,7 @@ public void MultiClassLightGBM() [Fact] public void MultiClassNaiveBayesTrainer() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.iris.trainFilename); var dataSource = new MultiFileSource(dataPath); @@ -883,7 +888,7 @@ public void MultiClassNaiveBayesTrainer() [Fact] public void HogwildSGDBinaryClassification() { - var env = new ConsoleEnvironment(seed: 0); + var env = new MLContext(seed: 0); var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new BinaryClassificationContext(env); @@ -896,7 +901,8 @@ public void HogwildSGDBinaryClassification() var est = reader.MakeNewEstimator() .Append(r => (r.label, preds: ctx.Trainers.StochasticGradientDescentClassificationTrainer(r.label, r.features, l2Weight: 0, - onFit: (p) => { pred = p; }))); + onFit: (p) => { pred = p; }, + advancedSettings: s => s.NumThreads = 1))); var pipe = reader.Append(est); diff --git a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs index 2f19a5c4f3..1cd0065703 100644 --- a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs +++ b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs @@ -20,19 +20,16 @@ public void UniformRandomSweeperReturnsDistinctValuesWhenProposeSweep() { DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator(); - using (var writer = new StreamWriter(new MemoryStream())) - using (var env = new ConsoleEnvironment(42, outWriter: writer, errWriter: writer)) - { - var sweeper = new UniformRandomSweeper(env, + var env = new MLContext(42); + var sweeper = new UniformRandomSweeper(env, new SweeperBase.ArgumentsBase(), new[] { valueGenerator }); - var results = sweeper.ProposeSweeps(3); - Assert.NotNull(results); + var results = sweeper.ProposeSweeps(3); + Assert.NotNull(results); - int length = results.Length; - Assert.Equal(2, length); - } + int length = results.Length; + Assert.Equal(2, length); } [Fact] @@ -40,19 +37,16 @@ public void RandomGridSweeperReturnsDistinctValuesWhenProposeSweep() { DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator(); - using (var writer = new StreamWriter(new MemoryStream())) - using (var env = new ConsoleEnvironment(42, outWriter: writer, errWriter: writer)) - { - var sweeper = new RandomGridSweeper(env, - new RandomGridSweeper.Arguments(), - new[] { valueGenerator }); + var env = new MLContext(42); + var sweeper = new RandomGridSweeper(env, + new RandomGridSweeper.Arguments(), + new[] { valueGenerator }); - var results = sweeper.ProposeSweeps(3); - Assert.NotNull(results); + var results = sweeper.ProposeSweeps(3); + Assert.NotNull(results); - int length = results.Length; - Assert.Equal(2, length); - } + int length = results.Length; + Assert.Equal(2, length); } private static DiscreteValueGenerator CreateDiscreteValueGenerator() diff --git a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs index d191084ba8..0718a7edde 100644 --- a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs +++ b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs @@ -94,39 +94,37 @@ public void TestDiscreteValueSweep(double normalizedValue, string expected) [Fact] public void TestRandomSweeper() { - using (var env = new ConsoleEnvironment(42)) + var env = new MLContext(42); + var args = new SweeperBase.ArgumentsBase() { - var args = new SweeperBase.ArgumentsBase() - { - SweptParameters = new[] { + SweptParameters = new[] { ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamArguments() { Name = "foo", Min = 10, Max = 20 })), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 100, Max = 200 })) } - }; + }; - var sweeper = new UniformRandomSweeper(env, args); - var initialList = sweeper.ProposeSweeps(5, new List()); - Assert.Equal(5, initialList.Length); - foreach (var parameterSet in initialList) + var sweeper = new UniformRandomSweeper(env, args); + var initialList = sweeper.ProposeSweeps(5, new List()); + Assert.Equal(5, initialList.Length); + foreach (var parameterSet in initialList) + { + foreach (var parameterValue in parameterSet) { - foreach (var parameterValue in parameterSet) + if (parameterValue.Name == "foo") { - if (parameterValue.Name == "foo") - { - var val = long.Parse(parameterValue.ValueText); - Assert.InRange(val, 10, 20); - } - else if (parameterValue.Name == "bar") - { - var val = long.Parse(parameterValue.ValueText); - Assert.InRange(val, 100, 200); - } - else - { - Assert.True(false, "Wrong parameter"); - } + var val = long.Parse(parameterValue.ValueText); + Assert.InRange(val, 10, 20); + } + else if (parameterValue.Name == "bar") + { + var val = long.Parse(parameterValue.ValueText); + Assert.InRange(val, 100, 200); + } + else + { + Assert.True(false, "Wrong parameter"); } } } @@ -136,282 +134,273 @@ public void TestRandomSweeper() public void TestSimpleSweeperAsync() { var random = new Random(42); - using (var env = new ConsoleEnvironment(42)) + var env = new MLContext(42); + const int sweeps = 100; + var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.ArgumentsBase { - int sweeps = 100; - var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.ArgumentsBase - { - SweptParameters = new IComponentFactory[] { + SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5 })), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) } - }); + }); - var paramSets = new List(); - for (int i = 0; i < sweeps; i++) - { - var task = sweeper.Propose(); - Assert.True(task.IsCompleted); - paramSets.Add(task.Result.ParameterSet); - var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true); - sweeper.Update(task.Result.Id, result); - } - Assert.Equal(sweeps, paramSets.Count); - CheckAsyncSweeperResult(paramSets); + var paramSets = new List(); + for (int i = 0; i < sweeps; i++) + { + var task = sweeper.Propose(); + Assert.True(task.IsCompleted); + paramSets.Add(task.Result.ParameterSet); + var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true); + sweeper.Update(task.Result.Id, result); + } + Assert.Equal(sweeps, paramSets.Count); + CheckAsyncSweeperResult(paramSets); - // Test consumption without ever calling Update. - var gridArgs = new RandomGridSweeper.Arguments(); - gridArgs.SweptParameters = new IComponentFactory[] { + // Test consumption without ever calling Update. + var gridArgs = new RandomGridSweeper.Arguments(); + gridArgs.SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 100, LogBase = true })) }; - var gridSweeper = new SimpleAsyncSweeper(env, gridArgs); - paramSets.Clear(); - for (int i = 0; i < sweeps; i++) - { - var task = gridSweeper.Propose(); - Assert.True(task.IsCompleted); - paramSets.Add(task.Result.ParameterSet); - } - Assert.Equal(sweeps, paramSets.Count); - CheckAsyncSweeperResult(paramSets); + var gridSweeper = new SimpleAsyncSweeper(env, gridArgs); + paramSets.Clear(); + for (int i = 0; i < sweeps; i++) + { + var task = gridSweeper.Propose(); + Assert.True(task.IsCompleted); + paramSets.Add(task.Result.ParameterSet); } + Assert.Equal(sweeps, paramSets.Count); + CheckAsyncSweeperResult(paramSets); } [Fact] public void TestDeterministicSweeperAsyncCancellation() { var random = new Random(42); - using (var env = new ConsoleEnvironment(42)) - { - var args = new DeterministicSweeperAsync.Arguments(); - args.BatchSize = 5; - args.Relaxation = 1; - - args.Sweeper = ComponentFactoryUtils.CreateFromFunction( - environ => new KdoSweeper(environ, - new KdoSweeper.Arguments() - { - SweptParameters = new IComponentFactory[] { + var env = new MLContext(42); + var args = new DeterministicSweeperAsync.Arguments(); + args.BatchSize = 5; + args.Relaxation = 1; + + args.Sweeper = ComponentFactoryUtils.CreateFromFunction( + environ => new KdoSweeper(environ, + new KdoSweeper.Arguments() + { + SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( t => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( t => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) - } - })); + } + })); - var sweeper = new DeterministicSweeperAsync(env, args); + var sweeper = new DeterministicSweeperAsync(env, args); - int sweeps = 20; - var tasks = new List>(); - int numCompleted = 0; - for (int i = 0; i < sweeps; i++) - { - var task = sweeper.Propose(); - if (i < args.BatchSize - args.Relaxation) - { - Assert.True(task.IsCompleted); - sweeper.Update(task.Result.Id, new RunResult(task.Result.ParameterSet, random.NextDouble(), true)); - numCompleted++; - } - else - tasks.Add(task); - } - // Cancel after the first barrier and check if the number of registered actions - // is indeed 2 * batchSize. - sweeper.Cancel(); - Task.WaitAll(tasks.ToArray()); - foreach (var task in tasks) + int sweeps = 20; + var tasks = new List>(); + int numCompleted = 0; + for (int i = 0; i < sweeps; i++) + { + var task = sweeper.Propose(); + if (i < args.BatchSize - args.Relaxation) { - if (task.Result != null) - numCompleted++; + Assert.True(task.IsCompleted); + sweeper.Update(task.Result.Id, new RunResult(task.Result.ParameterSet, random.NextDouble(), true)); + numCompleted++; } - Assert.Equal(args.BatchSize + args.BatchSize, numCompleted); + else + tasks.Add(task); } + // Cancel after the first barrier and check if the number of registered actions + // is indeed 2 * batchSize. + sweeper.Cancel(); + Task.WaitAll(tasks.ToArray()); + foreach (var task in tasks) + { + if (task.Result != null) + numCompleted++; + } + Assert.Equal(args.BatchSize + args.BatchSize, numCompleted); } [Fact] public void TestDeterministicSweeperAsync() { var random = new Random(42); - using (var env = new ConsoleEnvironment(42)) - { - var args = new DeterministicSweeperAsync.Arguments(); - args.BatchSize = 5; - args.Relaxation = args.BatchSize - 1; - - args.Sweeper = ComponentFactoryUtils.CreateFromFunction( - environ => new SmacSweeper(environ, - new SmacSweeper.Arguments() - { - SweptParameters = new IComponentFactory[] { + var env = new MLContext(42); + var args = new DeterministicSweeperAsync.Arguments(); + args.BatchSize = 5; + args.Relaxation = args.BatchSize - 1; + + args.Sweeper = ComponentFactoryUtils.CreateFromFunction( + environ => new SmacSweeper(environ, + new SmacSweeper.Arguments() + { + SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( t => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( t => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) - } - })); - - var sweeper = new DeterministicSweeperAsync(env, args); + } + })); - // Test single-threaded consumption. - int sweeps = 10; - var paramSets = new List(); - for (int i = 0; i < sweeps; i++) - { - var task = sweeper.Propose(); - Assert.True(task.IsCompleted); - paramSets.Add(task.Result.ParameterSet); - var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true); - sweeper.Update(task.Result.Id, result); - } - Assert.Equal(sweeps, paramSets.Count); - CheckAsyncSweeperResult(paramSets); - - // Create two batches and test if the 2nd batch is executed after the synchronization barrier is reached. - object mlock = new object(); - var tasks = new Task[sweeps]; - args.Relaxation = args.Relaxation - 1; - sweeper = new DeterministicSweeperAsync(env, args); - paramSets.Clear(); - var results = new List>(); - for (int i = 0; i < args.BatchSize; i++) - { - var task = sweeper.Propose(); - Assert.True(task.IsCompleted); - tasks[i] = task; - if (task.Result == null) - continue; - results.Add(new KeyValuePair(task.Result.Id, new RunResult(task.Result.ParameterSet, 0.42, true))); - } - // Register consumers for the 2nd batch. Those consumers will await until at least one run - // in the previous batch has been posted to the sweeper. - for (int i = args.BatchSize; i < 2 * args.BatchSize; i++) - { - var task = sweeper.Propose(); - Assert.False(task.IsCompleted); - tasks[i] = task; - } - // Call update to unblock the 2nd batch. - foreach (var run in results) - sweeper.Update(run.Key, run.Value); + var sweeper = new DeterministicSweeperAsync(env, args); - Task.WaitAll(tasks); - tasks.All(t => t.IsCompleted); + // Test single-threaded consumption. + int sweeps = 10; + var paramSets = new List(); + for (int i = 0; i < sweeps; i++) + { + var task = sweeper.Propose(); + Assert.True(task.IsCompleted); + paramSets.Add(task.Result.ParameterSet); + var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true); + sweeper.Update(task.Result.Id, result); + } + Assert.Equal(sweeps, paramSets.Count); + CheckAsyncSweeperResult(paramSets); + + // Create two batches and test if the 2nd batch is executed after the synchronization barrier is reached. + object mlock = new object(); + var tasks = new Task[sweeps]; + args.Relaxation = args.Relaxation - 1; + sweeper = new DeterministicSweeperAsync(env, args); + paramSets.Clear(); + var results = new List>(); + for (int i = 0; i < args.BatchSize; i++) + { + var task = sweeper.Propose(); + Assert.True(task.IsCompleted); + tasks[i] = task; + if (task.Result == null) + continue; + results.Add(new KeyValuePair(task.Result.Id, new RunResult(task.Result.ParameterSet, 0.42, true))); + } + // Register consumers for the 2nd batch. Those consumers will await until at least one run + // in the previous batch has been posted to the sweeper. + for (int i = args.BatchSize; i < 2 * args.BatchSize; i++) + { + var task = sweeper.Propose(); + Assert.False(task.IsCompleted); + tasks[i] = task; } + // Call update to unblock the 2nd batch. + foreach (var run in results) + sweeper.Update(run.Key, run.Value); + + Task.WaitAll(tasks); + tasks.All(t => t.IsCompleted); } [Fact] public void TestDeterministicSweeperAsyncParallel() { var random = new Random(42); - using (var env = new ConsoleEnvironment(42)) - { - int batchSize = 5; - int sweeps = 20; - var paramSets = new List(); - var args = new DeterministicSweeperAsync.Arguments(); - args.BatchSize = batchSize; - args.Relaxation = batchSize - 2; - - args.Sweeper = ComponentFactoryUtils.CreateFromFunction( - environ => new SmacSweeper(environ, - new SmacSweeper.Arguments() - { - SweptParameters = new IComponentFactory[] { + var env = new MLContext(42); + const int batchSize = 5; + const int sweeps = 20; + var paramSets = new List(); + var args = new DeterministicSweeperAsync.Arguments(); + args.BatchSize = batchSize; + args.Relaxation = batchSize - 2; + + args.Sweeper = ComponentFactoryUtils.CreateFromFunction( + environ => new SmacSweeper(environ, + new SmacSweeper.Arguments() + { + SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( t => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( t => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) - } - })); + } + })); - var sweeper = new DeterministicSweeperAsync(env, args); + var sweeper = new DeterministicSweeperAsync(env, args); - var mlock = new object(); - var options = new ParallelOptions(); - options.MaxDegreeOfParallelism = 4; + var mlock = new object(); + var options = new ParallelOptions(); + options.MaxDegreeOfParallelism = 4; - // Sleep randomly to simulate doing work. - int[] sleeps = new int[sweeps]; - for (int i = 0; i < sleeps.Length; i++) - sleeps[i] = random.Next(10, 100); - var r = Parallel.For(0, sweeps, options, (int i) => - { - var task = sweeper.Propose(); - task.Wait(); - Assert.Equal(TaskStatus.RanToCompletion, task.Status); - var paramWithId = task.Result; - if (paramWithId == null) - return; - Thread.Sleep(sleeps[i]); - var result = new RunResult(paramWithId.ParameterSet, 0.42, true); - sweeper.Update(paramWithId.Id, result); - lock (mlock) - paramSets.Add(paramWithId.ParameterSet); - }); - Assert.True(paramSets.Count <= sweeps); - CheckAsyncSweeperResult(paramSets); - } + // Sleep randomly to simulate doing work. + int[] sleeps = new int[sweeps]; + for (int i = 0; i < sleeps.Length; i++) + sleeps[i] = random.Next(10, 100); + var r = Parallel.For(0, sweeps, options, (int i) => + { + var task = sweeper.Propose(); + task.Wait(); + Assert.Equal(TaskStatus.RanToCompletion, task.Status); + var paramWithId = task.Result; + if (paramWithId == null) + return; + Thread.Sleep(sleeps[i]); + var result = new RunResult(paramWithId.ParameterSet, 0.42, true); + sweeper.Update(paramWithId.Id, result); + lock (mlock) + paramSets.Add(paramWithId.ParameterSet); + }); + Assert.True(paramSets.Count <= sweeps); + CheckAsyncSweeperResult(paramSets); } [Fact] public async Task TestNelderMeadSweeperAsync() { var random = new Random(42); - using (var env = new ConsoleEnvironment(42)) - { - int batchSize = 5; - int sweeps = 40; - var paramSets = new List(); - var args = new DeterministicSweeperAsync.Arguments(); - args.BatchSize = batchSize; - args.Relaxation = 0; - - args.Sweeper = ComponentFactoryUtils.CreateFromFunction( - environ => { - var param = new IComponentFactory[] { + var env = new MLContext(42); + const int batchSize = 5; + const int sweeps = 40; + var paramSets = new List(); + var args = new DeterministicSweeperAsync.Arguments(); + args.BatchSize = batchSize; + args.Relaxation = 0; + + args.Sweeper = ComponentFactoryUtils.CreateFromFunction( + environ => + { + var param = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( innerEnviron => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( innerEnviron => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) - }; + }; - var nelderMeadSweeperArgs = new NelderMeadSweeper.Arguments() - { - SweptParameters = param, - FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction( - (firstBatchSweeperEnviron, firstBatchSweeperArgs) => - new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param })) - }; + var nelderMeadSweeperArgs = new NelderMeadSweeper.Arguments() + { + SweptParameters = param, + FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction( + (firstBatchSweeperEnviron, firstBatchSweeperArgs) => + new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param })) + }; - return new NelderMeadSweeper(environ, nelderMeadSweeperArgs); - } - ); + return new NelderMeadSweeper(environ, nelderMeadSweeperArgs); + } + ); - var sweeper = new DeterministicSweeperAsync(env, args); - var mlock = new object(); - double[] metrics = new double[sweeps]; - for (int i = 0; i < metrics.Length; i++) - metrics[i] = random.NextDouble(); + var sweeper = new DeterministicSweeperAsync(env, args); + var mlock = new object(); + double[] metrics = new double[sweeps]; + for (int i = 0; i < metrics.Length; i++) + metrics[i] = random.NextDouble(); - for (int i = 0; i < sweeps; i++) - { - var paramWithId = await sweeper.Propose(); - if (paramWithId == null) - return; - var result = new RunResult(paramWithId.ParameterSet, metrics[i], true); - sweeper.Update(paramWithId.Id, result); - lock (mlock) - paramSets.Add(paramWithId.ParameterSet); - } - Assert.True(paramSets.Count <= sweeps); - CheckAsyncSweeperResult(paramSets); + for (int i = 0; i < sweeps; i++) + { + var paramWithId = await sweeper.Propose(); + if (paramWithId == null) + return; + var result = new RunResult(paramWithId.ParameterSet, metrics[i], true); + sweeper.Update(paramWithId.Id, result); + lock (mlock) + paramSets.Add(paramWithId.ParameterSet); } + Assert.True(paramSets.Count <= sweeps); + CheckAsyncSweeperResult(paramSets); } private void CheckAsyncSweeperResult(List paramSets) @@ -442,275 +431,266 @@ private void CheckAsyncSweeperResult(List paramSets) [Fact] public void TestRandomGridSweeper() { - using (var env = new ConsoleEnvironment(42)) + var env = new MLContext(42); + var args = new RandomGridSweeper.Arguments() { - var args = new RandomGridSweeper.Arguments() - { - SweptParameters = new[] { + SweptParameters = new[] { ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamArguments() { Name = "foo", Min = 10, Max = 20, NumSteps = 3 })), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 100, Max = 10000, LogBase = true, StepSize = 10 })) } - }; - var sweeper = new RandomGridSweeper(env, args); - var initialList = sweeper.ProposeSweeps(5, new List()); - Assert.Equal(5, initialList.Length); - var gridPoint = new bool[3][] { + }; + var sweeper = new RandomGridSweeper(env, args); + var initialList = sweeper.ProposeSweeps(5, new List()); + Assert.Equal(5, initialList.Length); + var gridPoint = new bool[3][] { new bool[3], new bool[3], new bool[3] }; - int i = 0; - int j = 0; - foreach (var parameterSet in initialList) + int i = 0; + int j = 0; + foreach (var parameterSet in initialList) + { + foreach (var parameterValue in parameterSet) { - foreach (var parameterValue in parameterSet) + if (parameterValue.Name == "foo") { - if (parameterValue.Name == "foo") - { - var val = long.Parse(parameterValue.ValueText); - Assert.True(val == 10 || val == 15 || val == 20); - i = (val == 10) ? 0 : (val == 15) ? 1 : 2; - } - else if (parameterValue.Name == "bar") - { - var val = long.Parse(parameterValue.ValueText); - Assert.True(val == 100 || val == 1000 || val == 10000); - j = (val == 100) ? 0 : (val == 1000) ? 1 : 2; - } - else - { - Assert.True(false, "Wrong parameter"); - } + var val = long.Parse(parameterValue.ValueText); + Assert.True(val == 10 || val == 15 || val == 20); + i = (val == 10) ? 0 : (val == 15) ? 1 : 2; + } + else if (parameterValue.Name == "bar") + { + var val = long.Parse(parameterValue.ValueText); + Assert.True(val == 100 || val == 1000 || val == 10000); + j = (val == 100) ? 0 : (val == 1000) ? 1 : 2; + } + else + { + Assert.True(false, "Wrong parameter"); } - Assert.False(gridPoint[i][j]); - gridPoint[i][j] = true; } + Assert.False(gridPoint[i][j]); + gridPoint[i][j] = true; + } - var nextList = sweeper.ProposeSweeps(5, initialList.Select(p => new RunResult(p))); - Assert.Equal(4, nextList.Length); - foreach (var parameterSet in nextList) + var nextList = sweeper.ProposeSweeps(5, initialList.Select(p => new RunResult(p))); + Assert.Equal(4, nextList.Length); + foreach (var parameterSet in nextList) + { + foreach (var parameterValue in parameterSet) { - foreach (var parameterValue in parameterSet) + if (parameterValue.Name == "foo") { - if (parameterValue.Name == "foo") - { - var val = long.Parse(parameterValue.ValueText); - Assert.True(val == 10 || val == 15 || val == 20); - i = (val == 10) ? 0 : (val == 15) ? 1 : 2; - } - else if (parameterValue.Name == "bar") - { - var val = long.Parse(parameterValue.ValueText); - Assert.True(val == 100 || val == 1000 || val == 10000); - j = (val == 100) ? 0 : (val == 1000) ? 1 : 2; - } - else - { - Assert.True(false, "Wrong parameter"); - } + var val = long.Parse(parameterValue.ValueText); + Assert.True(val == 10 || val == 15 || val == 20); + i = (val == 10) ? 0 : (val == 15) ? 1 : 2; + } + else if (parameterValue.Name == "bar") + { + var val = long.Parse(parameterValue.ValueText); + Assert.True(val == 100 || val == 1000 || val == 10000); + j = (val == 100) ? 0 : (val == 1000) ? 1 : 2; + } + else + { + Assert.True(false, "Wrong parameter"); } - Assert.False(gridPoint[i][j]); - gridPoint[i][j] = true; } + Assert.False(gridPoint[i][j]); + gridPoint[i][j] = true; + } - gridPoint = new bool[3][] { + gridPoint = new bool[3][] { new bool[3], new bool[3], new bool[3] }; - var lastList = sweeper.ProposeSweeps(10, null); - Assert.Equal(9, lastList.Length); - foreach (var parameterSet in lastList) + var lastList = sweeper.ProposeSweeps(10, null); + Assert.Equal(9, lastList.Length); + foreach (var parameterSet in lastList) + { + foreach (var parameterValue in parameterSet) { - foreach (var parameterValue in parameterSet) + if (parameterValue.Name == "foo") { - if (parameterValue.Name == "foo") - { - var val = long.Parse(parameterValue.ValueText); - Assert.True(val == 10 || val == 15 || val == 20); - i = (val == 10) ? 0 : (val == 15) ? 1 : 2; - } - else if (parameterValue.Name == "bar") - { - var val = long.Parse(parameterValue.ValueText); - Assert.True(val == 100 || val == 1000 || val == 10000); - j = (val == 100) ? 0 : (val == 1000) ? 1 : 2; - } - else - { - Assert.True(false, "Wrong parameter"); - } + var val = long.Parse(parameterValue.ValueText); + Assert.True(val == 10 || val == 15 || val == 20); + i = (val == 10) ? 0 : (val == 15) ? 1 : 2; + } + else if (parameterValue.Name == "bar") + { + var val = long.Parse(parameterValue.ValueText); + Assert.True(val == 100 || val == 1000 || val == 10000); + j = (val == 100) ? 0 : (val == 1000) ? 1 : 2; + } + else + { + Assert.True(false, "Wrong parameter"); } - Assert.False(gridPoint[i][j]); - gridPoint[i][j] = true; } - Assert.True(gridPoint.All(bArray => bArray.All(b => b))); + Assert.False(gridPoint[i][j]); + gridPoint[i][j] = true; } + Assert.True(gridPoint.All(bArray => bArray.All(b => b))); } [Fact] public void TestNelderMeadSweeper() { var random = new Random(42); - using (var env = new ConsoleEnvironment(42)) - { - var param = new IComponentFactory[] { + var env = new MLContext(42); + var param = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; - var args = new NelderMeadSweeper.Arguments() - { - SweptParameters = param, - FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction( - (environ, firstBatchArgs) => { - return new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param }); - } - ) - }; - var sweeper = new NelderMeadSweeper(env, args); - var sweeps = sweeper.ProposeSweeps(5, new List()); - Assert.Equal(3, sweeps.Length); - - var results = new List(); - for (int i = 1; i < 10; i++) + var args = new NelderMeadSweeper.Arguments() + { + SweptParameters = param, + FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction( + (environ, firstBatchArgs) => + { + return new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param }); + } + ) + }; + var sweeper = new NelderMeadSweeper(env, args); + var sweeps = sweeper.ProposeSweeps(5, new List()); + Assert.Equal(3, sweeps.Length); + + var results = new List(); + for (int i = 1; i < 10; i++) + { + foreach (var parameterSet in sweeps) { - foreach (var parameterSet in sweeps) + foreach (var parameterValue in parameterSet) { - foreach (var parameterValue in parameterSet) + if (parameterValue.Name == "foo") + { + var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture); + Assert.InRange(val, 1, 5); + } + else if (parameterValue.Name == "bar") { - if (parameterValue.Name == "foo") - { - var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture); - Assert.InRange(val, 1, 5); - } - else if (parameterValue.Name == "bar") - { - var val = long.Parse(parameterValue.ValueText); - Assert.InRange(val, 1, 1000); - } - else - { - Assert.True(false, "Wrong parameter"); - } + var val = long.Parse(parameterValue.ValueText); + Assert.InRange(val, 1, 1000); + } + else + { + Assert.True(false, "Wrong parameter"); } - results.Add(new RunResult(parameterSet, random.NextDouble(), true)); } - - sweeps = sweeper.ProposeSweeps(5, results); + results.Add(new RunResult(parameterSet, random.NextDouble(), true)); } - Assert.True(sweeps.Length <= 5); + + sweeps = sweeper.ProposeSweeps(5, results); } + Assert.True(sweeps.Length <= 5); } [Fact] public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper() { var random = new Random(42); - using (var env = new ConsoleEnvironment(42)) - { - var param = new IComponentFactory[] { + var env = new MLContext(42); + var param = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })) }; - var args = new NelderMeadSweeper.Arguments(); - args.SweptParameters = param; - var sweeper = new NelderMeadSweeper(env, args); - var sweeps = sweeper.ProposeSweeps(5, new List()); - Assert.Equal(3, sweeps.Length); + var args = new NelderMeadSweeper.Arguments(); + args.SweptParameters = param; + var sweeper = new NelderMeadSweeper(env, args); + var sweeps = sweeper.ProposeSweeps(5, new List()); + Assert.Equal(3, sweeps.Length); - var results = new List(); - for (int i = 1; i < 10; i++) + var results = new List(); + for (int i = 1; i < 10; i++) + { + foreach (var parameterSet in sweeps) { - foreach (var parameterSet in sweeps) + foreach (var parameterValue in parameterSet) { - foreach (var parameterValue in parameterSet) + if (parameterValue.Name == "foo") + { + var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture); + Assert.InRange(val, 1, 5); + } + else if (parameterValue.Name == "bar") { - if (parameterValue.Name == "foo") - { - var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture); - Assert.InRange(val, 1, 5); - } - else if (parameterValue.Name == "bar") - { - var val = long.Parse(parameterValue.ValueText); - Assert.InRange(val, 1, 1000); - } - else - { - Assert.True(false, "Wrong parameter"); - } + var val = long.Parse(parameterValue.ValueText); + Assert.InRange(val, 1, 1000); + } + else + { + Assert.True(false, "Wrong parameter"); } - results.Add(new RunResult(parameterSet, random.NextDouble(), true)); } - - sweeps = sweeper.ProposeSweeps(5, results); + results.Add(new RunResult(parameterSet, random.NextDouble(), true)); } - Assert.True(sweeps == null || sweeps.Length <= 5); + + sweeps = sweeper.ProposeSweeps(5, results); } + Assert.True(sweeps == null || sweeps.Length <= 5); } [Fact] public void TestSmacSweeper() { - RunMTAThread(() => + var random = new Random(42); + var env = new MLContext(42); + const int maxInitSweeps = 5; + var args = new SmacSweeper.Arguments() { - var random = new Random(42); - using (var env = new ConsoleEnvironment(42)) - { - int maxInitSweeps = 5; - var args = new SmacSweeper.Arguments() - { - NumberInitialPopulation = 20, - SweptParameters = new IComponentFactory[] { + NumberInitialPopulation = 20, + SweptParameters = new IComponentFactory[] { ComponentFactoryUtils.CreateFromFunction( environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})), ComponentFactoryUtils.CreateFromFunction( environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 100, LogBase = true })) } - }; + }; - var sweeper = new SmacSweeper(env, args); - var results = new List(); - var sweeps = sweeper.ProposeSweeps(maxInitSweeps, results); - Assert.Equal(Math.Min(args.NumberInitialPopulation, maxInitSweeps), sweeps.Length); + var sweeper = new SmacSweeper(env, args); + var results = new List(); + var sweeps = sweeper.ProposeSweeps(maxInitSweeps, results); + Assert.Equal(Math.Min(args.NumberInitialPopulation, maxInitSweeps), sweeps.Length); - for (int i = 1; i < 10; i++) + for (int i = 1; i < 10; i++) + { + foreach (var parameterSet in sweeps) + { + foreach (var parameterValue in parameterSet) { - foreach (var parameterSet in sweeps) + if (parameterValue.Name == "foo") { - foreach (var parameterValue in parameterSet) - { - if (parameterValue.Name == "foo") - { - var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture); - Assert.InRange(val, 1, 5); - } - else if (parameterValue.Name == "bar") - { - var val = long.Parse(parameterValue.ValueText); - Assert.InRange(val, 1, 1000); - } - else - { - Assert.True(false, "Wrong parameter"); - } - } - results.Add(new RunResult(parameterSet, random.NextDouble(), true)); + var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture); + Assert.InRange(val, 1, 5); + } + else if (parameterValue.Name == "bar") + { + var val = long.Parse(parameterValue.ValueText); + Assert.InRange(val, 1, 1000); + } + else + { + Assert.True(false, "Wrong parameter"); } - - sweeps = sweeper.ProposeSweeps(5, results); } - Assert.Equal(5, sweeps.Length); + results.Add(new RunResult(parameterSet, random.NextDouble(), true)); } - }); + + sweeps = sweeper.ProposeSweeps(5, results); + } + // Because only unique configurations are considered, the number asked for may exceed the number actually returned. + Assert.True(sweeps.Length <= 5); } } } diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index 273028ddaa..0a1e7ada2f 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -74,7 +74,8 @@ protected BaseTestBaseline(ITestOutputHelper output) : base(output) // The writer to write to test log files. protected StreamWriter LogWriter; - protected ConsoleEnvironment Env; + private protected ConsoleEnvironment _env; + protected IHostEnvironment Env => _env; protected MLContext ML; private bool _normal; private bool _passed; @@ -96,7 +97,7 @@ protected override void Initialize() string logPath = Path.Combine(logDir, FullTestName + LogSuffix); LogWriter = OpenWriter(logPath); _passed = true; - Env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter) + _env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter) .AddStandardComponents(); ML = new MLContext(42); } @@ -106,9 +107,8 @@ protected override void Initialize() // It is called as a first step in test clean up. protected override void Cleanup() { - if (Env != null) - Env.Dispose(); - Env = null; + _env?.Dispose(); + _env = null; Contracts.Assert(IsActive); Log("Test {0}: {1}: {2}", TestName, @@ -535,7 +535,7 @@ private bool MatchNumberWithTolerance(MatchCollection firstCollection, MatchColl double f1 = double.Parse(firstCollection[i].ToString()); double f2 = double.Parse(secondCollection[i].ToString()); - if(!CompareNumbersWithTolerance(f1, f2, i, digitsOfPrecision)) + if (!CompareNumbersWithTolerance(f1, f2, i, digitsOfPrecision)) { return false; } @@ -562,23 +562,23 @@ public bool CompareNumbersWithTolerance(double expected, double actual, int? ite // would fail the inRange == true check, but would suceed the following, and we doconsider those two numbers // (1.82844949 - 1.8284502) = -0.00000071 - double delta2 = 0; - if (!inRange) - { - delta2 = Math.Round(expected - actual, digitsOfPrecision); - inRange = delta2 >= -allowedVariance && delta2 <= allowedVariance; - } + double delta2 = 0; + if (!inRange) + { + delta2 = Math.Round(expected - actual, digitsOfPrecision); + inRange = delta2 >= -allowedVariance && delta2 <= allowedVariance; + } - if (!inRange) - { - var message = iterationOnCollection != null ? "" : $"Output and baseline mismatch at line {iterationOnCollection}." + Environment.NewLine; + if (!inRange) + { + var message = iterationOnCollection != null ? "" : $"Output and baseline mismatch at line {iterationOnCollection}." + Environment.NewLine; - Fail(_allowMismatch, message + - $"Values to compare are {expected} and {actual}" + Environment.NewLine + - $"\t AllowedVariance: {allowedVariance}" + Environment.NewLine + - $"\t delta: {delta}" + Environment.NewLine + - $"\t delta2: {delta2}" + Environment.NewLine); - } + Fail(_allowMismatch, message + + $"Values to compare are {expected} and {actual}" + Environment.NewLine + + $"\t AllowedVariance: {allowedVariance}" + Environment.NewLine + + $"\t delta: {delta}" + Environment.NewLine + + $"\t delta2: {delta2}" + Environment.NewLine); + } return inRange; } @@ -817,11 +817,8 @@ protected static StreamReader OpenReader(string path) /// protected static int MainForTest(string args) { - using (var env = new ConsoleEnvironment()) - { - int result = Maml.MainCore(env, args, false); - return result; - } + var env = new MLContext(); + return Maml.MainCore(env, args, false); } } diff --git a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs index 8eb4edb1b5..cf0303bc05 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs @@ -158,7 +158,7 @@ protected void Run(RunContext ctx, int digitsOfPrecision = DigitsOfPrecision) { // Not capturing into a specific log. Log("*** Start raw predictor output"); - res = MainForTest(Env, LogWriter, string.Join(" ", ctx.Command, runcmd), ctx.BaselineProgress); + res = MainForTest(_env, LogWriter, string.Join(" ", ctx.Command, runcmd), ctx.BaselineProgress); Log("*** End raw predictor output, return={0}", res); return; } @@ -189,7 +189,7 @@ protected void Run(RunContext ctx, int digitsOfPrecision = DigitsOfPrecision) Log(" Saving ini file: {0}", str); } - MainForTest(Env, LogWriter, str); + MainForTest(_env, LogWriter, str); files.ForEach(file => CheckEqualityNormalized(dir, file, digitsOfPrecision: digitsOfPrecision)); } diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 306a8a9920..5e5b9d4cd6 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -123,7 +123,7 @@ public void SavePipeLabelParsers() string name = TestName + "4-out.txt"; string pathOut = DeleteOutputPath("SavePipe", name); using (var writer = OpenWriter(pathOut)) - using (Env.RedirectChannelOutput(writer, writer)) + using (_env.RedirectChannelOutput(writer, writer)) { TestCore(pathData, true, new[] { @@ -133,7 +133,7 @@ public void SavePipeLabelParsers() "xf=SelectColumns{keepcol=RawLabel keepcol=FileLabelNum keepcol=FileLabelKey hidden=-}" }, suffix: "4"); writer.WriteLine(ProgressLogLine); - Env.PrintProgress(); + _env.PrintProgress(); } CheckEqualityNormalized("SavePipe", name); diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index ff9c83160b..03ca7d8600 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -172,19 +172,16 @@ private void CheckSameSchemaShape(SchemaShape promised, SchemaShape delivered) /// protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPipe, Action actLoader = null, string suffix = "", string suffixBase = null, bool checkBaseline = true, - bool forceDense = false, bool logCurs = false, ConsoleEnvironment env = null, bool roundTripText = true, + bool forceDense = false, bool logCurs = false, bool roundTripText = true, bool checkTranspose = false, bool checkId = true, bool baselineSchema = true) { Contracts.AssertValue(Env); - if (env == null) - env = Env; MultiFileSource files; IDataLoader compositeLoader; - var pipe1 = compositeLoader = CreatePipeDataLoader(env, pathData, argsPipe, out files); + var pipe1 = compositeLoader = CreatePipeDataLoader(_env, pathData, argsPipe, out files); - if (actLoader != null) - actLoader(compositeLoader); + actLoader?.Invoke(compositeLoader); // Re-apply pipe to the loader and check equality. var comp = compositeLoader as CompositeDataLoader; @@ -194,7 +191,7 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi srcLoader = comp.View; while (srcLoader is IDataTransform) srcLoader = ((IDataTransform)srcLoader).Source; - var reappliedPipe = ApplyTransformUtils.ApplyAllTransformsToData(env, comp.View, srcLoader); + var reappliedPipe = ApplyTransformUtils.ApplyAllTransformsToData(_env, comp.View, srcLoader); if (!CheckMetadataTypes(reappliedPipe.Schema)) Failed(); @@ -210,12 +207,12 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi string pathLog = DeleteOutputPath("SavePipe", name); using (var writer = OpenWriter(pathLog)) - using (env.RedirectChannelOutput(writer, writer)) + using (_env.RedirectChannelOutput(writer, writer)) { long count = 0; // Set the concurrency to 1 for this; restore later. - int conc = env.ConcurrencyFactor; - env.ConcurrencyFactor = 1; + int conc = _env.ConcurrencyFactor; + _env.ConcurrencyFactor = 1; using (var curs = pipe1.GetRowCursor(c => true, null)) { while (curs.MoveNext()) @@ -224,14 +221,14 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi } } writer.WriteLine("Cursored through {0} rows", count); - env.ConcurrencyFactor = conc; + _env.ConcurrencyFactor = conc; } CheckEqualityNormalized("SavePipe", name); } var pathModel = SavePipe(pipe1, suffix); - var pipe2 = LoadPipe(pathModel, env, files); + var pipe2 = LoadPipe(pathModel, _env, files); if (!CheckMetadataTypes(pipe2.Schema)) Failed(); @@ -243,21 +240,21 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi if (pipe1.Schema.ColumnCount > 0) { // The text saver fails if there are no columns, so we cannot check in that case. - if (!SaveLoadText(pipe1, env, keepHidden, suffix, suffixBase, checkBaseline, forceDense, roundTripText)) + if (!SaveLoadText(pipe1, _env, keepHidden, suffix, suffixBase, checkBaseline, forceDense, roundTripText)) Failed(); // The transpose saver likewise fails for the same reason. - if (checkTranspose && !SaveLoadTransposed(pipe1, env, suffix)) + if (checkTranspose && !SaveLoadTransposed(pipe1, _env, suffix)) Failed(); } - if (!SaveLoad(pipe1, env, suffix)) + if (!SaveLoad(pipe1, _env, suffix)) Failed(); // Check that the pipe doesn't shuffle when it cannot :). if (srcLoader != null) { // First we need to cache the data so it can be shuffled. - var cachedData = new CacheDataView(env, srcLoader, null); - var newPipe = ApplyTransformUtils.ApplyAllTransformsToData(env, comp.View, cachedData); + var cachedData = new CacheDataView(_env, srcLoader, null); + var newPipe = ApplyTransformUtils.ApplyAllTransformsToData(_env, comp.View, cachedData); if (!newPipe.CanShuffle) { using (var c1 = newPipe.GetRowCursor(col => true, new SysRandom(123))) diff --git a/test/Microsoft.ML.TestFramework/ModelHelper.cs b/test/Microsoft.ML.TestFramework/ModelHelper.cs index 94341e4a85..3fe189183b 100644 --- a/test/Microsoft.ML.TestFramework/ModelHelper.cs +++ b/test/Microsoft.ML.TestFramework/ModelHelper.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML.TestFramework { public static class ModelHelper { - private static ConsoleEnvironment s_environment = new ConsoleEnvironment(seed: 1); + private static IHostEnvironment s_environment = new MLContext(seed: 1); private static ITransformModel s_housePriceModel; public static void WriteKcHousePriceModel(string dataPath, string outputModelPath) @@ -35,7 +35,6 @@ public static void WriteKcHousePriceModel(string dataPath, Stream stream) { s_housePriceModel = CreateKcHousePricePredictorModel(dataPath); } - s_housePriceModel.Save(s_environment, stream); } diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs index 75f8d5d1e0..4e5c04395f 100644 --- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs +++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs @@ -292,10 +292,10 @@ protected bool TestCore(RunContextBase ctx, string cmdName, string args, int dig Contracts.AssertValueOrNull(args); OutputPath outputPath = ctx.StdoutPath(); using (var newWriter = OpenWriter(outputPath.Path)) - using (Env.RedirectChannelOutput(newWriter, newWriter)) + using (_env.RedirectChannelOutput(newWriter, newWriter)) { - Env.ResetProgressChannel(); - int res = MainForTest(Env, newWriter, string.Format("{0} {1}", cmdName, args), ctx.BaselineProgress); + _env.ResetProgressChannel(); + int res = MainForTest(_env, newWriter, string.Format("{0} {1}", cmdName, args), ctx.BaselineProgress); if (res != 0) Log("*** Predictor returned {0}", res); } @@ -322,7 +322,7 @@ protected bool TestCore(RunContextBase ctx, string cmdName, string args, int dig /// /// The arguments for MAML. /// Whether to print the progress summary. If true, progress summary will appear in the end of baseline output file. - protected static int MainForTest(ConsoleEnvironment env, TextWriter writer, string args, bool printProgress = false) + private protected static int MainForTest(ConsoleEnvironment env, TextWriter writer, string args, bool printProgress = false) { Contracts.AssertValue(env); Contracts.AssertValue(writer); @@ -364,7 +364,7 @@ private bool TestCoreCore(RunContextBase ctx, string cmdName, string dataPath, P return TestCoreCore(ctx, cmdName, dataPath, situation, inModelPath, outModelPath, loaderArgs, extraArgs, DigitsOfPrecision, toCompare); } - private bool TestCoreCore(RunContextBase ctx, string cmdName, string dataPath, PathArgument.Usage situation, + private bool TestCoreCore(RunContextBase ctx, string cmdName, string dataPath, PathArgument.Usage situation, OutputPath inModelPath, OutputPath outModelPath, string loaderArgs, string extraArgs, int digitsOfPrecision, params PathArgument[] toCompare) { Contracts.AssertNonEmpty(cmdName); @@ -503,24 +503,22 @@ private string DataArg(string dataPath) protected void TestPipeFromModel(string dataPath, OutputPath model) { - using (var env = new ConsoleEnvironment(42)) - { - var files = new MultiFileSource(dataPath); + var env = new MLContext(seed: 42); + var files = new MultiFileSource(dataPath); - bool tmp; - IDataView pipe; - using (var file = Env.OpenInputFile(model.Path)) - using (var strm = file.OpenReadStream()) - using (var rep = RepositoryReader.Open(strm, env)) - { - ModelLoadContext.LoadModel(env, - out pipe, rep, ModelFileUtils.DirDataLoaderModel, files); - } - - using (var c = pipe.GetRowCursor(col => true)) - tmp = CheckSameValues(c, pipe, true, true, true); - Check(tmp, "Single value same failed"); + bool tmp; + IDataView pipe; + using (var file = Env.OpenInputFile(model.Path)) + using (var strm = file.OpenReadStream()) + using (var rep = RepositoryReader.Open(strm, env)) + { + ModelLoadContext.LoadModel(env, + out pipe, rep, ModelFileUtils.DirDataLoaderModel, files); } + + using (var c = pipe.GetRowCursor(col => true)) + tmp = CheckSameValues(c, pipe, true, true, true); + Check(tmp, "Single value same failed"); } } @@ -1969,7 +1967,7 @@ public void CommandTrainingBinaryFieldAwareFactorizationMachineWithInitializatio string data = GetDataPath("breast-cancer.txt"); OutputPath model = ModelPath(); - TestCore("traintest", data, loaderArgs, extraArgs + " test=" + data, digitsOfPrecision:5); + TestCore("traintest", data, loaderArgs, extraArgs + " test=" + data, digitsOfPrecision: 5); _step++; TestInOutCore("traintest", data, model, extraArgs + " " + loaderArgs + " " + "cont+" + " " + "test=" + data); @@ -1988,17 +1986,17 @@ public void CommandTrainingBinaryFactorizationMachineWithValidation() string args = $"{loaderArgs} data={trainData} valid={validData} test={validData} {extraArgs} out={model}"; OutputPath outputPath = StdoutPath(); using (var newWriter = OpenWriter(outputPath.Path)) - using (Env.RedirectChannelOutput(newWriter, newWriter)) + using (_env.RedirectChannelOutput(newWriter, newWriter)) { - Env.ResetProgressChannel(); - int res = MainForTest(Env, newWriter, string.Format("{0} {1}", "traintest", args), true); + _env.ResetProgressChannel(); + int res = MainForTest(_env, newWriter, string.Format("{0} {1}", "traintest", args), true); Assert.True(res == 0); } // see https://github.com/dotnet/machinelearning/issues/404 // in Linux, the clang sqrt() results vary highly from the ones in mac and Windows. if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - Assert.True(outputPath.CheckEqualityNormalized(digitsOfPrecision:4)); + Assert.True(outputPath.CheckEqualityNormalized(digitsOfPrecision: 4)); else Assert.True(outputPath.CheckEqualityNormalized()); @@ -2017,10 +2015,10 @@ public void CommandTrainingBinaryFieldAwareFactorizationMachineWithValidation() string args = $"{loaderArgs} data={trainData} valid={validData} test={validData} {extraArgs} out={model}"; OutputPath outputPath = StdoutPath(); using (var newWriter = OpenWriter(outputPath.Path)) - using (Env.RedirectChannelOutput(newWriter, newWriter)) + using (_env.RedirectChannelOutput(newWriter, newWriter)) { - Env.ResetProgressChannel(); - int res = MainForTest(Env, newWriter, string.Format("{0} {1}", "traintest", args), true); + _env.ResetProgressChannel(); + int res = MainForTest(_env, newWriter, string.Format("{0} {1}", "traintest", args), true); Assert.Equal(0, res); } @@ -2044,15 +2042,15 @@ public void CommandTrainingBinaryFactorizationMachineWithValidationAndInitializa OutputPath outputPath = StdoutPath(); string args = $"data={data} test={data} valid={data} in={model.Path} cont+" + " " + loaderArgs + " " + extraArgs; using (var newWriter = OpenWriter(outputPath.Path)) - using (Env.RedirectChannelOutput(newWriter, newWriter)) + using (_env.RedirectChannelOutput(newWriter, newWriter)) { - Env.ResetProgressChannel(); - int res = MainForTest(Env, newWriter, string.Format("{0} {1}", "traintest", args), true); + _env.ResetProgressChannel(); + int res = MainForTest(_env, newWriter, string.Format("{0} {1}", "traintest", args), true); Assert.True(res == 0); } if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - Assert.True(outputPath.CheckEqualityNormalized(digitsOfPrecision:4)); + Assert.True(outputPath.CheckEqualityNormalized(digitsOfPrecision: 4)); else Assert.True(outputPath.CheckEqualityNormalized()); @@ -2074,10 +2072,10 @@ public void CommandTrainingBinaryFieldAwareFactorizationMachineWithValidationAnd OutputPath outputPath = StdoutPath(); string args = $"data={data} test={data} valid={data} in={model.Path} cont+" + " " + loaderArgs + " " + extraArgs; using (var newWriter = OpenWriter(outputPath.Path)) - using (Env.RedirectChannelOutput(newWriter, newWriter)) + using (_env.RedirectChannelOutput(newWriter, newWriter)) { - Env.ResetProgressChannel(); - int res = MainForTest(Env, newWriter, string.Format("{0} {1}", "traintest", args), true); + _env.ResetProgressChannel(); + int res = MainForTest(_env, newWriter, string.Format("{0} {1}", "traintest", args), true); Assert.True(res == 0); } diff --git a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs index a674618dd8..75cd3ca516 100644 --- a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs +++ b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs @@ -48,28 +48,26 @@ private void GenericSparseDataView(T[] v1, T[] v2) new SparseExample() { X = new VBuffer (5, 3, v1, new int[] { 0, 2, 4 }) }, new SparseExample() { X = new VBuffer (5, 3, v2, new int[] { 0, 1, 3 }) } }; - using (var host = new ConsoleEnvironment()) + var env = new MLContext(); + var data = env.CreateStreamingDataView(inputs); + var value = new VBuffer(); + int n = 0; + using (var cur = data.GetRowCursor(i => true)) { - var data = host.CreateStreamingDataView(inputs); - var value = new VBuffer(); - int n = 0; - using (var cur = data.GetRowCursor(i => true)) + var getter = cur.GetGetter>(0); + while (cur.MoveNext()) { - var getter = cur.GetGetter>(0); - while (cur.MoveNext()) - { - getter(ref value); - Assert.True(value.GetValues().Length == 3); - ++n; - } - } - Assert.True(n == 2); - var iter = data.AsEnumerable>(host, false).GetEnumerator(); - n = 0; - while (iter.MoveNext()) + getter(ref value); + Assert.True(value.GetValues().Length == 3); ++n; - Assert.True(n == 2); + } } + Assert.True(n == 2); + var iter = data.AsEnumerable>(env, false).GetEnumerator(); + n = 0; + while (iter.MoveNext()) + ++n; + Assert.True(n == 2); } [Fact] @@ -90,28 +88,26 @@ private void GenericDenseDataView(T[] v1, T[] v2) new DenseExample() { X = v1 }, new DenseExample() { X = v2 } }; - using (var host = new ConsoleEnvironment()) + var env = new MLContext(); + var data = env.CreateStreamingDataView(inputs); + var value = new VBuffer(); + int n = 0; + using (var cur = data.GetRowCursor(i => true)) { - var data = host.CreateStreamingDataView(inputs); - var value = new VBuffer(); - int n = 0; - using (var cur = data.GetRowCursor(i => true)) + var getter = cur.GetGetter>(0); + while (cur.MoveNext()) { - var getter = cur.GetGetter>(0); - while (cur.MoveNext()) - { - getter(ref value); - Assert.True(value.GetValues().Length == 3); - ++n; - } - } - Assert.True(n == 2); - var iter = data.AsEnumerable>(host, false).GetEnumerator(); - n = 0; - while (iter.MoveNext()) + getter(ref value); + Assert.True(value.GetValues().Length == 3); ++n; - Assert.True(n == 2); + } } + Assert.True(n == 2); + var iter = data.AsEnumerable>(env, false).GetEnumerator(); + n = 0; + while (iter.MoveNext()) + ++n; + Assert.True(n == 2); } } } diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 2e3b71ab01..8862864885 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -59,15 +59,13 @@ public void CheckConstructor() public void CanSuccessfullyApplyATransform() { var collection = CollectionDataSource.Create(new List() { new Input { Number1 = 1, String1 = "1" } }); - using (var environment = new ConsoleEnvironment()) - { + var environment = new MLContext(); Experiment experiment = environment.CreateExperiment(); Legacy.ILearningPipelineDataStep output = (Legacy.ILearningPipelineDataStep)collection.ApplyStep(null, experiment); Assert.NotNull(output.Data); Assert.NotNull(output.Data.VarName); Assert.Null(output.Model); - } } [Fact] @@ -79,9 +77,8 @@ public void CanSuccessfullyEnumerated() new Input { Number1 = 3, String1 = "3" } }); - using (var environment = new ConsoleEnvironment()) - { - Experiment experiment = environment.CreateExperiment(); + var environment = new MLContext(); + Experiment experiment = environment.CreateExperiment(); Legacy.ILearningPipelineDataStep output = collection.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep; experiment.Compile(); @@ -128,7 +125,6 @@ public void CanSuccessfullyEnumerated() Assert.False(cursor.MoveNext()); } - } } [Fact] @@ -294,7 +290,7 @@ public class ConversionSimpleClass public float fFloat; public double fDouble; public bool fBool; - public string fString=""; + public string fString = ""; } public bool CompareObjectValues(object x, object y, Type type) @@ -418,17 +414,15 @@ public void RoundTripConversionWithBasicTypes() new ConversionSimpleClass() }; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalEnumerator = data.GetEnumerator(); - while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); - } - Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); } public class ConversionNotSupportedMinValueClass @@ -442,27 +436,25 @@ public class ConversionNotSupportedMinValueClass [Fact] public void ConversionExceptionsBehavior() { - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var data = new ConversionNotSupportedMinValueClass[1]; + foreach (var field in typeof(ConversionNotSupportedMinValueClass).GetFields()) { - var data = new ConversionNotSupportedMinValueClass[1]; - foreach (var field in typeof(ConversionNotSupportedMinValueClass).GetFields()) + data[0] = new ConversionNotSupportedMinValueClass(); + FieldInfo fi; + if ((fi = field.FieldType.GetField("MinValue")) != null) + { + field.SetValue(data[0], fi.GetValue(null)); + } + var dataView = ComponentCreation.CreateDataView(env, data); + var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); + try + { + enumerator.MoveNext(); + Assert.True(false); + } + catch { - data[0] = new ConversionNotSupportedMinValueClass(); - FieldInfo fi; - if ((fi = field.FieldType.GetField("MinValue")) != null) - { - field.SetValue(data[0], fi.GetValue(null)); - } - var dataView = ComponentCreation.CreateDataView(env, data); - var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); - try - { - enumerator.MoveNext(); - Assert.True(false); - } - catch - { - } } } } @@ -496,15 +488,13 @@ public void ClassWithConstFieldsConversion() new ClassWithConstField(){ fInt=-1, fString ="" }, }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalEnumerator = data.GetEnumerator(); - while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); - Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - } + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); } @@ -524,15 +514,13 @@ public void ClassWithMixOfFieldsAndPropertiesConversion() new ClassWithMixOfFieldsAndProperties(){ IntProp=-1, fString ="" }, }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalEnumerator = data.GetEnumerator(); - while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); - Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - } + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); } public abstract class BaseClassWithInheritedProperties @@ -580,28 +568,25 @@ public void ClassWithPrivateFieldsAndPropertiesConversion() new ClassWithPrivateFieldsAndProperties(){ StringProp ="baba" } }; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalEnumerator = data.GetEnumerator(); - while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); - Assert.True(enumeratorSimple.Current.UnusedPropertyWithPrivateSetter == 100); - } - Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(enumeratorSimple.Current.UnusedPropertyWithPrivateSetter == 100); } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); } public class ClassWithInheritedProperties : BaseClassWithInheritedProperties { - private int _fInt; private long _fLong; private byte _fByte2; - public int IntProp { get { return _fInt; } set { _fInt = value; } } - public override long LongProp { get { return _fLong; } set { _fLong = value; } } - public override byte ByteProp { get { return _fByte2; } set { _fByte2 = value; } } + public int IntProp { get; set; } + public override long LongProp { get => _fLong; set => _fLong = value; } + public override byte ByteProp { get => _fByte2; set => _fByte2 = value; } } [Fact] @@ -613,15 +598,13 @@ public void ClassWithInheritedPropertiesConversion() new ClassWithInheritedProperties(){ IntProp=-1, StringProp ="", LongProp=2, ByteProp=4 }, }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalEnumerator = data.GetEnumerator(); - while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); - Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - } + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); } public class ClassWithArrays @@ -666,44 +649,30 @@ public void RoundTripConversionWithArrays() }; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalEnumerator = data.GetEnumerator(); - while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); - } - Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); } public class ClassWithArrayProperties { - private string[] _fString; - private int[] _fInt; - private uint[] _fuInt; - private short[] _fShort; - private ushort[] _fuShort; - private sbyte[] _fsByte; - private byte[] _fByte; - private long[] _fLong; - private ulong[] _fuLong; - private float[] _fFloat; - private double[] _fDouble; - private bool[] _fBool; - public string[] StringProp { get { return _fString; } set { _fString = value; } } - public int[] IntProp { get { return _fInt; } set { _fInt = value; } } - public uint[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } - public short[] ShortProp { get { return _fShort; } set { _fShort = value; } } - public ushort[] UShortProp { get { return _fuShort; } set { _fuShort = value; } } - public sbyte[] SByteProp { get { return _fsByte; } set { _fsByte = value; } } - public byte[] ByteProp { get { return _fByte; } set { _fByte = value; } } - public long[] LongProp { get { return _fLong; } set { _fLong = value; } } - public ulong[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } - public float[] FloatProp { get { return _fFloat; } set { _fFloat = value; } } - public double[] DobuleProp { get { return _fDouble; } set { _fDouble = value; } } - public bool[] BoolProp { get { return _fBool; } set { _fBool = value; } } + public string[] StringProp { get; set; } + public int[] IntProp { get; set; } + public uint[] UIntProp { get; set; } + public short[] ShortProp { get; set; } + public ushort[] UShortProp { get; set; } + public sbyte[] SByteProp { get; set; } + public byte[] ByteProp { get; set; } + public long[] LongProp { get; set; } + public ulong[] ULongProp { get; set; } + public float[] FloatProp { get; set; } + public double[] DobuleProp { get; set; } + public bool[] BoolProp { get; set; } } [Fact] @@ -731,27 +700,23 @@ public void RoundTripConversionWithArrayPropertiess() new ClassWithArrayProperties() }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalEnumerator = data.GetEnumerator(); - while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); - } - Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - } + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); } - class ClassWithGetter + private sealed class ClassWithGetter { private DateTime _dateTime = DateTime.Now; - public float Day { get { return _dateTime.Day; } } - public int Hour { get { return _dateTime.Hour; } } + public float Day => _dateTime.Day; + public int Hour => _dateTime.Hour; } - class ClassWithSetter + private sealed class ClassWithSetter { public float Day { private get; set; } public int Hour { private get; set; } @@ -772,16 +737,14 @@ public void PrivateGetSetProperties() new ClassWithGetter() }; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalEnumerator = data.GetEnumerator(); - while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) - { - Assert.True(enumeratorSimple.Current.GetDay == originalEnumerator.Current.Day && - enumeratorSimple.Current.GetHour == originalEnumerator.Current.Hour); - } + Assert.True(enumeratorSimple.Current.GetDay == originalEnumerator.Current.Day && + enumeratorSimple.Current.GetHour == originalEnumerator.Current.Hour); } } } diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index 3c4522bf1d..cde998e2b4 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.ImageAnalytics; @@ -25,72 +26,69 @@ public ImageTests(ITestOutputHelper output) : base(output) [Fact] public void TestEstimatorChain() { - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var invalidData = TextLoader.Create(env, new TextLoader.Arguments() + }, new MultiFileSource(dataFile)); + var invalidData = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.R4, 0), } - }, new MultiFileSource(dataFile)); + }, new MultiFileSource(dataFile)); - var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal")) - .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100)) - .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels")) - .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray"))); + var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal")) + .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100)) + .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels")) + .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray"))); - TestEstimatorCore(pipe, data, null, invalidData); - } + TestEstimatorCore(pipe, data, null, invalidData); Done(); } [Fact] public void TestEstimatorSaveLoad() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); + }, new MultiFileSource(dataFile)); - var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal")) - .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100)) - .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels")) - .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray"))); + var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal")) + .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100)) + .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels")) + .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray"))); - pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema)); - var model = pipe.Fit(data); - - using (var file = env.CreateTempFile()) - { - using (var fs = file.CreateWriteStream()) - model.SaveTo(env, fs); - var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream()); + pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema)); + var model = pipe.Fit(data); - var newCols = ((ImageLoaderTransform)model2.First()).Columns; - var oldCols = ((ImageLoaderTransform)model.First()).Columns; - Assert.True(newCols - .Zip(oldCols, (x, y) => x == y) - .All(x => x)); - } + var tempPath = Path.GetTempFileName(); + using (var file = new SimpleFileHandle(env, tempPath, true, true)) + { + using (var fs = file.CreateWriteStream()) + model.SaveTo(env, fs); + var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream()); + + var newCols = ((ImageLoaderTransform)model2.First()).Columns; + var oldCols = ((ImageLoaderTransform)model.First()).Columns; + Assert.True(newCols + .Zip(oldCols, (x, y) => x == y) + .All(x => x)); } Done(); } @@ -98,50 +96,48 @@ public void TestEstimatorSaveLoad() [Fact] public void TestSaveImages() { - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); + }, + ImageFolder = imageFolder + }, data); - IDataView cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + IDataView cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =100, ImageWidth = 100, Resizing = ImageResizerTransform.ResizingKind.IsoPad} } - }, images); + }, images); - cropped.Schema.TryGetColumnIndex("ImagePath", out int pathColumn); - cropped.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); - using (var cursor = cropped.GetRowCursor((x) => true)) - { - var pathGetter = cursor.GetGetter>(pathColumn); - ReadOnlyMemory path = default; - var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); - Bitmap bitmap = default; - while (cursor.MoveNext()) - { - pathGetter(ref path); - bitmapCropGetter(ref bitmap); - Assert.NotNull(bitmap); - var fileToSave = GetOutputPath(Path.GetFileNameWithoutExtension(path.ToString()) + ".cropped.jpg"); - bitmap.Save(fileToSave, System.Drawing.Imaging.ImageFormat.Jpeg); - } + cropped.Schema.TryGetColumnIndex("ImagePath", out int pathColumn); + cropped.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); + using (var cursor = cropped.GetRowCursor((x) => true)) + { + var pathGetter = cursor.GetGetter>(pathColumn); + ReadOnlyMemory path = default; + var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); + Bitmap bitmap = default; + while (cursor.MoveNext()) + { + pathGetter(ref path); + bitmapCropGetter(ref bitmap); + Assert.NotNull(bitmap); + var fileToSave = GetOutputPath(Path.GetFileNameWithoutExtension(path.ToString()) + ".cropped.jpg"); + bitmap.Save(fileToSave, System.Drawing.Imaging.ImageFormat.Jpeg); } } Done(); @@ -150,68 +146,66 @@ public void TestSaveImages() [Fact] public void TestGreyscaleTransformImages() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + var imageHeight = 150; + var imageWidth = 100; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 150; - var imageWidth = 100; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - IDataView grey = ImageGrayscaleTransform.Create(env, new ImageGrayscaleTransform.Arguments() - { - Column = new ImageGrayscaleTransform.Column[1]{ + IDataView grey = ImageGrayscaleTransform.Create(env, new ImageGrayscaleTransform.Arguments() + { + Column = new ImageGrayscaleTransform.Column[1]{ new ImageGrayscaleTransform.Column() { Name= "ImageGrey", Source = "ImageCropped"} } - }, cropped); + }, cropped); - var fname = nameof(TestGreyscaleTransformImages) + "_model.zip"; + var fname = nameof(TestGreyscaleTransformImages) + "_model.zip"; - var fh = env.CreateOutputFile(fname); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(grey)); + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(grey)); - grey = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath(fname); + grey = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); - grey.Schema.TryGetColumnIndex("ImageGrey", out int greyColumn); - using (var cursor = grey.GetRowCursor((x) => true)) - { - var bitmapGetter = cursor.GetGetter(greyColumn); - Bitmap bitmap = default; - while (cursor.MoveNext()) - { - bitmapGetter(ref bitmap); - Assert.NotNull(bitmap); - for (int x = 0; x < imageWidth; x++) - for (int y = 0; y < imageHeight; y++) - { - var pixel = bitmap.GetPixel(x, y); - // greyscale image has same values for R,G and B - Assert.True(pixel.R == pixel.G && pixel.G == pixel.B); - } - } + grey.Schema.TryGetColumnIndex("ImageGrey", out int greyColumn); + using (var cursor = grey.GetRowCursor((x) => true)) + { + var bitmapGetter = cursor.GetGetter(greyColumn); + Bitmap bitmap = default; + while (cursor.MoveNext()) + { + bitmapGetter(ref bitmap); + Assert.NotNull(bitmap); + for (int x = 0; x < imageWidth; x++) + for (int y = 0; y < imageHeight; y++) + { + var pixel = bitmap.GetPixel(x, y); + // greyscale image has same values for R,G and B + Assert.True(pixel.R == pixel.G && pixel.G == pixel.B); + } } } Done(); @@ -220,88 +214,86 @@ public void TestGreyscaleTransformImages() [Fact] public void TestBackAndForthConversionWithAlphaInterleave() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + const int imageHeight = 100; + const int imageWidth = 130; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 100; - var imageWidth = 130; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - InterleaveArgb = true, - Offset = 127.5f, - Scale = 2f / 255, - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + InterleaveArgb = true, + Offset = 127.5f, + Scale = 2f / 255, + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true} } - }, cropped); + }, cropped); - IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() - { - InterleaveArgb = true, - Offset = -1f, - Scale = 255f / 2, - Column = new VectorToImageTransform.Column[1]{ + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + { + InterleaveArgb = true, + Offset = -1f, + Scale = 255f / 2, + Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } - }, pixels); + }, pixels); - var fname = nameof(TestBackAndForthConversionWithAlphaInterleave) + "_model.zip"; + var fname = nameof(TestBackAndForthConversionWithAlphaInterleave) + "_model.zip"; - var fh = env.CreateOutputFile(fname); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); - backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath(fname); + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); - backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); - backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); - using (var cursor = backToBitmaps.GetRowCursor((x) => true)) - { - var bitmapGetter = cursor.GetGetter(bitmapColumn); - Bitmap restoredBitmap = default; - - var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); - Bitmap croppedBitmap = default; - while (cursor.MoveNext()) - { - bitmapGetter(ref restoredBitmap); - Assert.NotNull(restoredBitmap); - bitmapCropGetter(ref croppedBitmap); - Assert.NotNull(croppedBitmap); - for (int x = 0; x < imageWidth; x++) - for (int y = 0; y < imageHeight; y++) - { - var c = croppedBitmap.GetPixel(x, y); - var r = restoredBitmap.GetPixel(x, y); - Assert.True(c == r); - } - } + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); + backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); + using (var cursor = backToBitmaps.GetRowCursor((x) => true)) + { + var bitmapGetter = cursor.GetGetter(bitmapColumn); + Bitmap restoredBitmap = default; + + var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); + Bitmap croppedBitmap = default; + while (cursor.MoveNext()) + { + bitmapGetter(ref restoredBitmap); + Assert.NotNull(restoredBitmap); + bitmapCropGetter(ref croppedBitmap); + Assert.NotNull(croppedBitmap); + for (int x = 0; x < imageWidth; x++) + for (int y = 0; y < imageHeight; y++) + { + var c = croppedBitmap.GetPixel(x, y); + var r = restoredBitmap.GetPixel(x, y); + Assert.True(c == r); + } } } Done(); @@ -310,88 +302,86 @@ public void TestBackAndForthConversionWithAlphaInterleave() [Fact] public void TestBackAndForthConversionWithoutAlphaInterleave() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + const int imageHeight = 100; + const int imageWidth = 130; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 100; - var imageWidth = 130; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - InterleaveArgb = true, - Offset = 127.5f, - Scale = 2f / 255, - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + InterleaveArgb = true, + Offset = 127.5f, + Scale = 2f / 255, + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=false} } - }, cropped); + }, cropped); - IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() - { - InterleaveArgb = true, - Offset = -1f, - Scale = 255f / 2, - Column = new VectorToImageTransform.Column[1]{ + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + { + InterleaveArgb = true, + Offset = -1f, + Scale = 255f / 2, + Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} } - }, pixels); + }, pixels); - var fname = nameof(TestBackAndForthConversionWithoutAlphaInterleave) + "_model.zip"; + var fname = nameof(TestBackAndForthConversionWithoutAlphaInterleave) + "_model.zip"; - var fh = env.CreateOutputFile(fname); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); - backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath(fname); + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); - backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); - backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); - using (var cursor = backToBitmaps.GetRowCursor((x) => true)) - { - var bitmapGetter = cursor.GetGetter(bitmapColumn); - Bitmap restoredBitmap = default; - - var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); - Bitmap croppedBitmap = default; - while (cursor.MoveNext()) - { - bitmapGetter(ref restoredBitmap); - Assert.NotNull(restoredBitmap); - bitmapCropGetter(ref croppedBitmap); - Assert.NotNull(croppedBitmap); - for (int x = 0; x < imageWidth; x++) - for (int y = 0; y < imageHeight; y++) - { - var c = croppedBitmap.GetPixel(x, y); - var r = restoredBitmap.GetPixel(x, y); - Assert.True(c.R == r.R && c.G == r.G && c.B == r.B); - } - } + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); + backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); + using (var cursor = backToBitmaps.GetRowCursor((x) => true)) + { + var bitmapGetter = cursor.GetGetter(bitmapColumn); + Bitmap restoredBitmap = default; + + var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); + Bitmap croppedBitmap = default; + while (cursor.MoveNext()) + { + bitmapGetter(ref restoredBitmap); + Assert.NotNull(restoredBitmap); + bitmapCropGetter(ref croppedBitmap); + Assert.NotNull(croppedBitmap); + for (int x = 0; x < imageWidth; x++) + for (int y = 0; y < imageHeight; y++) + { + var c = croppedBitmap.GetPixel(x, y); + var r = restoredBitmap.GetPixel(x, y); + Assert.True(c.R == r.R && c.G == r.G && c.B == r.B); + } } } Done(); @@ -400,88 +390,86 @@ public void TestBackAndForthConversionWithoutAlphaInterleave() [Fact] public void TestBackAndForthConversionWithAlphaNoInterleave() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + const int imageHeight = 100; + const int imageWidth = 130; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 100; - var imageWidth = 130; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - InterleaveArgb = false, - Offset = 127.5f, - Scale = 2f / 255, - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + InterleaveArgb = false, + Offset = 127.5f, + Scale = 2f / 255, + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true} } - }, cropped); + }, cropped); - IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() - { - InterleaveArgb = false, - Offset = -1f, - Scale = 255f / 2, - Column = new VectorToImageTransform.Column[1]{ + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + { + InterleaveArgb = false, + Offset = -1f, + Scale = 255f / 2, + Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } - }, pixels); + }, pixels); - var fname = nameof(TestBackAndForthConversionWithAlphaNoInterleave) + "_model.zip"; + var fname = nameof(TestBackAndForthConversionWithAlphaNoInterleave) + "_model.zip"; - var fh = env.CreateOutputFile(fname); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); - backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath(fname); + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); - backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); - backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); - using (var cursor = backToBitmaps.GetRowCursor((x) => true)) - { - var bitmapGetter = cursor.GetGetter(bitmapColumn); - Bitmap restoredBitmap = default; - - var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); - Bitmap croppedBitmap = default; - while (cursor.MoveNext()) - { - bitmapGetter(ref restoredBitmap); - Assert.NotNull(restoredBitmap); - bitmapCropGetter(ref croppedBitmap); - Assert.NotNull(croppedBitmap); - for (int x = 0; x < imageWidth; x++) - for (int y = 0; y < imageHeight; y++) - { - var c = croppedBitmap.GetPixel(x, y); - var r = restoredBitmap.GetPixel(x, y); - Assert.True(c == r); - } - } + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); + backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); + using (var cursor = backToBitmaps.GetRowCursor((x) => true)) + { + var bitmapGetter = cursor.GetGetter(bitmapColumn); + Bitmap restoredBitmap = default; + + var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); + Bitmap croppedBitmap = default; + while (cursor.MoveNext()) + { + bitmapGetter(ref restoredBitmap); + Assert.NotNull(restoredBitmap); + bitmapCropGetter(ref croppedBitmap); + Assert.NotNull(croppedBitmap); + for (int x = 0; x < imageWidth; x++) + for (int y = 0; y < imageHeight; y++) + { + var c = croppedBitmap.GetPixel(x, y); + var r = restoredBitmap.GetPixel(x, y); + Assert.True(c == r); + } } } Done(); @@ -490,88 +478,86 @@ public void TestBackAndForthConversionWithAlphaNoInterleave() [Fact] public void TestBackAndForthConversionWithoutAlphaNoInterleave() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + const int imageHeight = 100; + const int imageWidth = 130; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 100; - var imageWidth = 130; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - InterleaveArgb = false, - Offset = 127.5f, - Scale = 2f / 255, - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + InterleaveArgb = false, + Offset = 127.5f, + Scale = 2f / 255, + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=false} } - }, cropped); + }, cropped); - IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() - { - InterleaveArgb = false, - Offset = -1f, - Scale = 255f / 2, - Column = new VectorToImageTransform.Column[1]{ + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + { + InterleaveArgb = false, + Offset = -1f, + Scale = 255f / 2, + Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} } - }, pixels); - - var fname = nameof(TestBackAndForthConversionWithoutAlphaNoInterleave) + "_model.zip"; + }, pixels); - var fh = env.CreateOutputFile(fname); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + var fname = nameof(TestBackAndForthConversionWithoutAlphaNoInterleave) + "_model.zip"; - backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath(fname); + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); - backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); - backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); - using (var cursor = backToBitmaps.GetRowCursor((x) => true)) - { - var bitmapGetter = cursor.GetGetter(bitmapColumn); - Bitmap restoredBitmap = default; - var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); - Bitmap croppedBitmap = default; - while (cursor.MoveNext()) - { - bitmapGetter(ref restoredBitmap); - Assert.NotNull(restoredBitmap); - bitmapCropGetter(ref croppedBitmap); - Assert.NotNull(croppedBitmap); - for (int x = 0; x < imageWidth; x++) - for (int y = 0; y < imageHeight; y++) - { - var c = croppedBitmap.GetPixel(x, y); - var r = restoredBitmap.GetPixel(x, y); - Assert.True(c.R == r.R && c.G == r.G && c.B == r.B); - } - } + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); + backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); + using (var cursor = backToBitmaps.GetRowCursor((x) => true)) + { + var bitmapGetter = cursor.GetGetter(bitmapColumn); + Bitmap restoredBitmap = default; + + var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); + Bitmap croppedBitmap = default; + while (cursor.MoveNext()) + { + bitmapGetter(ref restoredBitmap); + Assert.NotNull(restoredBitmap); + bitmapCropGetter(ref croppedBitmap); + Assert.NotNull(croppedBitmap); + for (int x = 0; x < imageWidth; x++) + for (int y = 0; y < imageHeight; y++) + { + var c = croppedBitmap.GetPixel(x, y); + var r = restoredBitmap.GetPixel(x, y); + Assert.True(c.R == r.R && c.G == r.G && c.B == r.B); + } } } Done(); @@ -580,84 +566,82 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleave() [Fact] public void TestBackAndForthConversionWithAlphaInterleaveNoOffset() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + const int imageHeight = 100; + const int imageWidth = 130; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 100; - var imageWidth = 130; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - InterleaveArgb = true, - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + InterleaveArgb = true, + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true} } - }, cropped); + }, cropped); - IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() - { - InterleaveArgb = true, - Column = new VectorToImageTransform.Column[1]{ + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + { + InterleaveArgb = true, + Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } - }, pixels); + }, pixels); - var fname = nameof(TestBackAndForthConversionWithAlphaInterleaveNoOffset) + "_model.zip"; + var fname = nameof(TestBackAndForthConversionWithAlphaInterleaveNoOffset) + "_model.zip"; - var fh = env.CreateOutputFile(fname); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); - backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath(fname); + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); - backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); - backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); - using (var cursor = backToBitmaps.GetRowCursor((x) => true)) - { - var bitmapGetter = cursor.GetGetter(bitmapColumn); - Bitmap restoredBitmap = default; - - var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); - Bitmap croppedBitmap = default; - while (cursor.MoveNext()) - { - bitmapGetter(ref restoredBitmap); - Assert.NotNull(restoredBitmap); - bitmapCropGetter(ref croppedBitmap); - Assert.NotNull(croppedBitmap); - for (int x = 0; x < imageWidth; x++) - for (int y = 0; y < imageHeight; y++) - { - var c = croppedBitmap.GetPixel(x, y); - var r = restoredBitmap.GetPixel(x, y); - Assert.True(c == r); - } - } + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); + backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); + using (var cursor = backToBitmaps.GetRowCursor((x) => true)) + { + var bitmapGetter = cursor.GetGetter(bitmapColumn); + Bitmap restoredBitmap = default; + + var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); + Bitmap croppedBitmap = default; + while (cursor.MoveNext()) + { + bitmapGetter(ref restoredBitmap); + Assert.NotNull(restoredBitmap); + bitmapCropGetter(ref croppedBitmap); + Assert.NotNull(croppedBitmap); + for (int x = 0; x < imageWidth; x++) + for (int y = 0; y < imageHeight; y++) + { + var c = croppedBitmap.GetPixel(x, y); + var r = restoredBitmap.GetPixel(x, y); + Assert.True(c == r); + } } } Done(); @@ -666,84 +650,82 @@ public void TestBackAndForthConversionWithAlphaInterleaveNoOffset() [Fact] public void TestBackAndForthConversionWithoutAlphaInterleaveNoOffset() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + const int imageHeight = 100; + const int imageWidth = 130; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 100; - var imageWidth = 130; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - InterleaveArgb = true, - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + InterleaveArgb = true, + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=false} } - }, cropped); + }, cropped); - IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() - { - InterleaveArgb = true, - Column = new VectorToImageTransform.Column[1]{ + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + { + InterleaveArgb = true, + Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} } - }, pixels); + }, pixels); - var fname = nameof(TestBackAndForthConversionWithoutAlphaInterleaveNoOffset) + "_model.zip"; + var fname = nameof(TestBackAndForthConversionWithoutAlphaInterleaveNoOffset) + "_model.zip"; - var fh = env.CreateOutputFile(fname); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); - backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath(fname); + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); - backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); - backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); - using (var cursor = backToBitmaps.GetRowCursor((x) => true)) - { - var bitmapGetter = cursor.GetGetter(bitmapColumn); - Bitmap restoredBitmap = default; - - var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); - Bitmap croppedBitmap = default; - while (cursor.MoveNext()) - { - bitmapGetter(ref restoredBitmap); - Assert.NotNull(restoredBitmap); - bitmapCropGetter(ref croppedBitmap); - Assert.NotNull(croppedBitmap); - for (int x = 0; x < imageWidth; x++) - for (int y = 0; y < imageHeight; y++) - { - var c = croppedBitmap.GetPixel(x, y); - var r = restoredBitmap.GetPixel(x, y); - Assert.True(c.R == r.R && c.G == r.G && c.B == r.B); - } - } + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); + backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); + using (var cursor = backToBitmaps.GetRowCursor((x) => true)) + { + var bitmapGetter = cursor.GetGetter(bitmapColumn); + Bitmap restoredBitmap = default; + + var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); + Bitmap croppedBitmap = default; + while (cursor.MoveNext()) + { + bitmapGetter(ref restoredBitmap); + Assert.NotNull(restoredBitmap); + bitmapCropGetter(ref croppedBitmap); + Assert.NotNull(croppedBitmap); + for (int x = 0; x < imageWidth; x++) + for (int y = 0; y < imageHeight; y++) + { + var c = croppedBitmap.GetPixel(x, y); + var r = restoredBitmap.GetPixel(x, y); + Assert.True(c.R == r.R && c.G == r.G && c.B == r.B); + } } } Done(); @@ -752,84 +734,82 @@ public void TestBackAndForthConversionWithoutAlphaInterleaveNoOffset() [Fact] public void TestBackAndForthConversionWithAlphaNoInterleaveNoOffset() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + const int imageHeight = 100; + var imageWidth = 130; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 100; - var imageWidth = 130; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - InterleaveArgb = false, - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + InterleaveArgb = false, + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true} } - }, cropped); + }, cropped); - IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() - { - InterleaveArgb = false, - Column = new VectorToImageTransform.Column[1]{ + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + { + InterleaveArgb = false, + Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } - }, pixels); + }, pixels); - var fname = nameof(TestBackAndForthConversionWithAlphaNoInterleaveNoOffset) + "_model.zip"; + var fname = nameof(TestBackAndForthConversionWithAlphaNoInterleaveNoOffset) + "_model.zip"; - var fh = env.CreateOutputFile(fname); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); - backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath(fname); + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); - backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); - backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); - using (var cursor = backToBitmaps.GetRowCursor((x) => true)) - { - var bitmapGetter = cursor.GetGetter(bitmapColumn); - Bitmap restoredBitmap = default; - - var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); - Bitmap croppedBitmap = default; - while (cursor.MoveNext()) - { - bitmapGetter(ref restoredBitmap); - Assert.NotNull(restoredBitmap); - bitmapCropGetter(ref croppedBitmap); - Assert.NotNull(croppedBitmap); - for (int x = 0; x < imageWidth; x++) - for (int y = 0; y < imageHeight; y++) - { - var c = croppedBitmap.GetPixel(x, y); - var r = restoredBitmap.GetPixel(x, y); - Assert.True(c == r); - } - } + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); + backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); + using (var cursor = backToBitmaps.GetRowCursor((x) => true)) + { + var bitmapGetter = cursor.GetGetter(bitmapColumn); + Bitmap restoredBitmap = default; + + var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); + Bitmap croppedBitmap = default; + while (cursor.MoveNext()) + { + bitmapGetter(ref restoredBitmap); + Assert.NotNull(restoredBitmap); + bitmapCropGetter(ref croppedBitmap); + Assert.NotNull(croppedBitmap); + for (int x = 0; x < imageWidth; x++) + for (int y = 0; y < imageHeight; y++) + { + var c = croppedBitmap.GetPixel(x, y); + var r = restoredBitmap.GetPixel(x, y); + Assert.True(c == r); + } } } Done(); @@ -838,87 +818,85 @@ public void TestBackAndForthConversionWithAlphaNoInterleaveNoOffset() [Fact] public void TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset() { - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + const int imageHeight = 100; + const int imageWidth = 130; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 100; - var imageWidth = 130; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - InterleaveArgb = false, - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + InterleaveArgb = false, + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=false} } - }, cropped); + }, cropped); - IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() - { - InterleaveArgb = false, - Column = new VectorToImageTransform.Column[1]{ + IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() + { + InterleaveArgb = false, + Column = new VectorToImageTransform.Column[1]{ new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} } - }, pixels); - - var fname = nameof(TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset) + "_model.zip"; + }, pixels); - var fh = env.CreateOutputFile(fname); - using (var ch = env.Start("save")) - TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + var fname = nameof(TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset) + "_model.zip"; - backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); - DeleteOutputPath(fname); + var fh = env.CreateOutputFile(fname); + using (var ch = env.Start("save")) + TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps)); + backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile)); + DeleteOutputPath(fname); - backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); - backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); - using (var cursor = backToBitmaps.GetRowCursor((x) => true)) - { - var bitmapGetter = cursor.GetGetter(bitmapColumn); - Bitmap restoredBitmap = default; - var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); - Bitmap croppedBitmap = default; - while (cursor.MoveNext()) - { - bitmapGetter(ref restoredBitmap); - Assert.NotNull(restoredBitmap); - bitmapCropGetter(ref croppedBitmap); - Assert.NotNull(croppedBitmap); - for (int x = 0; x < imageWidth; x++) - for (int y = 0; y < imageHeight; y++) - { - var c = croppedBitmap.GetPixel(x, y); - var r = restoredBitmap.GetPixel(x, y); - Assert.True(c.R == r.R && c.G == r.G && c.B == r.B); - } - } + backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn); + backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); + using (var cursor = backToBitmaps.GetRowCursor((x) => true)) + { + var bitmapGetter = cursor.GetGetter(bitmapColumn); + Bitmap restoredBitmap = default; + + var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn); + Bitmap croppedBitmap = default; + while (cursor.MoveNext()) + { + bitmapGetter(ref restoredBitmap); + Assert.NotNull(restoredBitmap); + bitmapCropGetter(ref croppedBitmap); + Assert.NotNull(croppedBitmap); + for (int x = 0; x < imageWidth; x++) + for (int y = 0; y < imageHeight; y++) + { + var c = croppedBitmap.GetPixel(x, y); + var r = restoredBitmap.GetPixel(x, y); + Assert.True(c.R == r.R && c.G == r.G && c.B == r.B); + } } + Done(); } - Done(); } } } diff --git a/test/Microsoft.ML.Tests/OnnxTests.cs b/test/Microsoft.ML.Tests/OnnxTests.cs index 1cac50611c..8e513a5a93 100644 --- a/test/Microsoft.ML.Tests/OnnxTests.cs +++ b/test/Microsoft.ML.Tests/OnnxTests.cs @@ -82,70 +82,68 @@ public class BreastCancerClusterPrediction [Fact] public void InitializerCreationTest() { - using (var env = new ConsoleEnvironment()) - { - // Create the actual implementation - var ctxImpl = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", Runtime.Model.Onnx.OnnxVersion.Stable); - - // Use implementation as in the actual conversion code - var ctx = ctxImpl as OnnxContext; - ctx.AddInitializer(9.4f, "float"); - ctx.AddInitializer(17L, "int64"); - ctx.AddInitializer("36", "string"); - ctx.AddInitializer(new List { 9.4f, 1.7f, 3.6f }, new List { 1, 3 }, "floats"); - ctx.AddInitializer(new List { 94L, 17L, 36L }, new List { 1, 3 }, "int64s"); - ctx.AddInitializer(new List { "94" , "17", "36" }, new List { 1, 3 }, "strings"); - - var model = ctxImpl.MakeModel(); - - var floatScalar = model.Graph.Initializer[0]; - Assert.True(floatScalar.Name == "float"); - Assert.True(floatScalar.Dims.Count == 0); - Assert.True(floatScalar.FloatData.Count == 1); - Assert.True(floatScalar.FloatData[0] == 9.4f); - - var int64Scalar = model.Graph.Initializer[1]; - Assert.True(int64Scalar.Name == "int64"); - Assert.True(int64Scalar.Dims.Count == 0); - Assert.True(int64Scalar.Int64Data.Count == 1); - Assert.True(int64Scalar.Int64Data[0] == 17L); - - var stringScalar = model.Graph.Initializer[2]; - Assert.True(stringScalar.Name == "string"); - Assert.True(stringScalar.Dims.Count == 0); - Assert.True(stringScalar.StringData.Count == 1); - Assert.True(stringScalar.StringData[0].ToStringUtf8() == "36"); - - var floatsTensor = model.Graph.Initializer[3]; - Assert.True(floatsTensor.Name == "floats"); - Assert.True(floatsTensor.Dims.Count == 2); - Assert.True(floatsTensor.Dims[0] == 1); - Assert.True(floatsTensor.Dims[1] == 3); - Assert.True(floatsTensor.FloatData.Count == 3); - Assert.True(floatsTensor.FloatData[0] == 9.4f); - Assert.True(floatsTensor.FloatData[1] == 1.7f); - Assert.True(floatsTensor.FloatData[2] == 3.6f); - - var int64sTensor = model.Graph.Initializer[4]; - Assert.True(int64sTensor.Name == "int64s"); - Assert.True(int64sTensor.Dims.Count == 2); - Assert.True(int64sTensor.Dims[0] == 1); - Assert.True(int64sTensor.Dims[1] == 3); - Assert.True(int64sTensor.Int64Data.Count == 3); - Assert.True(int64sTensor.Int64Data[0] == 94L); - Assert.True(int64sTensor.Int64Data[1] == 17L); - Assert.True(int64sTensor.Int64Data[2] == 36L); - - var stringsTensor = model.Graph.Initializer[5]; - Assert.True(stringsTensor.Name == "strings"); - Assert.True(stringsTensor.Dims.Count == 2); - Assert.True(stringsTensor.Dims[0] == 1); - Assert.True(stringsTensor.Dims[1] == 3); - Assert.True(stringsTensor.StringData.Count == 3); - Assert.True(stringsTensor.StringData[0].ToStringUtf8() == "94"); - Assert.True(stringsTensor.StringData[1].ToStringUtf8() == "17"); - Assert.True(stringsTensor.StringData[2].ToStringUtf8() == "36"); - } + var env = new MLContext(); + // Create the actual implementation + var ctxImpl = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", Runtime.Model.Onnx.OnnxVersion.Stable); + + // Use implementation as in the actual conversion code + var ctx = ctxImpl as OnnxContext; + ctx.AddInitializer(9.4f, "float"); + ctx.AddInitializer(17L, "int64"); + ctx.AddInitializer("36", "string"); + ctx.AddInitializer(new List { 9.4f, 1.7f, 3.6f }, new List { 1, 3 }, "floats"); + ctx.AddInitializer(new List { 94L, 17L, 36L }, new List { 1, 3 }, "int64s"); + ctx.AddInitializer(new List { "94", "17", "36" }, new List { 1, 3 }, "strings"); + + var model = ctxImpl.MakeModel(); + + var floatScalar = model.Graph.Initializer[0]; + Assert.True(floatScalar.Name == "float"); + Assert.True(floatScalar.Dims.Count == 0); + Assert.True(floatScalar.FloatData.Count == 1); + Assert.True(floatScalar.FloatData[0] == 9.4f); + + var int64Scalar = model.Graph.Initializer[1]; + Assert.True(int64Scalar.Name == "int64"); + Assert.True(int64Scalar.Dims.Count == 0); + Assert.True(int64Scalar.Int64Data.Count == 1); + Assert.True(int64Scalar.Int64Data[0] == 17L); + + var stringScalar = model.Graph.Initializer[2]; + Assert.True(stringScalar.Name == "string"); + Assert.True(stringScalar.Dims.Count == 0); + Assert.True(stringScalar.StringData.Count == 1); + Assert.True(stringScalar.StringData[0].ToStringUtf8() == "36"); + + var floatsTensor = model.Graph.Initializer[3]; + Assert.True(floatsTensor.Name == "floats"); + Assert.True(floatsTensor.Dims.Count == 2); + Assert.True(floatsTensor.Dims[0] == 1); + Assert.True(floatsTensor.Dims[1] == 3); + Assert.True(floatsTensor.FloatData.Count == 3); + Assert.True(floatsTensor.FloatData[0] == 9.4f); + Assert.True(floatsTensor.FloatData[1] == 1.7f); + Assert.True(floatsTensor.FloatData[2] == 3.6f); + + var int64sTensor = model.Graph.Initializer[4]; + Assert.True(int64sTensor.Name == "int64s"); + Assert.True(int64sTensor.Dims.Count == 2); + Assert.True(int64sTensor.Dims[0] == 1); + Assert.True(int64sTensor.Dims[1] == 3); + Assert.True(int64sTensor.Int64Data.Count == 3); + Assert.True(int64sTensor.Int64Data[0] == 94L); + Assert.True(int64sTensor.Int64Data[1] == 17L); + Assert.True(int64sTensor.Int64Data[2] == 36L); + + var stringsTensor = model.Graph.Initializer[5]; + Assert.True(stringsTensor.Name == "strings"); + Assert.True(stringsTensor.Dims.Count == 2); + Assert.True(stringsTensor.Dims[0] == 1); + Assert.True(stringsTensor.Dims[1] == 3); + Assert.True(stringsTensor.StringData.Count == 3); + Assert.True(stringsTensor.StringData[0].ToStringUtf8() == "94"); + Assert.True(stringsTensor.StringData[1].ToStringUtf8() == "17"); + Assert.True(stringsTensor.StringData[2].ToStringUtf8() == "36"); } [Fact] @@ -259,8 +257,12 @@ public void KeyToVectorWithBagTest() }); var vectorizer = new CategoricalOneHotVectorizer(); - var categoricalColumn = new OneHotEncodingTransformerColumn() { - OutputKind = OneHotEncodingTransformerOutputKind.Bag, Name = "F2", Source = "F2" }; + var categoricalColumn = new OneHotEncodingTransformerColumn() + { + OutputKind = OneHotEncodingTransformerOutputKind.Bag, + Name = "F2", + Source = "F2" + }; vectorizer.Column = new OneHotEncodingTransformerColumn[1] { categoricalColumn }; pipeline.Add(vectorizer); pipeline.Add(new ColumnConcatenator("Features", "F1", "F2")); @@ -306,7 +308,7 @@ public void WordEmbeddingsTest() { Separator = new[] { '\t' }, HasHeader = false, - Column = new [] + Column = new[] { new TextLoaderColumn() { @@ -317,7 +319,7 @@ public void WordEmbeddingsTest() } } }); - + var modelPath = GetDataPath(@"shortsentiment.emd"); var embed = new WordEmbeddings() { CustomLookupTable = modelPath }; embed.AddColumn("Cat", "Cat"); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index 2f5ba899fc..a9a13385a8 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -153,7 +153,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m [Fact] public void TrainRegressionModel() => TrainRegression(GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename), GetDataPath(TestDatasets.generatedRegressionDataset.testFilename), - DeleteOutputPath("cook_model.zip")); + DeleteOutputPath("cook_model_static.zip")); private ITransformer TrainOnIris(string irisDataPath) { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs index 1bfa09377b..32486abb27 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs @@ -38,7 +38,7 @@ public void New_TrainWithInitialPredictor() var secondTrainer = ml.BinaryClassification.Trainers.AveragedPerceptron("Label","Features"); var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); - var finalModel = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: firstModel.Model)); + var finalModel = ((ITrainer)secondTrainer).Train(new TrainContext(trainRoles, initialPredictor: firstModel.Model)); } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index 9507d4bce3..b3086a68a1 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -62,80 +62,78 @@ private class TwoIChannelsOnlyOneWithAttribute [Fact] public void CursorChannelExposedInMapTransform() { - using (var env = new ConsoleEnvironment(0)) - { - // Correct use of CursorChannel attribute. - var data1 = Utils.CreateArray(10, new OneIChannelWithAttribute()); - var idv1 = env.CreateDataView(data1); - Assert.Null(data1[0].Channel); - - var filter1 = LambdaTransform.CreateFilter(env, idv1, - (input, state) => - { - Assert.NotNull(input.Channel); - return false; - }, null); - filter1.GetRowCursor(col => true).MoveNext(); - - // Error case: non-IChannel field marked with attribute. - var data2 = Utils.CreateArray(10, new OneStringWithAttribute()); - var idv2 = env.CreateDataView(data2); - Assert.Null(data2[0].Channel); - - var filter2 = LambdaTransform.CreateFilter(env, idv2, - (input, state) => - { - Assert.Null(input.Channel); - return false; - }, null); - try + var env = new MLContext(seed: 0); + // Correct use of CursorChannel attribute. + var data1 = Utils.CreateArray(10, new OneIChannelWithAttribute()); + var idv1 = env.CreateDataView(data1); + Assert.Null(data1[0].Channel); + + var filter1 = LambdaTransform.CreateFilter(env, idv1, + (input, state) => { - filter2.GetRowCursor(col => true).MoveNext(); - Assert.True(false, "Throw an error if attribute is applied to a field that is not an IChannel."); - } - catch (InvalidOperationException ex) + Assert.NotNull(input.Channel); + return false; + }, null); + filter1.GetRowCursor(col => true).MoveNext(); + + // Error case: non-IChannel field marked with attribute. + var data2 = Utils.CreateArray(10, new OneStringWithAttribute()); + var idv2 = env.CreateDataView(data2); + Assert.Null(data2[0].Channel); + + var filter2 = LambdaTransform.CreateFilter(env, idv2, + (input, state) => { - Assert.True(ex.IsMarked()); - } + Assert.Null(input.Channel); + return false; + }, null); + try + { + filter2.GetRowCursor(col => true).MoveNext(); + Assert.True(false, "Throw an error if attribute is applied to a field that is not an IChannel."); + } + catch (InvalidOperationException ex) + { + Assert.True(ex.IsMarked()); + } - // Error case: multiple fields marked with attributes. - var data3 = Utils.CreateArray(10, new TwoIChannelsWithAttributes()); - var idv3 = env.CreateDataView(data3); - Assert.Null(data3[0].ChannelOne); - Assert.Null(data3[2].ChannelTwo); + // Error case: multiple fields marked with attributes. + var data3 = Utils.CreateArray(10, new TwoIChannelsWithAttributes()); + var idv3 = env.CreateDataView(data3); + Assert.Null(data3[0].ChannelOne); + Assert.Null(data3[2].ChannelTwo); - var filter3 = LambdaTransform.CreateFilter(env, idv3, - (input, state) => - { - Assert.Null(input.ChannelOne); - Assert.Null(input.ChannelTwo); - return false; - }, null); - try + var filter3 = LambdaTransform.CreateFilter(env, idv3, + (input, state) => { - filter3.GetRowCursor(col => true).MoveNext(); - Assert.True(false, "Throw an error if attribute is applied to a field that is not an IChannel."); - } - catch (InvalidOperationException ex) - { - Assert.True(ex.IsMarked()); - } + Assert.Null(input.ChannelOne); + Assert.Null(input.ChannelTwo); + return false; + }, null); + try + { + filter3.GetRowCursor(col => true).MoveNext(); + Assert.True(false, "Throw an error if attribute is applied to a field that is not an IChannel."); + } + catch (InvalidOperationException ex) + { + Assert.True(ex.IsMarked()); + } - // Correct case: non-marked IChannel field is not touched. - var example4 = new TwoIChannelsOnlyOneWithAttribute(); - Assert.Null(example4.ChannelTwo); - Assert.Null(example4.ChannelOne); - var idv4 = env.CreateDataView(Utils.CreateArray(10, example4)); + // Correct case: non-marked IChannel field is not touched. + var example4 = new TwoIChannelsOnlyOneWithAttribute(); + Assert.Null(example4.ChannelTwo); + Assert.Null(example4.ChannelOne); + var idv4 = env.CreateDataView(Utils.CreateArray(10, example4)); - var filter4 = LambdaTransform.CreateFilter(env, idv4, - (input, state) => - { - Assert.Null(input.ChannelOne); - Assert.NotNull(input.ChannelTwo); - return false; - }, null); - filter1.GetRowCursor(col => true).MoveNext(); - } + var filter4 = LambdaTransform.CreateFilter(env, idv4, + (input, state) => + { + Assert.Null(input.ChannelOne); + Assert.NotNull(input.ChannelTwo); + return false; + }, null); + filter1.GetRowCursor(col => true).MoveNext(); } public class BreastCancerExample @@ -149,44 +147,40 @@ public class BreastCancerExample [Fact] public void LambdaTransformCreate() { - using (var env = new ConsoleEnvironment(42)) - { - var data = ReadBreastCancerExamples(); - var idv = env.CreateDataView(data); + var env = new MLContext(seed: 42); + var data = ReadBreastCancerExamples(); + var idv = env.CreateDataView(data); - var filter = LambdaTransform.CreateFilter(env, idv, - (input, state) => input.Label == 0, null); + var filter = LambdaTransform.CreateFilter(env, idv, + (input, state) => input.Label == 0, null); - Assert.Null(filter.GetRowCount()); + Assert.Null(filter.GetRowCount()); - // test re-apply - var applied = env.CreateDataView(data); - applied = ApplyTransformUtils.ApplyAllTransformsToData(env, filter, applied); + // test re-apply + var applied = env.CreateDataView(data); + applied = ApplyTransformUtils.ApplyAllTransformsToData(env, filter, applied); - var saver = new TextSaver(env, new TextSaver.Arguments()); - Assert.True(applied.Schema.TryGetColumnIndex("Label", out int label)); - using (var fs = File.Create(GetOutputPath(OutputRelativePath, "lambda-output.tsv"))) - saver.SaveData(fs, applied, label); - } + var saver = new TextSaver(env, new TextSaver.Arguments()); + Assert.True(applied.Schema.TryGetColumnIndex("Label", out int label)); + using (var fs = File.Create(GetOutputPath(OutputRelativePath, "lambda-output.tsv"))) + saver.SaveData(fs, applied, label); } [Fact] public void TrainAveragedPerceptronWithCache() { - using (var env = new ConsoleEnvironment(0)) - { - var dataFile = GetDataPath("breast-cancer.txt"); - var loader = TextLoader.Create(env, new TextLoader.Arguments(), new MultiFileSource(dataFile)); - var globalCounter = 0; - var xf = LambdaTransform.CreateFilter(env, loader, - (i, s) => true, - s => { globalCounter++; }); + var env = new MLContext(0); + var dataFile = GetDataPath("breast-cancer.txt"); + var loader = TextLoader.Create(env, new TextLoader.Arguments(), new MultiFileSource(dataFile)); + var globalCounter = 0; + var xf = LambdaTransform.CreateFilter(env, loader, + (i, s) => true, + s => { globalCounter++; }); - new AveragedPerceptronTrainer(env, "Label", "Features", numIterations: 2).Fit(xf).Transform(xf); + new AveragedPerceptronTrainer(env, "Label", "Features", numIterations: 2).Fit(xf).Transform(xf); - // Make sure there were 2 cursoring events. - Assert.Equal(1, globalCounter); - } + // Make sure there were 2 cursoring events. + Assert.Equal(1, globalCounter); } private List ReadBreastCancerExamples() diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 089edc4407..a67834cda0 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -420,8 +420,7 @@ private LearningPipeline PreparePipelineSymSGD() WordFeatureExtractor = new NGramNgramExtractor() { NgramLength = 2, AllLengths = true } }); - - pipeline.Add(new SymSgdBinaryClassifier() { NumberOfThreads = 1}); + pipeline.Add(new SymSgdBinaryClassifier() { NumberOfThreads = 1 }); pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); return pipeline; diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index b64f073825..a362aecacc 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -25,50 +25,48 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest() string dataPath = GetDataPath("iris.txt"); string testDataPath = dataPath; - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) - { - // Pipeline - var loader = TextLoader.ReadFile(env, - new TextLoader.Arguments() + var env = new MLContext(seed: 1, conc: 1); + // Pipeline + var loader = TextLoader.ReadFile(env, + new TextLoader.Arguments() + { + HasHeader = false, + Column = new[] { - HasHeader = false, - Column = new[] - { new TextLoader.Column("Label", DataKind.R4, 0), new TextLoader.Column("SepalLength", DataKind.R4, 1), new TextLoader.Column("SepalWidth", DataKind.R4, 2), new TextLoader.Column("PetalLength", DataKind.R4, 3), new TextLoader.Column("PetalWidth", DataKind.R4, 4) - } - }, new MultiFileSource(dataPath)); + } + }, new MultiFileSource(dataPath)); - IDataView pipeline = new ColumnConcatenatingTransformer(env, "Features", - "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(loader); + IDataView pipeline = new ColumnConcatenatingTransformer(env, "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(loader); - // NormalizingEstimator is not automatically added though the trainer has 'NormalizeFeatures' On/Auto - pipeline = NormalizeTransform.CreateMinMaxNormalizer(env, pipeline, "Features"); + // NormalizingEstimator is not automatically added though the trainer has 'NormalizeFeatures' On/Auto + pipeline = NormalizeTransform.CreateMinMaxNormalizer(env, pipeline, "Features"); - // Train - var trainer = new SdcaMultiClassTrainer(env, "Label", "Features", advancedSettings: (s) => s.NumThreads = 1); + // Train + var trainer = new SdcaMultiClassTrainer(env, "Label", "Features", advancedSettings: s => s.NumThreads = 1); - // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto - var cached = new CacheDataView(env, pipeline, prefetch: null); - var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); - var pred = trainer.Train(trainRoles); + // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto + var cached = new CacheDataView(env, pipeline, prefetch: null); + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); + var pred = trainer.Train(trainRoles); - // Get scorer and evaluate the predictions from test data - IDataScorerTransform testDataScorer = GetScorer(env, pipeline, pred, testDataPath); - var metrics = Evaluate(env, testDataScorer); - CompareMatrics(metrics); + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, pipeline, pred, testDataPath); + var metrics = Evaluate(env, testDataScorer); + CompareMatrics(metrics); - // Create prediction engine and test predictions - var model = env.CreatePredictionEngine(testDataScorer); - ComparePredictions(model); + // Create prediction engine and test predictions + var model = env.CreatePredictionEngine(testDataScorer); + ComparePredictions(model); - // Get feature importance i.e. weight vector - var summary = ((MulticlassLogisticRegressionPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); - Assert.Equal(7.757864, Convert.ToDouble(summary[0].Value), 5); - } + // Get feature importance i.e. weight vector + var summary = pred.GetSummaryInKeyValuePairs(trainRoles.Schema); + Assert.Equal(7.76443, Convert.ToDouble(summary[0].Value), 5); } private void ComparePredictions(PredictionEngine model) diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs index 594ec65d66..820d146cc1 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs @@ -22,10 +22,9 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) - { - // Pipeline - var loader = TextLoader.ReadFile(env, + var env = new MLContext(seed: 1, conc: 1); + // Pipeline + var loader = TextLoader.ReadFile(env, new TextLoader.Arguments() { Separator = "tab", @@ -37,46 +36,45 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() } }, new MultiFileSource(dataPath)); - var trans = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments() + var trans = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments() + { + Column = new TextFeaturizingEstimator.Column { - Column = new TextFeaturizingEstimator.Column - { - Name = "Features", - Source = new[] { "SentimentText" } - }, - OutputTokens = true, - KeepPunctuations = false, - StopWordsRemover = new PredefinedStopWordsRemoverFactory(), - VectorNormalizer = TextFeaturizingEstimator.TextNormKind.L2, - CharFeatureExtractor = new NgramExtractingTransformer.NgramExtractorArguments() { NgramLength = 3, AllLengths = false }, - WordFeatureExtractor = new NgramExtractingTransformer.NgramExtractorArguments() { NgramLength = 2, AllLengths = true }, + Name = "Features", + Source = new[] { "SentimentText" } }, - loader); - - // Train - var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, - numLeaves:5, numTrees:5, minDatapointsInLeaves: 2); - - var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); - var pred = trainer.Train(trainRoles); - - // Get scorer and evaluate the predictions from test data - IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); - var metrics = EvaluateBinary(env, testDataScorer); - ValidateBinaryMetrics(metrics); - - // Create prediction engine and test predictions - var model = env.CreateBatchPredictionEngine(testDataScorer); - var sentiments = GetTestData(); - var predictions = model.Predict(sentiments, false); - Assert.Equal(2, predictions.Count()); - Assert.True(predictions.ElementAt(0).Sentiment); - Assert.True(predictions.ElementAt(1).Sentiment); - - // Get feature importance based on feature gain during training - var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); - Assert.Equal(1.0, (double)summary[0].Value, 1); - } + OutputTokens = true, + KeepPunctuations = false, + StopWordsRemover = new PredefinedStopWordsRemoverFactory(), + VectorNormalizer = TextFeaturizingEstimator.TextNormKind.L2, + CharFeatureExtractor = new NgramExtractingTransformer.NgramExtractorArguments() { NgramLength = 3, AllLengths = false }, + WordFeatureExtractor = new NgramExtractingTransformer.NgramExtractorArguments() { NgramLength = 2, AllLengths = true }, + }, + loader); + + // Train + var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, + numLeaves: 5, numTrees: 5, minDatapointsInLeaves: 2); + + var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); + var pred = trainer.Train(trainRoles); + + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + var metrics = EvaluateBinary(env, testDataScorer); + ValidateBinaryMetrics(metrics); + + // Create prediction engine and test predictions + var model = env.CreateBatchPredictionEngine(testDataScorer); + var sentiments = GetTestData(); + var predictions = model.Predict(sentiments, false); + Assert.Equal(2, predictions.Count()); + Assert.True(predictions.ElementAt(0).Sentiment); + Assert.True(predictions.ElementAt(1).Sentiment); + + // Get feature importance based on feature gain during training + var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); + Assert.Equal(1.0, (double)summary[0].Value, 1); } [Fact] @@ -85,10 +83,9 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordE var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) - { - // Pipeline - var loader = TextLoader.ReadFile(env, + var env = new MLContext(seed: 1, conc: 1); + // Pipeline + var loader = TextLoader.ReadFile(env, new TextLoader.Arguments() { Separator = "tab", @@ -100,60 +97,60 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordE } }, new MultiFileSource(dataPath)); - var text = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments() + var text = TextFeaturizingEstimator.Create(env, new TextFeaturizingEstimator.Arguments() + { + Column = new TextFeaturizingEstimator.Column { - Column = new TextFeaturizingEstimator.Column - { - Name = "WordEmbeddings", - Source = new[] { "SentimentText" } - }, - OutputTokens = true, - KeepPunctuations= false, - StopWordsRemover = new PredefinedStopWordsRemoverFactory(), - VectorNormalizer = TextFeaturizingEstimator.TextNormKind.None, - CharFeatureExtractor = null, - WordFeatureExtractor = null, + Name = "WordEmbeddings", + Source = new[] { "SentimentText" } }, - loader); - - var trans = WordEmbeddingsExtractingTransformer.Create(env, new WordEmbeddingsExtractingTransformer.Arguments() + OutputTokens = true, + KeepPunctuations = false, + StopWordsRemover = new PredefinedStopWordsRemoverFactory(), + VectorNormalizer = TextFeaturizingEstimator.TextNormKind.None, + CharFeatureExtractor = null, + WordFeatureExtractor = null, + }, + loader); + + var trans = WordEmbeddingsExtractingTransformer.Create(env, new WordEmbeddingsExtractingTransformer.Arguments() + { + Column = new WordEmbeddingsExtractingTransformer.Column[1] { - Column = new WordEmbeddingsExtractingTransformer.Column[1] - { new WordEmbeddingsExtractingTransformer.Column { Name = "Features", Source = "WordEmbeddings_TransformedText" } - }, - ModelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe, - }, text); - // Train - var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numLeaves: 5, numTrees:5, minDatapointsInLeaves:2); - - var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); - var pred = trainer.Train(trainRoles); - // Get scorer and evaluate the predictions from test data - IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); - var metrics = EvaluateBinary(env, testDataScorer); - - // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great. - Assert.Equal(.6667, metrics.Accuracy, 4); - Assert.Equal(.71, metrics.Auc, 1); - Assert.Equal(.58, metrics.Auprc, 2); - // Create prediction engine and test predictions - var model = env.CreateBatchPredictionEngine(testDataScorer); - var sentiments = GetTestData(); - var predictions = model.Predict(sentiments, false); - Assert.Equal(2, predictions.Count()); - Assert.True(predictions.ElementAt(0).Sentiment); - Assert.True(predictions.ElementAt(1).Sentiment); - - // Get feature importance based on feature gain during training - var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); - Assert.Equal(1.0, (double)summary[0].Value, 1); - } + }, + ModelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe, + }, text); + // Train + var trainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numLeaves: 5, numTrees: 5, minDatapointsInLeaves: 2); + + var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); + var pred = trainer.Train(trainRoles); + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + var metrics = EvaluateBinary(env, testDataScorer); + + // SSWE is a simple word embedding model + we train on a really small dataset, so metrics are not great. + Assert.Equal(.6667, metrics.Accuracy, 4); + Assert.Equal(.71, metrics.Auc, 1); + Assert.Equal(.58, metrics.Auprc, 2); + // Create prediction engine and test predictions + var model = env.CreateBatchPredictionEngine(testDataScorer); + var sentiments = GetTestData(); + var predictions = model.Predict(sentiments, false); + Assert.Equal(2, predictions.Count()); + Assert.True(predictions.ElementAt(0).Sentiment); + Assert.True(predictions.ElementAt(1).Sentiment); + + // Get feature importance based on feature gain during training + var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); + Assert.Equal(1.0, (double)summary[0].Value, 1); } + private BinaryClassificationMetrics EvaluateBinary(IHostEnvironment env, IDataView scoredData) { var dataEval = new RoleMappedData(scoredData, label: "Label", feature: "Features", opt: true); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 3529dd6b03..46c1f9da23 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -35,10 +35,9 @@ private class TestData public void TensorFlowTransformMatrixMultiplicationTest() { var model_location = "model_matmul/frozen_saved_model.pb"; - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) - { - // Pipeline - var loader = ComponentCreation.CreateDataView(env, + var env = new MLContext(seed: 1, conc: 1); + // Pipeline + var loader = ComponentCreation.CreateDataView(env, new List(new TestData[] { new TestData() { a = new[] { 1.0f, 2.0f, 3.0f, 4.0f }, b = new[] { 1.0f, 2.0f, @@ -48,32 +47,30 @@ public void TensorFlowTransformMatrixMultiplicationTest() b = new[] { 3.0f, 3.0f, 3.0f, 3.0f } } })); - var trans = TensorFlowTransform.Create(env, loader, model_location, new[] { "c" }, new[] { "a", "b" }); - - using (var cursor = trans.GetRowCursor(a => true)) - { - var cgetter = cursor.GetGetter>(2); - Assert.True(cursor.MoveNext()); - VBuffer c = default; - cgetter(ref c); - - Assert.Equal(1.0 * 1.0 + 2.0 * 3.0, c.Values[0]); - Assert.Equal(1.0 * 2.0 + 2.0 * 4.0, c.Values[1]); - Assert.Equal(3.0 * 1.0 + 4.0 * 3.0, c.Values[2]); - Assert.Equal(3.0 * 2.0 + 4.0 * 4.0, c.Values[3]); + var trans = TensorFlowTransform.Create(env, loader, model_location, new[] { "c" }, new[] { "a", "b" }); - Assert.True(cursor.MoveNext()); - c = default; - cgetter(ref c); - - Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[0]); - Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[1]); - Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[2]); - Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[3]); - - Assert.False(cursor.MoveNext()); - - } + using (var cursor = trans.GetRowCursor(a => true)) + { + var cgetter = cursor.GetGetter>(2); + Assert.True(cursor.MoveNext()); + VBuffer c = default; + cgetter(ref c); + + Assert.Equal(1.0 * 1.0 + 2.0 * 3.0, c.Values[0]); + Assert.Equal(1.0 * 2.0 + 2.0 * 4.0, c.Values[1]); + Assert.Equal(3.0 * 1.0 + 4.0 * 3.0, c.Values[2]); + Assert.Equal(3.0 * 2.0 + 4.0 * 4.0, c.Values[3]); + + Assert.True(cursor.MoveNext()); + c = default; + cgetter(ref c); + + Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[0]); + Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[1]); + Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[2]); + Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[3]); + + Assert.False(cursor.MoveNext()); } } @@ -81,58 +78,56 @@ public void TensorFlowTransformMatrixMultiplicationTest() public void TensorFlowTransformObjectDetectionTest() { var model_location = @"C:\models\TensorFlow\ssd_mobilenet_v1_coco_2018_01_28\frozen_inference_graph.pb"; - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) + var env = new MLContext(seed: 1, conc: 1); + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =32, ImageWidth = 32, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - Column = new ImagePixelExtractorTransform.Column[1]{ + }, images); + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "image_tensor", UseAlpha=false, InterleaveArgb=true, Convert = false} } - }, cropped); - - var tf = TensorFlowTransform.Create(env, pixels, model_location, - new[] { "detection_boxes", "detection_scores", "num_detections", "detection_classes" }, - new[] { "image_tensor" }); - - tf.Schema.TryGetColumnIndex("image_tensor", out int input); - tf.Schema.TryGetColumnIndex("detection_boxes", out int boxes); - tf.Schema.TryGetColumnIndex("detection_scores", out int scores); - tf.Schema.TryGetColumnIndex("num_detections", out int num); - tf.Schema.TryGetColumnIndex("detection_classes", out int classes); - using (var curs = tf.GetRowCursor(col => col == classes || col == num || col == scores || col == boxes || col == input)) - { - var getInput = curs.GetGetter>(input); - var getBoxes = curs.GetGetter>(boxes); - var getScores = curs.GetGetter>(scores); - var getNum = curs.GetGetter>(num); - var getClasses = curs.GetGetter>(classes); - var buffer = default(VBuffer); - var inputBuffer = default(VBuffer); - while (curs.MoveNext()) - { - getInput(ref inputBuffer); - getBoxes(ref buffer); - getScores(ref buffer); - getNum(ref buffer); - getClasses(ref buffer); - } + }, cropped); + + var tf = TensorFlowTransform.Create(env, pixels, model_location, + new[] { "detection_boxes", "detection_scores", "num_detections", "detection_classes" }, + new[] { "image_tensor" }); + + tf.Schema.TryGetColumnIndex("image_tensor", out int input); + tf.Schema.TryGetColumnIndex("detection_boxes", out int boxes); + tf.Schema.TryGetColumnIndex("detection_scores", out int scores); + tf.Schema.TryGetColumnIndex("num_detections", out int num); + tf.Schema.TryGetColumnIndex("detection_classes", out int classes); + using (var curs = tf.GetRowCursor(col => col == classes || col == num || col == scores || col == boxes || col == input)) + { + var getInput = curs.GetGetter>(input); + var getBoxes = curs.GetGetter>(boxes); + var getScores = curs.GetGetter>(scores); + var getNum = curs.GetGetter>(num); + var getClasses = curs.GetGetter>(classes); + var buffer = default(VBuffer); + var inputBuffer = default(VBuffer); + while (curs.MoveNext()) + { + getInput(ref inputBuffer); + getBoxes(ref buffer); + getScores(ref buffer); + getNum(ref buffer); + getClasses(ref buffer); } } } @@ -141,47 +136,45 @@ public void TensorFlowTransformObjectDetectionTest() public void TensorFlowTransformInceptionTest() { var model_location = @"C:\models\TensorFlow\tensorflow_inception_graph.pb"; - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) + var env = new MLContext(seed: 1, conc: 1); + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =224, ImageWidth = 224, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - Column = new ImagePixelExtractorTransform.Column[1]{ + }, images); + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "input", UseAlpha=false, InterleaveArgb=true, Convert = true} } - }, cropped); + }, cropped); - var tf = TensorFlowTransform.Create(env, pixels, model_location, new[] { "softmax2_pre_activation" }, new[] { "input" }); + var tf = TensorFlowTransform.Create(env, pixels, model_location, new[] { "softmax2_pre_activation" }, new[] { "input" }); - tf.Schema.TryGetColumnIndex("input", out int input); - tf.Schema.TryGetColumnIndex("softmax2_pre_activation", out int b); - using (var curs = tf.GetRowCursor(col => col == b || col == input)) - { - var get = curs.GetGetter>(b); - var getInput = curs.GetGetter>(input); - var buffer = default(VBuffer); - var inputBuffer = default(VBuffer); - while (curs.MoveNext()) - { - getInput(ref inputBuffer); - get(ref buffer); - } + tf.Schema.TryGetColumnIndex("input", out int input); + tf.Schema.TryGetColumnIndex("softmax2_pre_activation", out int b); + using (var curs = tf.GetRowCursor(col => col == b || col == input)) + { + var get = curs.GetGetter>(b); + var getInput = curs.GetGetter>(input); + var buffer = default(VBuffer); + var inputBuffer = default(VBuffer); + while (curs.MoveNext()) + { + getInput(ref inputBuffer); + get(ref buffer); } } } @@ -189,133 +182,130 @@ public void TensorFlowTransformInceptionTest() [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only public void TensorFlowInputsOutputsSchemaTest() { - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) - { - var model_location = "mnist_model/frozen_saved_model.pb"; - var schema = TensorFlowUtils.GetModelSchema(env, model_location); - Assert.Equal(86, schema.ColumnCount); - Assert.True(schema.TryGetColumnIndex("Placeholder", out int col)); - var type = (VectorType)schema.GetColumnType(col); - Assert.Equal(2, type.Dimensions.Length); - Assert.Equal(28, type.Dimensions[0]); - Assert.Equal(28, type.Dimensions[1]); - var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); - Assert.NotNull(metadataType); - Assert.True(metadataType is TextType); - ReadOnlyMemory opType = default; - schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); - Assert.Equal("Placeholder", opType.ToString()); - metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); - Assert.Null(metadataType); - - Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D/ReadVariableOp", out col)); - type = (VectorType)schema.GetColumnType(col); - Assert.Equal(new[] { 5, 5, 1, 32 }, type.Dimensions); - metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); - Assert.NotNull(metadataType); - Assert.True(metadataType is TextType); - schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); - Assert.Equal("Identity", opType.ToString()); - metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); - Assert.NotNull(metadataType); - VBuffer> inputOps = default; - schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); - Assert.Equal(1, inputOps.Length); - Assert.Equal("conv2d/kernel", inputOps.Values[0].ToString()); - - Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D", out col)); - type = (VectorType)schema.GetColumnType(col); - Assert.Equal(new[] { 28, 28, 32 }, type.Dimensions); - metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); - Assert.NotNull(metadataType); - Assert.True(metadataType is TextType); - schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); - Assert.Equal("Conv2D", opType.ToString()); - metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); - Assert.NotNull(metadataType); - schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); - Assert.Equal(2, inputOps.Length); - Assert.Equal("reshape/Reshape", inputOps.Values[0].ToString()); - Assert.Equal("conv2d/Conv2D/ReadVariableOp", inputOps.Values[1].ToString()); - - Assert.True(schema.TryGetColumnIndex("Softmax", out col)); - type = (VectorType)schema.GetColumnType(col); - Assert.Equal(new[] { 10 }, type.Dimensions); - metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); - Assert.NotNull(metadataType); - Assert.True(metadataType is TextType); - schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); - Assert.Equal("Softmax", opType.ToString()); - metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); - Assert.NotNull(metadataType); - schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); - Assert.Equal(1, inputOps.Length); - Assert.Equal("sequential/dense_1/BiasAdd", inputOps.Values[0].ToString()); - - model_location = "model_matmul/frozen_saved_model.pb"; - schema = TensorFlowUtils.GetModelSchema(env, model_location); - char name = 'a'; - for (int i = 0; i < schema.ColumnCount; i++) - { - Assert.Equal(name.ToString(), schema.GetColumnName(i)); - type = (VectorType)schema.GetColumnType(i); - Assert.Equal(new[] { 2, 2 }, type.Dimensions); - name++; - } + var env = new MLContext(seed: 1, conc: 1); + var model_location = "mnist_model/frozen_saved_model.pb"; + var schema = TensorFlowUtils.GetModelSchema(env, model_location); + Assert.Equal(86, schema.ColumnCount); + Assert.True(schema.TryGetColumnIndex("Placeholder", out int col)); + var type = (VectorType)schema.GetColumnType(col); + Assert.Equal(2, type.Dimensions.Length); + Assert.Equal(28, type.Dimensions[0]); + Assert.Equal(28, type.Dimensions[1]); + var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType is TextType); + ReadOnlyMemory opType = default; + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.Equal("Placeholder", opType.ToString()); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.Null(metadataType); + + Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D/ReadVariableOp", out col)); + type = (VectorType)schema.GetColumnType(col); + Assert.Equal(new[] { 5, 5, 1, 32 }, type.Dimensions); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType is TextType); + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.Equal("Identity", opType.ToString()); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.NotNull(metadataType); + VBuffer> inputOps = default; + schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); + Assert.Equal(1, inputOps.Length); + Assert.Equal("conv2d/kernel", inputOps.Values[0].ToString()); + + Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D", out col)); + type = (VectorType)schema.GetColumnType(col); + Assert.Equal(new[] { 28, 28, 32 }, type.Dimensions); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType is TextType); + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.Equal("Conv2D", opType.ToString()); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.NotNull(metadataType); + schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); + Assert.Equal(2, inputOps.Length); + Assert.Equal("reshape/Reshape", inputOps.Values[0].ToString()); + Assert.Equal("conv2d/Conv2D/ReadVariableOp", inputOps.Values[1].ToString()); + + Assert.True(schema.TryGetColumnIndex("Softmax", out col)); + type = (VectorType)schema.GetColumnType(col); + Assert.Equal(new[] { 10 }, type.Dimensions); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType is TextType); + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.Equal("Softmax", opType.ToString()); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.NotNull(metadataType); + schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); + Assert.Equal(1, inputOps.Length); + Assert.Equal("sequential/dense_1/BiasAdd", inputOps.Values[0].ToString()); + + model_location = "model_matmul/frozen_saved_model.pb"; + schema = TensorFlowUtils.GetModelSchema(env, model_location); + char name = 'a'; + for (int i = 0; i < schema.ColumnCount; i++) + { + Assert.Equal(name.ToString(), schema.GetColumnName(i)); + type = (VectorType)schema.GetColumnType(i); + Assert.Equal(new[] { 2, 2 }, type.Dimensions); + name++; } } [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only public void TensorFlowTransformMNISTConvTest() { - var model_location = "mnist_model/frozen_saved_model.pb"; - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) - { - var dataPath = GetDataPath("Train-Tiny-28x28.txt"); - var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); + const string model_location = "mnist_model/frozen_saved_model.pb"; + var env = new MLContext(seed: 1, conc: 1); + var dataPath = GetDataPath("Train-Tiny-28x28.txt"); + var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); - // Pipeline - var loader = TextLoader.ReadFile(env, - new TextLoader.Arguments() + // Pipeline + var loader = TextLoader.ReadFile(env, + new TextLoader.Arguments() + { + Separator = "tab", + HasHeader = true, + Column = new[] { - Separator = "tab", - HasHeader = true, - Column = new[] - { new TextLoader.Column("Label", DataKind.Num,0), new TextLoader.Column("Placeholder", DataKind.Num,new []{new TextLoader.Range(1, 784) }) - } - }, new MultiFileSource(dataPath)); + } + }, new MultiFileSource(dataPath)); - IDataView trans = ColumnsCopyingTransformer.Create(env, new ColumnsCopyingTransformer.Arguments() - { - Column = new[] { new ColumnsCopyingTransformer.Column() - { Name = "reshape_input", Source = "Placeholder" } - } - }, loader); - trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); - trans = new ColumnConcatenatingTransformer(env, "Features", "Softmax", "dense/Relu").Transform(trans); + IDataView trans = ColumnsCopyingTransformer.Create(env, new ColumnsCopyingTransformer.Arguments() + { + Column = new[] { new ColumnsCopyingTransformer.Column() + { Name = "reshape_input", Source = "Placeholder" } + } + }, loader); + trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); + trans = new ColumnConcatenatingTransformer(env, "Features", "Softmax", "dense/Relu").Transform(trans); - var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); + var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); - var cached = new CacheDataView(env, trans, prefetch: null); - var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); - var pred = trainer.Train(trainRoles); + var cached = new CacheDataView(env, trans, prefetch: null); + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); + var pred = trainer.Train(trainRoles); - // Get scorer and evaluate the predictions from test data - IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); - var metrics = Evaluate(env, testDataScorer); + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + var metrics = Evaluate(env, testDataScorer); - Assert.Equal(0.99, metrics.AccuracyMicro, 2); - Assert.Equal(1.0, metrics.AccuracyMacro, 2); + Assert.Equal(0.99, metrics.AccuracyMicro, 2); + Assert.Equal(1.0, metrics.AccuracyMacro, 2); - // Create prediction engine and test predictions - var model = env.CreatePredictionEngine(testDataScorer); + // Create prediction engine and test predictions + var model = env.CreatePredictionEngine(testDataScorer); - var sample1 = new MNISTData() - { - Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + var sample1 = new MNISTData() + { + Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -328,23 +318,22 @@ public void TensorFlowTransformMNISTConvTest() 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } - }; + }; - var prediction = model.Predict(sample1); + var prediction = model.Predict(sample1); - float max = -1; - int maxIndex = -1; - for (int i = 0; i < prediction.PredictedLabels.Length; i++) + float max = -1; + int maxIndex = -1; + for (int i = 0; i < prediction.PredictedLabels.Length; i++) + { + if (prediction.PredictedLabels[i] > max) { - if (prediction.PredictedLabels[i] > max) - { - max = prediction.PredictedLabels[i]; - maxIndex = i; - } + max = prediction.PredictedLabels[i]; + maxIndex = i; } - - Assert.Equal(5, maxIndex); } + + Assert.Equal(5, maxIndex); } [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only @@ -362,81 +351,80 @@ private void ExecuteTFTransformMNISTLRTrainingTest(bool shuffle, int? shuffleSee var model_location = "mnist_lr_model"; try { - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) - { - var dataPath = GetDataPath("Train-Tiny-28x28.txt"); - var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); + var env = new MLContext(seed: 1, conc: 1); + var dataPath = GetDataPath("Train-Tiny-28x28.txt"); + var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); - // Pipeline - var loader = TextLoader.ReadFile(env, - new TextLoader.Arguments() + // Pipeline + var loader = TextLoader.ReadFile(env, + new TextLoader.Arguments() + { + Separator = "tab", + HasHeader = false, + Column = new[] { - Separator = "tab", - HasHeader = false, - Column = new[] - { new TextLoader.Column("Label", DataKind.Num,0), new TextLoader.Column("Placeholder", DataKind.Num,new []{new TextLoader.Range(1, 784) }) - } - }, new MultiFileSource(dataPath)); + } + }, new MultiFileSource(dataPath)); - IDataView trans = new OneHotEncodingEstimator(env, "Label", "OneHotLabel").Fit(loader).Transform(loader); - trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "Features", "Placeholder"); + IDataView trans = new OneHotEncodingEstimator(env, "Label", "OneHotLabel").Fit(loader).Transform(loader); + trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "Features", "Placeholder"); + + var args = new TensorFlowTransform.Arguments() + { + ModelLocation = model_location, + InputColumns = new[] { "Features" }, + OutputColumns = new[] { "Prediction", "b" }, + LabelColumn = "OneHotLabel", + TensorFlowLabel = "Label", + OptimizationOperation = "SGDOptimizer", + LossOperation = "Loss", + Epoch = 10, + LearningRateOperation = "SGDOptimizer/learning_rate", + LearningRate = 0.001f, + BatchSize = 20, + ReTrain = true + }; - var args = new TensorFlowTransform.Arguments() - { - ModelLocation = model_location, - InputColumns = new[] { "Features" }, - OutputColumns = new[] { "Prediction", "b" }, - LabelColumn = "OneHotLabel", - TensorFlowLabel = "Label", - OptimizationOperation = "SGDOptimizer", - LossOperation = "Loss", - Epoch = 10, - LearningRateOperation = "SGDOptimizer/learning_rate", - LearningRate = 0.001f, - BatchSize = 20, - ReTrain = true - }; - - IDataView trainedTfDataView = null; - if (shuffle) - { - var shuffledView = new RowShufflingTransformer(env, new RowShufflingTransformer.Arguments() - { - ForceShuffle = shuffle, - ForceShuffleSeed = shuffleSeed - }, trans); - trainedTfDataView = new TensorFlowEstimator(env, args).Fit(shuffledView).Transform(trans); - } - else + IDataView trainedTfDataView = null; + if (shuffle) + { + var shuffledView = new RowShufflingTransformer(env, new RowShufflingTransformer.Arguments() { - trainedTfDataView = new TensorFlowEstimator(env, args).Fit(trans).Transform(trans); - } + ForceShuffle = shuffle, + ForceShuffleSeed = shuffleSeed + }, trans); + trainedTfDataView = new TensorFlowEstimator(env, args).Fit(shuffledView).Transform(trans); + } + else + { + trainedTfDataView = new TensorFlowEstimator(env, args).Fit(trans).Transform(trans); + } - trans = new ColumnConcatenatingTransformer(env, "Features", "Prediction").Transform(trainedTfDataView); + trans = new ColumnConcatenatingTransformer(env, "Features", "Prediction").Transform(trainedTfDataView); - var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); + var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); - var cached = new CacheDataView(env, trans, prefetch: null); - var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); + var cached = new CacheDataView(env, trans, prefetch: null); + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); - var pred = trainer.Train(trainRoles); + var pred = trainer.Train(trainRoles); - // Get scorer and evaluate the predictions from test data - IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); - var metrics = Evaluate(env, testDataScorer); + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + var metrics = Evaluate(env, testDataScorer); - Assert.Equal(expectedMicroAccuracy, metrics.AccuracyMicro, 2); - Assert.Equal(expectedMacroAccruacy, metrics.AccuracyMacro, 2); + Assert.Equal(expectedMicroAccuracy, metrics.AccuracyMicro, 2); + Assert.Equal(expectedMacroAccruacy, metrics.AccuracyMacro, 2); - // Create prediction engine and test predictions - var model = env.CreatePredictionEngine(testDataScorer); + // Create prediction engine and test predictions + var model = env.CreatePredictionEngine(testDataScorer); - var sample1 = new MNISTData() - { - Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + var sample1 = new MNISTData() + { + Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -449,34 +437,33 @@ private void ExecuteTFTransformMNISTLRTrainingTest(bool shuffle, int? shuffleSee 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } - }; + }; - var prediction = model.Predict(sample1); + var prediction = model.Predict(sample1); - float max = -1; - int maxIndex = -1; - for (int i = 0; i < prediction.PredictedLabels.Length; i++) + float max = -1; + int maxIndex = -1; + for (int i = 0; i < prediction.PredictedLabels.Length; i++) + { + if (prediction.PredictedLabels[i] > max) { - if (prediction.PredictedLabels[i] > max) - { - max = prediction.PredictedLabels[i]; - maxIndex = i; - } + max = prediction.PredictedLabels[i]; + maxIndex = i; } + } - Assert.Equal(5, maxIndex); + Assert.Equal(5, maxIndex); - // Check if the bias actually got changed after the training. - using (var cursor = trainedTfDataView.GetRowCursor(a => true)) + // Check if the bias actually got changed after the training. + using (var cursor = trainedTfDataView.GetRowCursor(a => true)) + { + trainedTfDataView.Schema.TryGetColumnIndex("b", out int bias); + var getter = cursor.GetGetter>(bias); + if (cursor.MoveNext()) { - trainedTfDataView.Schema.TryGetColumnIndex("b", out int bias); - var getter = cursor.GetGetter>(bias); - if (cursor.MoveNext()) - { - var trainedBias = default(VBuffer); - getter(ref trainedBias); - Assert.NotEqual(trainedBias.Values, new float[] { 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f }); - } + var trainedBias = default(VBuffer); + getter(ref trainedBias); + Assert.NotEqual(trainedBias.Values, new float[] { 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f }); } } } @@ -512,86 +499,85 @@ public void TensorFlowTransformMNISTConvTrainingTest() private void ExecuteTFTransformMNISTConvTrainingTest(bool shuffle, int? shuffleSeed, double expectedMicroAccuracy, double expectedMacroAccruacy) { - var model_location = "mnist_conv_model"; + const string modelLocation = "mnist_conv_model"; try { - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) - { - var dataPath = GetDataPath("Train-Tiny-28x28.txt"); - var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); + var env = new MLContext(seed: 1, conc: 1); + var dataPath = GetDataPath("Train-Tiny-28x28.txt"); + var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); - // Pipeline - var loader = TextLoader.ReadFile(env, - new TextLoader.Arguments() + // Pipeline + var loader = TextLoader.ReadFile(env, + new TextLoader.Arguments() + { + Separator = "tab", + HasHeader = false, + Column = new[] { - Separator = "tab", - HasHeader = false, - Column = new[] - { new TextLoader.Column("Label", DataKind.I8,0), new TextLoader.Column("Placeholder", DataKind.Num,new []{new TextLoader.Range(1, 784) }) - } - }, new MultiFileSource(dataPath)); + } + }, new MultiFileSource(dataPath)); - IDataView trans = new ColumnsCopyingTransformer(env, - ("Placeholder", "Features")).Transform(loader); + IDataView trans = new ColumnsCopyingTransformer(env, + ("Placeholder", "Features")).Transform(loader); + + var args = new TensorFlowTransform.Arguments() + { + ModelLocation = modelLocation, + InputColumns = new[] { "Features" }, + OutputColumns = new[] { "Prediction" }, + LabelColumn = "Label", + TensorFlowLabel = "Label", + OptimizationOperation = "MomentumOp", + LossOperation = "Loss", + MetricOperation = "Accuracy", + Epoch = 10, + LearningRateOperation = "learning_rate", + LearningRate = 0.01f, + BatchSize = 20, + ReTrain = true + }; - var args = new TensorFlowTransform.Arguments() - { - ModelLocation = model_location, - InputColumns = new[] { "Features" }, - OutputColumns = new[] { "Prediction" }, - LabelColumn = "Label", - TensorFlowLabel = "Label", - OptimizationOperation = "MomentumOp", - LossOperation = "Loss", - MetricOperation = "Accuracy", - Epoch = 10, - LearningRateOperation = "learning_rate", - LearningRate = 0.01f, - BatchSize = 20, - ReTrain = true - }; - - IDataView trainedTfDataView = null; - if (shuffle) + IDataView trainedTfDataView = null; + if (shuffle) + { + var shuffledView = new RowShufflingTransformer(env, new RowShufflingTransformer.Arguments() { - var shuffledView = new RowShufflingTransformer(env, new RowShufflingTransformer.Arguments() - { - ForceShuffle = shuffle, - ForceShuffleSeed = shuffleSeed - }, trans); - trainedTfDataView = new TensorFlowEstimator(env, args).Fit(shuffledView).Transform(trans); - } + ForceShuffle = shuffle, + ForceShuffleSeed = shuffleSeed + }, trans); + trainedTfDataView = new TensorFlowEstimator(env, args).Fit(shuffledView).Transform(trans); + } else - { - trainedTfDataView = new TensorFlowEstimator(env, args).Fit(trans).Transform(trans); - } + { + trainedTfDataView = new TensorFlowEstimator(env, args).Fit(trans).Transform(trans); + } - trans = new ColumnConcatenatingTransformer(env, "Features", "Prediction").Transform(trainedTfDataView); - trans = new TypeConvertingTransformer(env, new TypeConvertingTransformer.ColumnInfo("Label", "Label", DataKind.R4)).Transform(trans); + trans = new ColumnConcatenatingTransformer(env, "Features", "Prediction").Transform(trainedTfDataView); + trans = new TypeConvertingTransformer(env, new TypeConvertingTransformer.ColumnInfo("Label", "Label", DataKind.R4)).Transform(trans); - var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); + var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); - var cached = new CacheDataView(env, trans, prefetch: null); - var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); + var cached = new CacheDataView(env, trans, prefetch: null); + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); - var pred = trainer.Train(trainRoles); + var pred = trainer.Train(trainRoles); - // Get scorer and evaluate the predictions from test data - IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); - var metrics = Evaluate(env, testDataScorer); + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + var metrics = Evaluate(env, testDataScorer); - Assert.Equal(expectedMicroAccuracy, metrics.AccuracyMicro, 2); - Assert.Equal(expectedMacroAccruacy, metrics.AccuracyMacro, 2); + Assert.Equal(expectedMicroAccuracy, metrics.AccuracyMicro, 2); + Assert.Equal(expectedMacroAccruacy, metrics.AccuracyMacro, 2); - // Create prediction engine and test predictions - var model = env.CreatePredictionEngine(testDataScorer); + // Create prediction engine and test predictions + var model = env.CreatePredictionEngine(testDataScorer); - var sample1 = new MNISTData() - { - Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + var sample1 = new MNISTData() + { + Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -604,83 +590,81 @@ private void ExecuteTFTransformMNISTConvTrainingTest(bool shuffle, int? shuffleS 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } - }; + }; - var prediction = model.Predict(sample1); + var prediction = model.Predict(sample1); - float max = -1; - int maxIndex = -1; - for (int i = 0; i < prediction.PredictedLabels.Length; i++) + float max = -1; + int maxIndex = -1; + for (int i = 0; i < prediction.PredictedLabels.Length; i++) + { + if (prediction.PredictedLabels[i] > max) { - if (prediction.PredictedLabels[i] > max) - { - max = prediction.PredictedLabels[i]; - maxIndex = i; - } + max = prediction.PredictedLabels[i]; + maxIndex = i; } - - Assert.Equal(5, maxIndex); } + + Assert.Equal(5, maxIndex); } finally { // This test changes the state of the model. // Cleanup folder so that other test can also use the same model. - CleanUp(model_location); + CleanUp(modelLocation); } } [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only public void TensorFlowTransformMNISTConvSavedModelTest() { - var model_location = "mnist_model"; - using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) - { - var dataPath = GetDataPath("Train-Tiny-28x28.txt"); - var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); + const string modelLocation = "mnist_model"; + var env = new MLContext(seed: 1, conc: 1); + var dataPath = GetDataPath("Train-Tiny-28x28.txt"); + var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); - // Pipeline - var loader = TextLoader.ReadFile(env, - new TextLoader.Arguments() + // Pipeline + var loader = TextLoader.ReadFile(env, + new TextLoader.Arguments() + { + Separator = "tab", + HasHeader = true, + Column = new[] { - Separator = "tab", - HasHeader = true, - Column = new[] - { new TextLoader.Column("Label", DataKind.Num,0), new TextLoader.Column("Placeholder", DataKind.Num,new []{new TextLoader.Range(1, 784) }) - } - }, new MultiFileSource(dataPath)); + } + }, new MultiFileSource(dataPath)); - IDataView trans = ColumnsCopyingTransformer.Create(env, new ColumnsCopyingTransformer.Arguments() - { - Column = new[] { new ColumnsCopyingTransformer.Column() - { Name = "reshape_input", Source = "Placeholder" } - } - }, loader); - trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); - trans = new ColumnConcatenatingTransformer(env, "Features", "Softmax", "dense/Relu").Transform(trans); + IDataView trans = ColumnsCopyingTransformer.Create(env, new ColumnsCopyingTransformer.Arguments() + { + Column = new[] { new ColumnsCopyingTransformer.Column() + { Name = "reshape_input", Source = "Placeholder" } + } + }, loader); + trans = TensorFlowTransform.Create(env, trans, modelLocation, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); + trans = new ColumnConcatenatingTransformer(env, "Features", "Softmax", "dense/Relu").Transform(trans); - var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); + var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); - var cached = new CacheDataView(env, trans, prefetch: null); - var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); - var pred = trainer.Train(trainRoles); + var cached = new CacheDataView(env, trans, prefetch: null); + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); + var pred = trainer.Train(trainRoles); - // Get scorer and evaluate the predictions from test data - IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); - var metrics = Evaluate(env, testDataScorer); + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + var metrics = Evaluate(env, testDataScorer); - Assert.Equal(0.99, metrics.AccuracyMicro, 2); - Assert.Equal(1.0, metrics.AccuracyMacro, 2); + Assert.Equal(0.99, metrics.AccuracyMicro, 2); + Assert.Equal(1.0, metrics.AccuracyMacro, 2); - // Create prediction engine and test predictions - var model = env.CreatePredictionEngine(testDataScorer); + // Create prediction engine and test predictions + var model = env.CreatePredictionEngine(testDataScorer); - var sample1 = new MNISTData() - { - Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + var sample1 = new MNISTData() + { + Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -693,23 +677,22 @@ public void TensorFlowTransformMNISTConvSavedModelTest() 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } - }; + }; - var prediction = model.Predict(sample1); + var prediction = model.Predict(sample1); - float max = -1; - int maxIndex = -1; - for (int i = 0; i < prediction.PredictedLabels.Length; i++) + float max = -1; + int maxIndex = -1; + for (int i = 0; i < prediction.PredictedLabels.Length; i++) + { + if (prediction.PredictedLabels[i] > max) { - if (prediction.PredictedLabels[i] > max) - { - max = prediction.PredictedLabels[i]; - maxIndex = i; - } + max = prediction.PredictedLabels[i]; + maxIndex = i; } - - Assert.Equal(5, maxIndex); } + + Assert.Equal(5, maxIndex); } [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only @@ -773,64 +756,62 @@ public void TensorFlowTransformCifar() { var model_location = "cifar_model/frozen_model.pb"; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location); + var schema = tensorFlowModel.GetInputSchema(); + Assert.True(schema.TryGetColumnIndex("Input", out int column)); + var type = (VectorType)schema.GetColumnType(column); + var imageHeight = type.Dimensions[0]; + var imageWidth = type.Dimensions[1]; + + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location); - var schema = tensorFlowModel.GetInputSchema(); - Assert.True(schema.TryGetColumnIndex("Input", out int column)); - var type = (VectorType)schema.GetColumnType(column); - var imageHeight = type.Dimensions[0]; - var imageWidth = type.Dimensions[1]; - - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true} } - }, cropped); + }, cropped); - IDataView trans = TensorFlowTransform.Create(env, pixels, tensorFlowModel, new[] { "Output" }, new[] { "Input" }); + IDataView trans = TensorFlowTransform.Create(env, pixels, tensorFlowModel, new[] { "Output" }, new[] { "Input" }); - trans.Schema.TryGetColumnIndex("Output", out int output); - using (var cursor = trans.GetRowCursor(col => col == output)) - { - var buffer = default(VBuffer); - var getter = cursor.GetGetter>(output); - var numRows = 0; - while (cursor.MoveNext()) - { - getter(ref buffer); - Assert.Equal(10, buffer.Length); - numRows += 1; - } - Assert.Equal(3, numRows); + trans.Schema.TryGetColumnIndex("Output", out int output); + using (var cursor = trans.GetRowCursor(col => col == output)) + { + var buffer = default(VBuffer); + var getter = cursor.GetGetter>(output); + var numRows = 0; + while (cursor.MoveNext()) + { + getter(ref buffer); + Assert.Equal(10, buffer.Length); + numRows += 1; } + Assert.Equal(3, numRows); } } @@ -839,64 +820,62 @@ public void TensorFlowTransformCifarSavedModel() { var model_location = "cifar_saved_model"; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location); + var schema = tensorFlowModel.GetInputSchema(); + Assert.True(schema.TryGetColumnIndex("Input", out int column)); + var type = (VectorType)schema.GetColumnType(column); + var imageHeight = type.Dimensions[0]; + var imageWidth = type.Dimensions[1]; + + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location); - var schema = tensorFlowModel.GetInputSchema(); - Assert.True(schema.TryGetColumnIndex("Input", out int column)); - var type = (VectorType)schema.GetColumnType(column); - var imageHeight = type.Dimensions[0]; - var imageWidth = type.Dimensions[1]; - - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true} } - }, cropped); + }, cropped); - IDataView trans = TensorFlowTransform.Create(env, pixels, tensorFlowModel, new[] { "Output" }, new[] { "Input" }); + IDataView trans = TensorFlowTransform.Create(env, pixels, tensorFlowModel, new[] { "Output" }, new[] { "Input" }); - trans.Schema.TryGetColumnIndex("Output", out int output); - using (var cursor = trans.GetRowCursor(col => col == output)) - { - var buffer = default(VBuffer); - var getter = cursor.GetGetter>(output); - var numRows = 0; - while (cursor.MoveNext()) - { - getter(ref buffer); - Assert.Equal(10, buffer.Length); - numRows += 1; - } - Assert.Equal(3, numRows); + trans.Schema.TryGetColumnIndex("Output", out int output); + using (var cursor = trans.GetRowCursor(col => col == output)) + { + var buffer = default(VBuffer); + var getter = cursor.GetGetter>(output); + var numRows = 0; + while (cursor.MoveNext()) + { + getter(ref buffer); + Assert.Equal(10, buffer.Length); + numRows += 1; } + Assert.Equal(3, numRows); } } @@ -905,53 +884,51 @@ public void TensorFlowTransformCifarInvalidShape() { var model_location = "cifar_model/frozen_model.pb"; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var imageHeight = 28; + var imageWidth = 28; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = TextLoader.Create(env, new TextLoader.Arguments() { - var imageHeight = 28; - var imageWidth = 28; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.Create(env, new TextLoader.Arguments() + Column = new[] { - Column = new[] - { new TextLoader.Column("ImagePath", DataKind.TX, 0), new TextLoader.Column("Name", DataKind.TX, 1), } - }, new MultiFileSource(dataFile)); - var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + }, new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] { - Column = new ImageLoaderTransform.Column[1] - { new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } - }, - ImageFolder = imageFolder - }, data); - var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() - { - Column = new ImageResizerTransform.Column[1]{ + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} } - }, images); + }, images); - var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() - { - Column = new ImagePixelExtractorTransform.Column[1]{ + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + Column = new ImagePixelExtractorTransform.Column[1]{ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true} } - }, cropped); + }, cropped); - var thrown = false; - try - { - IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" }); - } - catch - { - thrown = true; - } - Assert.True(thrown); + var thrown = false; + try + { + IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" }); + } + catch + { + thrown = true; } + Assert.True(thrown); } } } diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs index 2f2820692f..01b7a7aa90 100644 --- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs @@ -133,10 +133,8 @@ void TestOldSavingAndLoading() [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // x86 output differs from Baseline void TestCommandLine() { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=a:R4:0-3 col=b:R4:0-3} xf=TFTransform{inputs=a inputs=b outputs=c modellocation={model_matmul/frozen_saved_model.pb}}"}), (int)0); - } + var env = new MLContext(); + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=a:R4:0-3 col=b:R4:0-3} xf=TFTransform{inputs=a inputs=b outputs=c modellocation={model_matmul/frozen_saved_model.pb}}" }), (int)0); } [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only @@ -144,91 +142,87 @@ public void TestTensorFlowStatic() { var modelLocation = "cifar_model/frozen_model.pb"; - using (var env = new ConsoleEnvironment()) - { - var imageHeight = 32; - var imageWidth = 32; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); + var env = new MLContext(); + var imageHeight = 32; + var imageWidth = 32; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.CreateReader(env, ctx => ( - imagePath: ctx.LoadText(0), - name: ctx.LoadText(1))) - .Read(dataFile); + var data = TextLoader.CreateReader(env, ctx => ( + imagePath: ctx.LoadText(0), + name: ctx.LoadText(1))) + .Read(dataFile); - // Note that CamelCase column names are there to match the TF graph node names. - var pipe = data.MakeNewEstimator() - .Append(row => ( - row.name, - Input: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true))) - .Append(row => (row.name, Output: row.Input.ApplyTensorFlowGraph(modelLocation))); + // Note that CamelCase column names are there to match the TF graph node names. + var pipe = data.MakeNewEstimator() + .Append(row => ( + row.name, + Input: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true))) + .Append(row => (row.name, Output: row.Input.ApplyTensorFlowGraph(modelLocation))); - TestEstimatorCore(pipe.AsDynamic, data.AsDynamic); + TestEstimatorCore(pipe.AsDynamic, data.AsDynamic); - var result = pipe.Fit(data).Transform(data).AsDynamic; - result.Schema.TryGetColumnIndex("Output", out int output); - using (var cursor = result.GetRowCursor(col => col == output)) + var result = pipe.Fit(data).Transform(data).AsDynamic; + result.Schema.TryGetColumnIndex("Output", out int output); + using (var cursor = result.GetRowCursor(col => col == output)) + { + var buffer = default(VBuffer); + var getter = cursor.GetGetter>(output); + var numRows = 0; + while (cursor.MoveNext()) { - var buffer = default(VBuffer); - var getter = cursor.GetGetter>(output); - var numRows = 0; - while (cursor.MoveNext()) - { - getter(ref buffer); - Assert.Equal(10, buffer.Length); - numRows += 1; - } - Assert.Equal(3, numRows); + getter(ref buffer); + Assert.Equal(10, buffer.Length); + numRows += 1; } + Assert.Equal(3, numRows); } } [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] public void TestTensorFlowStaticWithSchema() { - var modelLocation = "cifar_model/frozen_model.pb"; + const string modelLocation = "cifar_model/frozen_model.pb"; - using (var env = new ConsoleEnvironment()) - { - var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, modelLocation); - var schema = tensorFlowModel.GetInputSchema(); - Assert.True(schema.TryGetColumnIndex("Input", out int column)); - var type = (VectorType)schema.GetColumnType(column); - var imageHeight = type.Dimensions[0]; - var imageWidth = type.Dimensions[1]; + var env = new MLContext(); + var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, modelLocation); + var schema = tensorFlowModel.GetInputSchema(); + Assert.True(schema.TryGetColumnIndex("Input", out int column)); + var type = (VectorType)schema.GetColumnType(column); + var imageHeight = type.Dimensions[0]; + var imageWidth = type.Dimensions[1]; - var dataFile = GetDataPath("images/images.tsv"); - var imageFolder = Path.GetDirectoryName(dataFile); + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); - var data = TextLoader.CreateReader(env, ctx => ( - imagePath: ctx.LoadText(0), - name: ctx.LoadText(1))) - .Read(dataFile); + var data = TextLoader.CreateReader(env, ctx => ( + imagePath: ctx.LoadText(0), + name: ctx.LoadText(1))) + .Read(dataFile); - // Note that CamelCase column names are there to match the TF graph node names. - var pipe = data.MakeNewEstimator() - .Append(row => ( - row.name, - Input: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true))) - .Append(row => (row.name, Output: row.Input.ApplyTensorFlowGraph(tensorFlowModel))); + // Note that CamelCase column names are there to match the TF graph node names. + var pipe = data.MakeNewEstimator() + .Append(row => ( + row.name, + Input: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true))) + .Append(row => (row.name, Output: row.Input.ApplyTensorFlowGraph(tensorFlowModel))); - TestEstimatorCore(pipe.AsDynamic, data.AsDynamic); + TestEstimatorCore(pipe.AsDynamic, data.AsDynamic); - var result = pipe.Fit(data).Transform(data).AsDynamic; - result.Schema.TryGetColumnIndex("Output", out int output); - using (var cursor = result.GetRowCursor(col => col == output)) + var result = pipe.Fit(data).Transform(data).AsDynamic; + result.Schema.TryGetColumnIndex("Output", out int output); + using (var cursor = result.GetRowCursor(col => col == output)) + { + var buffer = default(VBuffer); + var getter = cursor.GetGetter>(output); + var numRows = 0; + while (cursor.MoveNext()) { - var buffer = default(VBuffer); - var getter = cursor.GetGetter>(output); - var numRows = 0; - while (cursor.MoveNext()) - { - getter(ref buffer); - Assert.Equal(10, buffer.Length); - numRows += 1; - } - Assert.Equal(3, numRows); + getter(ref buffer); + Assert.Equal(10, buffer.Length); + numRows += 1; } + Assert.Equal(3, numRows); } } diff --git a/test/Microsoft.ML.Tests/TermEstimatorTests.cs b/test/Microsoft.ML.Tests/TermEstimatorTests.cs index 1c465c5669..4caeaa27ae 100644 --- a/test/Microsoft.ML.Tests/TermEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TermEstimatorTests.cs @@ -140,7 +140,7 @@ void TestMetadataCopy() var dataView = ComponentCreation.CreateDataView(Env, data); var termEst = new ValueToKeyMappingEstimator(Env, new[] { new ValueToKeyMappingTransformer.ColumnInfo("Term" ,"T") }); - + var termTransformer = termEst.Fit(dataView); var result = termTransformer.Transform(dataView); result.Schema.TryGetColumnIndex("T", out int termIndex); @@ -155,10 +155,8 @@ void TestMetadataCopy() [Fact] void TestCommandLine() { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} in=f:\2.txt" })); - } + var env = new MLContext(); + Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} in=f:\2.txt" })); } private void ValidateTermTransformer(IDataView result) diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 575146fbf7..fb83ffab24 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -46,7 +46,7 @@ public void TestTextLoaderDataTypes() Assert.True(cursor.MoveNext()); - sbyte[] sByteTargets = new sbyte[] { sbyte.MinValue, sbyte.MaxValue, default}; + sbyte[] sByteTargets = new sbyte[] { sbyte.MinValue, sbyte.MaxValue, default }; short[] shortTargets = new short[] { short.MinValue, short.MaxValue, default }; int[] intTargets = new int[] { int.MinValue, int.MaxValue, default }; long[] longTargets = new long[] { long.MinValue, long.MaxValue, default }; @@ -96,7 +96,7 @@ public void TestTextLoaderInvalidLongMin() "loader=Text{col=DvInt8:I8:0 sep=comma}", }, logCurs: true); } - catch(Exception ex) + catch (Exception ex) { Assert.Equal("Could not parse value -9223372036854775809 in line 1, column DvInt8", ex.Message); return; @@ -142,7 +142,7 @@ public TextLoaderTests(ITestOutputHelper output) public void ConstructorDoesntThrow() { Assert.NotNull(new Legacy.Data.TextLoader("fakeFile.txt").CreateFrom()); - Assert.NotNull(new Legacy.Data.TextLoader("fakeFile.txt").CreateFrom(useHeader:true)); + Assert.NotNull(new Legacy.Data.TextLoader("fakeFile.txt").CreateFrom(useHeader: true)); Assert.NotNull(new Legacy.Data.TextLoader("fakeFile.txt").CreateFrom()); Assert.NotNull(new Legacy.Data.TextLoader("fakeFile.txt").CreateFrom(useHeader: false)); Assert.NotNull(new Legacy.Data.TextLoader("fakeFile.txt").CreateFrom(useHeader: false, supportSparse: false, trimWhitespace: false)); @@ -158,15 +158,13 @@ public void CanSuccessfullyApplyATransform() { var loader = new Legacy.Data.TextLoader("fakeFile.txt").CreateFrom(); - using (var environment = new ConsoleEnvironment()) - { - Experiment experiment = environment.CreateExperiment(); - Legacy.ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep; + var environment = new MLContext(); + Experiment experiment = environment.CreateExperiment(); + Legacy.ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep; - Assert.NotNull(output.Data); - Assert.NotNull(output.Data.VarName); - Assert.Null(output.Model); - } + Assert.NotNull(output.Data); + Assert.NotNull(output.Data.VarName); + Assert.Null(output.Model); } [Fact] @@ -174,56 +172,54 @@ public void CanSuccessfullyRetrieveQuotedData() { string dataPath = GetDataPath("QuotingData.csv"); var loader = new Legacy.Data.TextLoader(dataPath).CreateFrom(useHeader: true, separator: ',', allowQuotedStrings: true, supportSparse: false); - - using (var environment = new ConsoleEnvironment()) - { - Experiment experiment = environment.CreateExperiment(); - Legacy.ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep; - experiment.Compile(); - loader.SetInput(environment, experiment); - experiment.Run(); + var environment = new MLContext(); + Experiment experiment = environment.CreateExperiment(); + Legacy.ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep; - IDataView data = experiment.GetOutput(output.Data); - Assert.NotNull(data); + experiment.Compile(); + loader.SetInput(environment, experiment); + experiment.Run(); - using (var cursor = data.GetRowCursor((a => true))) - { - var IDGetter = cursor.GetGetter(0); - var TextGetter = cursor.GetGetter>(1); + IDataView data = experiment.GetOutput(output.Data); + Assert.NotNull(data); - Assert.True(cursor.MoveNext()); + using (var cursor = data.GetRowCursor((a => true))) + { + var IDGetter = cursor.GetGetter(0); + var TextGetter = cursor.GetGetter>(1); + + Assert.True(cursor.MoveNext()); - float ID = 0; - IDGetter(ref ID); - Assert.Equal(1, ID); + float ID = 0; + IDGetter(ref ID); + Assert.Equal(1, ID); - ReadOnlyMemory Text = new ReadOnlyMemory(); - TextGetter(ref Text); - Assert.Equal("This text contains comma, within quotes.", Text.ToString()); + ReadOnlyMemory Text = new ReadOnlyMemory(); + TextGetter(ref Text); + Assert.Equal("This text contains comma, within quotes.", Text.ToString()); - Assert.True(cursor.MoveNext()); + Assert.True(cursor.MoveNext()); - ID = 0; - IDGetter(ref ID); - Assert.Equal(2, ID); + ID = 0; + IDGetter(ref ID); + Assert.Equal(2, ID); - Text = new ReadOnlyMemory(); - TextGetter(ref Text); - Assert.Equal("This text contains extra punctuations and special characters.;*<>?!@#$%^&*()_+=-{}|[]:;'", Text.ToString()); + Text = new ReadOnlyMemory(); + TextGetter(ref Text); + Assert.Equal("This text contains extra punctuations and special characters.;*<>?!@#$%^&*()_+=-{}|[]:;'", Text.ToString()); - Assert.True(cursor.MoveNext()); + Assert.True(cursor.MoveNext()); - ID = 0; - IDGetter(ref ID); - Assert.Equal(3, ID); + ID = 0; + IDGetter(ref ID); + Assert.Equal(3, ID); - Text = new ReadOnlyMemory(); - TextGetter(ref Text); - Assert.Equal("This text has no quotes", Text.ToString()); + Text = new ReadOnlyMemory(); + TextGetter(ref Text); + Assert.Equal("This text has no quotes", Text.ToString()); - Assert.False(cursor.MoveNext()); - } + Assert.False(cursor.MoveNext()); } } @@ -233,21 +229,20 @@ public void CanSuccessfullyRetrieveSparseData() string dataPath = GetDataPath("SparseData.txt"); var loader = new Legacy.Data.TextLoader(dataPath).CreateFrom(useHeader: true, allowQuotedStrings: false, supportSparse: true); - using (var environment = new ConsoleEnvironment()) - { - Experiment experiment = environment.CreateExperiment(); - Legacy.ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep; + var environment = new MLContext(); + Experiment experiment = environment.CreateExperiment(); + Legacy.ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep; - experiment.Compile(); - loader.SetInput(environment, experiment); - experiment.Run(); + experiment.Compile(); + loader.SetInput(environment, experiment); + experiment.Run(); - IDataView data = experiment.GetOutput(output.Data); - Assert.NotNull(data); + IDataView data = experiment.GetOutput(output.Data); + Assert.NotNull(data); - using (var cursor = data.GetRowCursor((a => true))) - { - var getters = new ValueGetter[]{ + using (var cursor = data.GetRowCursor((a => true))) + { + var getters = new ValueGetter[]{ cursor.GetGetter(0), cursor.GetGetter(1), cursor.GetGetter(2), @@ -256,38 +251,37 @@ public void CanSuccessfullyRetrieveSparseData() }; - Assert.True(cursor.MoveNext()); - - float[] targets = new float[] { 1, 2, 3, 4, 5 }; - for (int i = 0; i < getters.Length; i++) - { - float value = 0; - getters[i](ref value); - Assert.Equal(targets[i], value); - } + Assert.True(cursor.MoveNext()); - Assert.True(cursor.MoveNext()); + float[] targets = new float[] { 1, 2, 3, 4, 5 }; + for (int i = 0; i < getters.Length; i++) + { + float value = 0; + getters[i](ref value); + Assert.Equal(targets[i], value); + } - targets = new float[] { 0, 0, 0, 4, 5 }; - for (int i = 0; i < getters.Length; i++) - { - float value = 0; - getters[i](ref value); - Assert.Equal(targets[i], value); - } + Assert.True(cursor.MoveNext()); - Assert.True(cursor.MoveNext()); + targets = new float[] { 0, 0, 0, 4, 5 }; + for (int i = 0; i < getters.Length; i++) + { + float value = 0; + getters[i](ref value); + Assert.Equal(targets[i], value); + } - targets = new float[] { 0, 2, 0, 0, 0 }; - for (int i = 0; i < getters.Length; i++) - { - float value = 0; - getters[i](ref value); - Assert.Equal(targets[i], value); - } + Assert.True(cursor.MoveNext()); - Assert.False(cursor.MoveNext()); + targets = new float[] { 0, 2, 0, 0, 0 }; + for (int i = 0; i < getters.Length; i++) + { + float value = 0; + getters[i](ref value); + Assert.Equal(targets[i], value); } + + Assert.False(cursor.MoveNext()); } } @@ -298,52 +292,50 @@ public void CanSuccessfullyTrimSpaces() string dataPath = GetDataPath("TrimData.csv"); var loader = new Legacy.Data.TextLoader(dataPath).CreateFrom(useHeader: true, separator: ',', allowQuotedStrings: false, supportSparse: false, trimWhitespace: true); - using (var environment = new ConsoleEnvironment()) - { - Experiment experiment = environment.CreateExperiment(); - Legacy.ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep; + var environment = new MLContext(); + Experiment experiment = environment.CreateExperiment(); + Legacy.ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep; - experiment.Compile(); - loader.SetInput(environment, experiment); - experiment.Run(); + experiment.Compile(); + loader.SetInput(environment, experiment); + experiment.Run(); - IDataView data = experiment.GetOutput(output.Data); - Assert.NotNull(data); + IDataView data = experiment.GetOutput(output.Data); + Assert.NotNull(data); - using (var cursor = data.GetRowCursor((a => true))) - { - var IDGetter = cursor.GetGetter(0); - var TextGetter = cursor.GetGetter>(1); + using (var cursor = data.GetRowCursor((a => true))) + { + var IDGetter = cursor.GetGetter(0); + var TextGetter = cursor.GetGetter>(1); + + Assert.True(cursor.MoveNext()); - Assert.True(cursor.MoveNext()); + float ID = 0; + IDGetter(ref ID); + Assert.Equal(1, ID); - float ID = 0; - IDGetter(ref ID); - Assert.Equal(1, ID); + ReadOnlyMemory Text = new ReadOnlyMemory(); + TextGetter(ref Text); + Assert.Equal("There is a space at the end", Text.ToString()); - ReadOnlyMemory Text = new ReadOnlyMemory(); - TextGetter(ref Text); - Assert.Equal("There is a space at the end", Text.ToString()); + Assert.True(cursor.MoveNext()); - Assert.True(cursor.MoveNext()); + ID = 0; + IDGetter(ref ID); + Assert.Equal(2, ID); - ID = 0; - IDGetter(ref ID); - Assert.Equal(2, ID); + Text = new ReadOnlyMemory(); + TextGetter(ref Text); + Assert.Equal("There is no space at the end", Text.ToString()); - Text = new ReadOnlyMemory(); - TextGetter(ref Text); - Assert.Equal("There is no space at the end", Text.ToString()); - - Assert.False(cursor.MoveNext()); - } + Assert.False(cursor.MoveNext()); } } [Fact] public void ThrowsExceptionWithPropertyName() { - Exception ex = Assert.Throws( () => new Legacy.Data.TextLoader("fakefile.txt").CreateFrom() ); + Exception ex = Assert.Throws(() => new Legacy.Data.TextLoader("fakefile.txt").CreateFrom()); Assert.StartsWith("Field or property String1 is missing ColumnAttribute", ex.Message); } @@ -351,9 +343,9 @@ public void ThrowsExceptionWithPropertyName() public void CanSuccessfullyColumnNameProperty() { var loader = new Legacy.Data.TextLoader("fakefile.txt").CreateFrom(); - Assert.Equal("Col1",loader.Arguments.Column[0].Name); - Assert.Equal("Col2",loader.Arguments.Column[1].Name); - Assert.Equal("String_3",loader.Arguments.Column[2].Name); + Assert.Equal("Col1", loader.Arguments.Column[0].Name); + Assert.Equal("Col2", loader.Arguments.Column[1].Name); + Assert.Equal("String_3", loader.Arguments.Column[2].Name); } public class QuoteInput @@ -413,7 +405,7 @@ public class ModelWithColumnNameAttribute [Column("1")] [ColumnName("Col2")] public string String_2; - [Column("3")] + [Column("3")] public string String_3; } } diff --git a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs index 5d951f1ce0..09714ffd22 100644 --- a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs @@ -40,32 +40,28 @@ class TestMetaClass void TestWorking() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new ColumnsCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); - var transformer = est.Fit(dataView); - var result = transformer.Transform(dataView); - ValidateCopyColumnTransformer(result); - } + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var est = new ColumnsCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var transformer = est.Fit(dataView); + var result = transformer.Transform(dataView); + ValidateCopyColumnTransformer(result); } [Fact] void TestBadOriginalSchema() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var est = new ColumnsCopyingEstimator(env, new[] { ("D", "A"), ("B", "E") }); + try + { + var transformer = est.Fit(dataView); + Assert.False(true); + } + catch { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new ColumnsCopyingEstimator(env, new[] { ("D", "A"), ("B", "E") }); - try - { - var transformer = est.Fit(dataView); - Assert.False(true); - } - catch - { - } } } @@ -74,20 +70,18 @@ void TestBadTransformSchema() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var xydata = new[] { new TestClassXY() { X = 10, Y = 100 }, new TestClassXY() { X = -1, Y = -100 } }; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var xyDataView = ComponentCreation.CreateDataView(env, xydata); + var est = new ColumnsCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var transformer = est.Fit(dataView); + try + { + var result = transformer.Transform(xyDataView); + Assert.False(true); + } + catch { - var dataView = ComponentCreation.CreateDataView(env, data); - var xyDataView = ComponentCreation.CreateDataView(env, xydata); - var est = new ColumnsCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); - var transformer = est.Fit(dataView); - try - { - var result = transformer.Transform(xyDataView); - Assert.False(true); - } - catch - { - } } } @@ -95,20 +89,17 @@ void TestBadTransformSchema() void TestSavingAndLoading() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var est = new ColumnsCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var transformer = est.Fit(dataView); + using (var ms = new MemoryStream()) { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new ColumnsCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); - var transformer = est.Fit(dataView); - using (var ms = new MemoryStream()) - { - transformer.SaveTo(env, ms); - ms.Position = 0; - var loadedTransformer = TransformerChain.LoadFrom(env, ms); - var result = loadedTransformer.Transform(dataView); - ValidateCopyColumnTransformer(result); - } - + transformer.SaveTo(env, ms); + ms.Position = 0; + var loadedTransformer = TransformerChain.LoadFrom(env, ms); + var result = loadedTransformer.Transform(dataView); + ValidateCopyColumnTransformer(result); } } @@ -116,20 +107,18 @@ void TestSavingAndLoading() void TestOldSavingAndLoading() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; - using (var env = new ConsoleEnvironment()) + IHostEnvironment env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var est = new ColumnsCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var transformer = est.Fit(dataView); + var result = transformer.Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) { - var dataView = ComponentCreation.CreateDataView(env, data); - var est = new ColumnsCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); - var transformer = est.Fit(dataView); - var result = transformer.Transform(dataView); - var resultRoles = new RoleMappedData(result); - using (var ms = new MemoryStream()) - { - TrainUtils.SaveModel(env, env.Start("saving"), ms, null, resultRoles); - ms.Position = 0; - var loadedView = ModelFileUtils.LoadTransforms(env, dataView, ms); - ValidateCopyColumnTransformer(loadedView); - } + TrainUtils.SaveModel(env, env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(env, dataView, ms); + ValidateCopyColumnTransformer(loadedView); } } @@ -137,37 +126,33 @@ void TestOldSavingAndLoading() void TestMetadataCopy() { var data = new[] { new TestMetaClass() { Term = "A", NotUsed = 1 }, new TestMetaClass() { Term = "B" }, new TestMetaClass() { Term = "C" } }; - using (var env = new ConsoleEnvironment()) + var env = new MLContext(); + var dataView = ComponentCreation.CreateDataView(env, data); + var term = ValueToKeyMappingTransformer.Create(env, new ValueToKeyMappingTransformer.Arguments() { - var dataView = ComponentCreation.CreateDataView(env, data); - var term = ValueToKeyMappingTransformer.Create(env, new ValueToKeyMappingTransformer.Arguments() - { - Column = new[] { new ValueToKeyMappingTransformer.Column() { Source = "Term", Name = "T" } } - }, dataView); - var est = new ColumnsCopyingEstimator(env, "T", "T1"); - var transformer = est.Fit(term); - var result = transformer.Transform(term); - result.Schema.TryGetColumnIndex("T", out int termIndex); - result.Schema.TryGetColumnIndex("T1", out int copyIndex); - var names1 = default(VBuffer>); - var names2 = default(VBuffer>); - var type1 = result.Schema.GetColumnType(termIndex); - var itemType1 = (type1 as VectorType)?.ItemType ?? type1; - int size = (itemType1 as KeyType)?.Count ?? -1; - var type2 = result.Schema.GetColumnType(copyIndex); - result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, termIndex, ref names1); - result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, copyIndex, ref names2); - Assert.True(CompareVec(in names1, in names2, size, (a, b) => a.Span.SequenceEqual(b.Span))); - } + Column = new[] { new ValueToKeyMappingTransformer.Column() { Source = "Term", Name = "T" } } + }, dataView); + var est = new ColumnsCopyingEstimator(env, "T", "T1"); + var transformer = est.Fit(term); + var result = transformer.Transform(term); + result.Schema.TryGetColumnIndex("T", out int termIndex); + result.Schema.TryGetColumnIndex("T1", out int copyIndex); + var names1 = default(VBuffer>); + var names2 = default(VBuffer>); + var type1 = result.Schema.GetColumnType(termIndex); + var itemType1 = (type1 as VectorType)?.ItemType ?? type1; + int size = (itemType1 as KeyType)?.Count ?? -1; + var type2 = result.Schema.GetColumnType(copyIndex); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, termIndex, ref names1); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, copyIndex, ref names2); + Assert.True(CompareVec(in names1, in names2, size, (a, b) => a.Span.SequenceEqual(b.Span))); } [Fact] void TestCommandLine() { - using (var env = new ConsoleEnvironment()) - { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=copy{col=B:A} in=f:\1.txt" }), (int)0); - } + var env = new MLContext(); + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=copy{col=B:A} in=f:\1.txt" }), (int)0); } private void ValidateCopyColumnTransformer(IDataView result) diff --git a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs index 2b780cd808..1e5e053c96 100644 --- a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs @@ -66,23 +66,21 @@ public void TestCustomTransformer() IDataView transformedData; // We create a temporary environment to instantiate the custom transformer. This is to ensure that we don't need the same // environment for saving and loading. - using (var tempoEnv = new ConsoleEnvironment()) + var tempoEnv = new MLContext(); + var customEst = new CustomMappingEstimator(tempoEnv, MyLambda.MyAction, "MyLambda"); + + try { - var customEst = new CustomMappingEstimator(tempoEnv, MyLambda.MyAction, "MyLambda"); - - try - { - TestEstimatorCore(customEst, data); - Assert.True(false, "Cannot work without MEF injection"); - } - catch (Exception) - { - // REVIEW: we should have a common mechanism that will make sure this is 'our' exception thrown. - } - ML.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(MyLambda))); TestEstimatorCore(customEst, data); - transformedData = customEst.Fit(data).Transform(data); + Assert.True(false, "Cannot work without MEF injection"); + } + catch (Exception) + { + // REVIEW: we should have a common mechanism that will make sure this is 'our' exception thrown. } + ML.CompositionContainer = new CompositionContainer(new TypeCatalog(typeof(MyLambda))); + TestEstimatorCore(customEst, data); + transformedData = customEst.Fit(data).Transform(data); var inputs = transformedData.AsEnumerable(ML, true); var outputs = transformedData.AsEnumerable(ML, true); diff --git a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs index 3cc11f5d20..43ad1f7d70 100644 --- a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; @@ -15,14 +16,14 @@ namespace Microsoft.ML.Tests.Transformers { public sealed class PcaTests : TestDataPipeBase { - private readonly ConsoleEnvironment _env; + private readonly IHostEnvironment _env; private readonly string _dataSource; private readonly TextSaver _saver; public PcaTests(ITestOutputHelper helper) : base(helper) { - _env = new ConsoleEnvironment(seed: 1); + _env = new MLContext(seed: 1); _dataSource = GetDataPath("generated_regression_dataset.csv"); _saver = new TextSaver(_env, new TextSaver.Arguments { Silent = true, OutputHeader = false }); } diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index c93a82db0b..3c46a729fc 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -2,16 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Text; using Microsoft.ML.Transforms.Categorical; +using Microsoft.ML.Transforms.Conversions; using System.IO; using Xunit; using Xunit.Abstractions; -using Microsoft.ML.Transforms.Conversions; namespace Microsoft.ML.Tests.Transformers { @@ -239,7 +240,7 @@ public void NgramWorkout() [Fact(Skip = "LDA transform cannot be trained on empty data, schema propagation fails")] public void LdaWorkout() { - var env = new ConsoleEnvironment(seed: 42, conc: 1); + IHostEnvironment env = new MLContext(seed: 42, conc: 1); string sentimentDataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); var data = TextLoader.CreateReader(env, ctx => ( label: ctx.LoadBool(0), diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index a72e131bc4..834e3f669f 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -13,114 +13,109 @@ namespace Microsoft.ML.Tests public sealed class TimeSeries { - public class Prediction + private sealed class Prediction { +#pragma warning disable CS0649 [VectorType(4)] public double[] Change; +#pragma warning restore CS0649 } - sealed class Data + private sealed class Data { public float Value; - public Data(float value) - { - Value = value; - } + public Data(float value) => Value = value; } [Fact] public void ChangeDetection() { - using (var env = new ConsoleEnvironment(conc: 1)) + var env = new MLContext(conc: 1); + const int size = 10; + List data = new List(size); + var dataView = env.CreateStreamingDataView(data); + for (int i = 0; i < size / 2; i++) + data.Add(new Data(5)); + + for (int i = 0; i < size / 2; i++) + data.Add(new Data((float)(5 + i * 1.1))); + + var args = new IidChangePointDetector.Arguments() { - const int size = 10; - List data = new List(size); - var dataView = env.CreateStreamingDataView(data); - for (int i = 0; i < size / 2; i++) - data.Add(new Data(5)); - - for (int i = 0; i < size / 2; i++) - data.Add(new Data((float)(5 + i * 1.1))); - - var args = new IidChangePointDetector.Arguments() - { - Confidence = 80, - Source = "Value", - Name = "Change", - ChangeHistoryLength = size - }; - // Train - var detector = new IidChangePointEstimator(env, args).Fit(dataView); - // Transform - var output = detector.Transform(dataView); - // Get predictions - var enumerator = output.AsEnumerable(env, true).GetEnumerator(); - Prediction row = null; - List expectedValues = new List() { 0, 5, 0.5, 5.1200000000000114E-08, 0, 5, 0.4999999995, 5.1200000046080209E-08, 0, 5, 0.4999999995, 5.1200000092160303E-08, + Confidence = 80, + Source = "Value", + Name = "Change", + ChangeHistoryLength = size + }; + // Train + var detector = new IidChangePointEstimator(env, args).Fit(dataView); + // Transform + var output = detector.Transform(dataView); + // Get predictions + var enumerator = output.AsEnumerable(env, true).GetEnumerator(); + Prediction row = null; + List expectedValues = new List() { 0, 5, 0.5, 5.1200000000000114E-08, 0, 5, 0.4999999995, 5.1200000046080209E-08, 0, 5, 0.4999999995, 5.1200000092160303E-08, 0, 5, 0.4999999995, 5.12000001382404E-08}; - int index = 0; - while (enumerator.MoveNext() && index < expectedValues.Count) - { - row = enumerator.Current; - - Assert.Equal(expectedValues[index++], row.Change[0]); - Assert.Equal(expectedValues[index++], row.Change[1]); - Assert.Equal(expectedValues[index++], row.Change[2]); - Assert.Equal(expectedValues[index++], row.Change[3]); - } + int index = 0; + while (enumerator.MoveNext() && index < expectedValues.Count) + { + row = enumerator.Current; + + Assert.Equal(expectedValues[index++], row.Change[0]); + Assert.Equal(expectedValues[index++], row.Change[1]); + Assert.Equal(expectedValues[index++], row.Change[2]); + Assert.Equal(expectedValues[index++], row.Change[3]); } } [Fact] public void ChangePointDetectionWithSeasonality() { - using (var env = new ConsoleEnvironment(conc: 1)) + var env = new MLContext(conc: 1); + const int ChangeHistorySize = 10; + const int SeasonalitySize = 10; + const int NumberOfSeasonsInTraining = 5; + const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize; + + List data = new List(); + var dataView = env.CreateStreamingDataView(data); + + var args = new SsaChangePointDetector.Arguments() { - const int ChangeHistorySize = 10; - const int SeasonalitySize = 10; - const int NumberOfSeasonsInTraining = 5; - const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize; - - List data = new List(); - var dataView = env.CreateStreamingDataView(data); - - var args = new SsaChangePointDetector.Arguments() - { - Confidence = 95, - Source = "Value", - Name = "Change", - ChangeHistoryLength = ChangeHistorySize, - TrainingWindowSize = MaxTrainingSize, - SeasonalWindowSize = SeasonalitySize - }; - - for (int j = 0; j < NumberOfSeasonsInTraining; j++) - for (int i = 0; i < SeasonalitySize; i++) - data.Add(new Data(i)); - - for (int i = 0; i < ChangeHistorySize; i++) - data.Add(new Data(i * 100)); - - // Train - var detector = new SsaChangePointEstimator(env, args).Fit(dataView); - // Transform - var output = detector.Transform(dataView); - // Get predictions - var enumerator = output.AsEnumerable(env, true).GetEnumerator(); - Prediction row = null; - List expectedValues = new List() { 0, -3.31410598754883, 0.5, 5.12000000000001E-08, 0, 1.5700820684432983, 5.2001145245395008E-07, + Confidence = 95, + Source = "Value", + Name = "Change", + ChangeHistoryLength = ChangeHistorySize, + TrainingWindowSize = MaxTrainingSize, + SeasonalWindowSize = SeasonalitySize + }; + + for (int j = 0; j < NumberOfSeasonsInTraining; j++) + for (int i = 0; i < SeasonalitySize; i++) + data.Add(new Data(i)); + + for (int i = 0; i < ChangeHistorySize; i++) + data.Add(new Data(i * 100)); + + // Train + var detector = new SsaChangePointEstimator(env, args).Fit(dataView); + // Transform + var output = detector.Transform(dataView); + // Get predictions + var enumerator = output.AsEnumerable(env, true).GetEnumerator(); + Prediction row = null; + List expectedValues = new List() { 0, -3.31410598754883, 0.5, 5.12000000000001E-08, 0, 1.5700820684432983, 5.2001145245395008E-07, 0.012414560443710681, 0, 1.2854313254356384, 0.28810801662678009, 0.02038940454467935, 0, -1.0950627326965332, 0.36663890634019225, 0.026956459625565483}; - int index = 0; - while (enumerator.MoveNext() && index < expectedValues.Count) - { - row = enumerator.Current; - Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); // Alert - Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); // Raw score - Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); // P-Value score - Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score - } + int index = 0; + while (enumerator.MoveNext() && index < expectedValues.Count) + { + row = enumerator.Current; + Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); // Alert + Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); // Raw score + Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); // P-Value score + Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score } } }