diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index fb15345497..bad42afd5d 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -17,8 +17,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath", "src
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.PipelineInference", "src\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj", "{2D7391C9-8254-4B8F-BF26-FADAF8F02F44}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InferenceTesting", "test\Microsoft.ML.InferenceTesting\Microsoft.ML.InferenceTesting.csproj", "{E278EC99-E6EE-49FE-92E6-0A309A478D98}"
-EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Data", "src\Microsoft.ML.Data\Microsoft.ML.Data.csproj", "{AD92D96B-0E96-4F22-8DCE-892E13B1F282}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Onnx", "src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
@@ -175,14 +173,6 @@ Global
{2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release|Any CPU.Build.0 = Release|Any CPU
{2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
{2D7391C9-8254-4B8F-BF26-FADAF8F02F44}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
- {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
- {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
- {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release|Any CPU.Build.0 = Release|Any CPU
- {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
- {E278EC99-E6EE-49FE-92E6-0A309A478D98}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
{AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug|Any CPU.Build.0 = Debug|Any CPU
{AD92D96B-0E96-4F22-8DCE-892E13B1F282}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
@@ -520,7 +510,6 @@ Global
{EC743D1D-7691-43B7-B9B0-5F2F7018A8F6} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{46F2F967-C23F-4076-858D-33F7DA9BD2DA} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{2D7391C9-8254-4B8F-BF26-FADAF8F02F44} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
- {E278EC99-E6EE-49FE-92E6-0A309A478D98} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{AD92D96B-0E96-4F22-8DCE-892E13B1F282} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{65D0603E-B96C-4DFC-BDD1-705891B88C18} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{707BB22C-7E5F-497A-8C2F-74578F675705} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
diff --git a/src/Microsoft.ML.Api/GenerateCodeCommand.cs b/src/Microsoft.ML.Api/GenerateCodeCommand.cs
index 26136971af..4855a94219 100644
--- a/src/Microsoft.ML.Api/GenerateCodeCommand.cs
+++ b/src/Microsoft.ML.Api/GenerateCodeCommand.cs
@@ -24,7 +24,7 @@ namespace Microsoft.ML.Runtime.Api
///
/// REVIEW: Consider adding support for generating VBuffers instead of arrays, maybe for high dimensionality vectors.
///
- public sealed class GenerateCodeCommand : ICommand
+ internal sealed class GenerateCodeCommand : ICommand
{
public const string LoadName = "GenerateSamplePredictionCode";
private const string CodeTemplatePath = "Microsoft.ML.Api.GeneratedCodeTemplate.csresource";
diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs
index 3450dd2bbf..7110565d12 100644
--- a/src/Microsoft.ML.Api/TypedCursor.cs
+++ b/src/Microsoft.ML.Api/TypedCursor.cs
@@ -528,8 +528,6 @@ public ICursor GetRootCursor()
///
public static class CursoringUtils
{
- private const string NeedEnvObsoleteMessage = "This method is obsolete. Please use the overload that takes an additional 'env' argument. An environment can be created via new LocalEnvironment().";
-
///
/// Generate a strongly-typed cursorable wrapper of the .
///
@@ -550,24 +548,6 @@ public static ICursorable AsCursorable(this IDataView data, IHostEnv
return TypedCursorable.Create(env, data, ignoreMissingColumns, schemaDefinition);
}
- ///
- /// Generate a strongly-typed cursorable wrapper of the .
- ///
- /// The user-defined row type.
- /// The underlying data view.
- /// Whether to ignore the case when a requested column is not present in the data view.
- /// Optional user-provided schema definition. If it is not present, the schema is inferred from the definition of T.
- /// The cursorable wrapper of .
- [Obsolete(NeedEnvObsoleteMessage)]
- public static ICursorable AsCursorable(this IDataView data, bool ignoreMissingColumns = false,
- SchemaDefinition schemaDefinition = null)
- where TRow : class, new()
- {
- // REVIEW: Take an env as a parameter.
- var env = new ConsoleEnvironment();
- return data.AsCursorable(env, ignoreMissingColumns, schemaDefinition);
- }
-
///
/// Convert an into a strongly-typed .
///
@@ -589,24 +569,5 @@ public static IEnumerable AsEnumerable(this IDataView data, IHostEnv
var engine = new PipeEngine(env, data, ignoreMissingColumns, schemaDefinition);
return engine.RunPipe(reuseRowObject);
}
-
- ///
- /// Convert an into a strongly-typed .
- ///
- /// The user-defined row type.
- /// The underlying data view.
- /// Whether to return the same object on every row, or allocate a new one per row.
- /// Whether to ignore the case when a requested column is not present in the data view.
- /// Optional user-provided schema definition. If it is not present, the schema is inferred from the definition of T.
- /// The that holds the data in . It can be enumerated multiple times.
- [Obsolete(NeedEnvObsoleteMessage)]
- public static IEnumerable AsEnumerable(this IDataView data, bool reuseRowObject,
- bool ignoreMissingColumns = false, SchemaDefinition schemaDefinition = null)
- where TRow : class, new()
- {
- // REVIEW: Take an env as a parameter.
- var env = new ConsoleEnvironment();
- return data.AsEnumerable(env, reuseRowObject, ignoreMissingColumns, schemaDefinition);
- }
}
}
diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
index 28c623f3c2..81018dad3e 100644
--- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
+++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
@@ -40,7 +40,8 @@ internal enum SettingsFlags
///
/// This allows components to be created by name, signature type, and a settings string.
///
- public interface ICommandLineComponentFactory : IComponentFactory
+ [BestFriend]
+ internal interface ICommandLineComponentFactory : IComponentFactory
{
Type SignatureType { get; }
string Name { get; }
diff --git a/src/Common/AssemblyLoadingUtils.cs b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs
similarity index 97%
rename from src/Common/AssemblyLoadingUtils.cs
rename to src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs
index ab2f2c563a..e947776bc9 100644
--- a/src/Common/AssemblyLoadingUtils.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs
@@ -6,11 +6,13 @@
using System;
using System.IO;
using System.IO.Compression;
-using System.Linq;
using System.Reflection;
namespace Microsoft.ML.Runtime
{
+ [Obsolete("The usage for this is intended for the internal command line utilities and is not intended for anything related to the API. " +
+ "Please consider another way of doing whatever it is you're attempting to accomplish.")]
+ [BestFriend]
internal static class AssemblyLoadingUtils
{
///
diff --git a/src/Microsoft.ML.Core/Data/ICommand.cs b/src/Microsoft.ML.Core/Data/ICommand.cs
index c300b4f3a5..44d4c7340b 100644
--- a/src/Microsoft.ML.Core/Data/ICommand.cs
+++ b/src/Microsoft.ML.Core/Data/ICommand.cs
@@ -11,9 +11,11 @@ namespace Microsoft.ML.Runtime.Command
///
/// The signature for commands.
///
- public delegate void SignatureCommand();
+ [BestFriend]
+ internal delegate void SignatureCommand();
- public interface ICommand
+ [BestFriend]
+ internal interface ICommand
{
void Run();
}
diff --git a/src/Microsoft.ML.Core/Data/IFileHandle.cs b/src/Microsoft.ML.Core/Data/IFileHandle.cs
index 2b57187250..37b871b7b6 100644
--- a/src/Microsoft.ML.Core/Data/IFileHandle.cs
+++ b/src/Microsoft.ML.Core/Data/IFileHandle.cs
@@ -61,7 +61,7 @@ public sealed class SimpleFileHandle : IFileHandle
// handle has been disposed.
private List _streams;
- private bool IsDisposed { get { return _streams == null; } }
+ private bool IsDisposed => _streams == null;
public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bool autoDelete)
{
@@ -84,15 +84,9 @@ public SimpleFileHandle(IExceptionContext ectx, string path, bool needsWrite, bo
_streams = new List();
}
- public bool CanWrite
- {
- get { return !_wrote && !IsDisposed; }
- }
+ public bool CanWrite => !_wrote && !IsDisposed;
- public bool CanRead
- {
- get { return _wrote && !IsDisposed; }
- }
+ public bool CanRead => _wrote && !IsDisposed;
public void Dispose()
{
diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
index 0b8c097e7d..eb0d57845c 100644
--- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
+++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
@@ -72,6 +72,8 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
/// The suffix and prefix are optional. A common use for suffix is to specify an extension, eg, ".txt".
/// The use of suffix and prefix, including whether they have any affect, is up to the host environment.
///
+ [Obsolete("The host environment is not disposable, so it is inappropriate to use this method. " +
+ "Please handle your own temporary files within the component yourself, including their proper disposal and deletion.")]
IFileHandle CreateTempFile(string suffix = null, string prefix = null);
///
@@ -188,7 +190,8 @@ public readonly struct ChannelMessage
///
public string Message => _args != null ? string.Format(_message, _args) : _message;
- public ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message)
+ [BestFriend]
+ internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string message)
{
Contracts.CheckNonEmpty(message, nameof(message));
Kind = kind;
@@ -197,7 +200,8 @@ public ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, s
_args = null;
}
- public ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args)
+ [BestFriend]
+ internal ChannelMessage(ChannelMessageKind kind, MessageSensitivity sensitivity, string fmt, params object[] args)
{
Contracts.CheckNonEmpty(fmt, nameof(fmt));
Contracts.CheckNonEmpty(args, nameof(args));
diff --git a/src/Microsoft.ML.Core/Data/ProgressReporter.cs b/src/Microsoft.ML.Core/Data/ProgressReporter.cs
index f7741b462c..191364e2a3 100644
--- a/src/Microsoft.ML.Core/Data/ProgressReporter.cs
+++ b/src/Microsoft.ML.Core/Data/ProgressReporter.cs
@@ -14,7 +14,8 @@ namespace Microsoft.ML.Runtime.Data
///
/// The progress reporting classes used by descendants.
///
- public static class ProgressReporting
+ [BestFriend]
+ internal static class ProgressReporting
{
///
/// The progress channel for .
diff --git a/src/Microsoft.ML.Core/Data/ServerChannel.cs b/src/Microsoft.ML.Core/Data/ServerChannel.cs
index b11e962ab2..a9b33d1986 100644
--- a/src/Microsoft.ML.Core/Data/ServerChannel.cs
+++ b/src/Microsoft.ML.Core/Data/ServerChannel.cs
@@ -19,7 +19,8 @@ namespace Microsoft.ML.Runtime
/// delegates will be published in some fashion, with the target scenario being
/// that the library will publish some sort of restful API.
///
- public sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable
+ [BestFriend]
+ internal sealed class ServerChannel : ServerChannel.IPendingBundleNotification, IDisposable
{
// See ServerChannel.md for a more elaborate discussion of high level usage and design.
private readonly IChannelProvider _chp;
@@ -250,7 +251,8 @@ public void AddDoneAction(Action onDone)
}
}
- public static class ServerChannelUtilities
+ [BestFriend]
+ internal static class ServerChannelUtilities
{
///
/// Convenience method for that looks more idiomatic to typical
diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
index 79a8c028ef..0163222fc1 100644
--- a/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointModuleAttribute.cs
@@ -9,13 +9,15 @@ namespace Microsoft.ML.Runtime.EntryPoints
///
/// This is a signature for classes that are 'holders' of entry points and components.
///
- public delegate void SignatureEntryPointModule();
+ [BestFriend]
+ internal delegate void SignatureEntryPointModule();
///
/// A simplified assembly attribute for marking EntryPoint modules.
///
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
- public sealed class EntryPointModuleAttribute : LoadableClassAttributeBase
+ [BestFriend]
+ internal sealed class EntryPointModuleAttribute : LoadableClassAttributeBase
{
public EntryPointModuleAttribute(Type loaderType)
: base(null, typeof(void), loaderType, null, new[] { typeof(SignatureEntryPointModule) }, loaderType.FullName)
diff --git a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
index f64c8d0758..c4d4325f79 100644
--- a/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/EntryPointUtils.cs
@@ -12,7 +12,8 @@
namespace Microsoft.ML.Runtime.EntryPoints
{
- public static class EntryPointUtils
+ [BestFriend]
+ internal static class EntryPointUtils
{
private static bool IsValueWithinRange(TlcModule.RangeAttribute range, object obj)
{
diff --git a/src/Microsoft.ML.Core/EntryPoints/IMlState.cs b/src/Microsoft.ML.Core/EntryPoints/IMlState.cs
index 52b0828256..41ea062861 100644
--- a/src/Microsoft.ML.Core/EntryPoints/IMlState.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/IMlState.cs
@@ -10,5 +10,5 @@ namespace Microsoft.ML.Runtime.EntryPoints
/// black box to the graph. The macro itself will then case to the concrete type.
///
public interface IMlState
- {}
+ { }
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
index 4dc02993d2..d538f636ce 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ModuleArgs.cs
@@ -18,7 +18,8 @@ namespace Microsoft.ML.Runtime.EntryPoints
/// This class defines attributes to annotate module inputs, outputs, entry points etc. when defining
/// the module interface.
///
- public static class TlcModule
+ [BestFriend]
+ internal static class TlcModule
{
///
/// An attribute used to annotate the component.
diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
index e683a42413..3e27ce2516 100644
--- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
+++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
@@ -5,8 +5,6 @@
#pragma warning disable 420 // volatile with Interlocked.CompareExchange
using System;
-using System.Collections.Concurrent;
-using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
@@ -15,7 +13,12 @@ namespace Microsoft.ML.Runtime.Data
{
using Stopwatch = System.Diagnostics.Stopwatch;
- public sealed class ConsoleEnvironment : HostEnvironmentBase
+ ///
+ /// The console environment. As its name suggests, should be limited to those applications that deliberately want
+ /// console functionality.
+ ///
+ [BestFriend]
+ internal sealed class ConsoleEnvironment : HostEnvironmentBase
{
public const string ComponentHistoryKey = "ComponentHistory";
diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
index 6cfb4157be..31ab23e28f 100644
--- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
+++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
@@ -15,7 +15,8 @@ namespace Microsoft.ML.Runtime.Data
/// Base class for channel providers. This is a common base class for.
/// The ParentFullName, ShortName, and FullName may be null or empty.
///
- public abstract class ChannelProviderBase : IExceptionContext
+ [BestFriend]
+ internal abstract class ChannelProviderBase : IExceptionContext
{
///
/// Data keys that are attached to the exception thrown via the exception context.
@@ -79,42 +80,22 @@ public virtual TException Process(TException ex)
///
/// Message source (a channel) that generated the message being dispatched.
///
- public interface IMessageSource
+ [BestFriend]
+ internal interface IMessageSource
{
string ShortName { get; }
string FullName { get; }
bool Verbose { get; }
}
- ///
- /// A that is also a channel listener can attach
- /// listeners for messages, as sent through .
- ///
- public interface IMessageDispatcher : IHostEnvironment
- {
- ///
- /// Listen on this environment to messages of a particular type.
- ///
- /// The message type
- /// The action to perform when a message of the
- /// appropriate type is received.
- void AddListener(Action listenerFunc);
-
- ///
- /// Removes a previously added listener.
- ///
- /// The message type
- /// The previous listener function that is now being removed.
- void RemoveListener(Action listenerFunc);
- }
-
///
/// A basic host environment suited for many environments.
/// This also supports modifying the concurrency factor, provides the ability to subscribe to pipes via the
/// AddListener/RemoveListener methods, and exposes the to
/// query progress.
///
- public abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironment, IDisposable, IChannelProvider, IMessageDispatcher
+ [BestFriend]
+ internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironment, IDisposable, IChannelProvider
where TEnv : HostEnvironmentBase
{
///
diff --git a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs
index ebd7d41486..72b08e2715 100644
--- a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs
+++ b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs
@@ -13,7 +13,8 @@ namespace Microsoft.ML.Runtime
///
/// A telemetry message.
///
- public abstract class TelemetryMessage
+ [BestFriend]
+ internal abstract class TelemetryMessage
{
public static TelemetryMessage CreateCommand(string commandName, string commandText)
{
@@ -40,7 +41,8 @@ public static TelemetryMessage CreateException(Exception exception)
///
/// Message with one long text and bunch of small properties (limit on value is ~1020 chars)
///
- public sealed class TelemetryTrace : TelemetryMessage
+ [BestFriend]
+ internal sealed class TelemetryTrace : TelemetryMessage
{
public readonly string Text;
public readonly string Name;
@@ -57,7 +59,8 @@ public TelemetryTrace(string text, string name, string type)
///
/// Message with exception
///
- public sealed class TelemetryException : TelemetryMessage
+ [BestFriend]
+ internal sealed class TelemetryException : TelemetryMessage
{
public readonly Exception Exception;
public TelemetryException(Exception exception)
@@ -70,7 +73,8 @@ public TelemetryException(Exception exception)
///
/// Message with metric value and it properites
///
- public sealed class TelemetryMetric : TelemetryMessage
+ [BestFriend]
+ internal sealed class TelemetryMetric : TelemetryMessage
{
public readonly string Name;
public readonly double Value;
diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs
index 6647f77592..5e796aa602 100644
--- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs
+++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs
@@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Runtime.Data;
+using System;
namespace Microsoft.ML.Runtime
{
@@ -39,7 +40,8 @@ namespace Microsoft.ML.Runtime
/// The base interface for a trainers. Implementors should not implement this interface directly,
/// but rather implement the more specific .
///
- public interface ITrainer
+ [BestFriend]
+ internal interface ITrainer
{
///
/// Auxiliary information about the trainer in terms of its capabilities
@@ -66,7 +68,8 @@ public interface ITrainer
/// and produces a predictor.
///
/// Type of predictor produced
- public interface ITrainer : ITrainer
+ [BestFriend]
+ internal interface ITrainer : ITrainer
where TPredictor : IPredictor
{
///
@@ -77,7 +80,8 @@ public interface ITrainer : ITrainer
new TPredictor Train(TrainContext context);
}
- public static class TrainerExtensions
+ [BestFriend]
+ internal static class TrainerExtensions
{
///
/// Convenience train extension for the case where one has only a training set with no auxiliary information.
diff --git a/src/Microsoft.ML.Core/Prediction/TrainContext.cs b/src/Microsoft.ML.Core/Prediction/TrainContext.cs
index 7bd1509bb5..e5e4bbad1c 100644
--- a/src/Microsoft.ML.Core/Prediction/TrainContext.cs
+++ b/src/Microsoft.ML.Core/Prediction/TrainContext.cs
@@ -11,7 +11,8 @@ namespace Microsoft.ML.Runtime
/// into or .
/// This holds at least a training set, as well as optioonally a predictor.
///
- public sealed class TrainContext
+ [BestFriend]
+ internal sealed class TrainContext
{
///
/// The training set. Cannot be null.
diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
index 24b84d38c1..2f20462290 100644
--- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
@@ -21,7 +21,8 @@
namespace Microsoft.ML.Runtime.Data
{
- public sealed class CrossValidationCommand : DataCommand.ImplBase
+ [BestFriend]
+ internal sealed class CrossValidationCommand : DataCommand.ImplBase
{
// REVIEW: We need a way to specify different data sets, not just LabeledExamples.
public sealed class Arguments : DataCommand.ArgumentsBase
diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs
index f617fb206f..5c5ceef6f7 100644
--- a/src/Microsoft.ML.Data/Commands/DataCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs
@@ -17,7 +17,8 @@ namespace Microsoft.ML.Runtime.Data
///
/// This holds useful base classes for commands that ingest a primary dataset and deal with associated model files.
///
- public static class DataCommand
+ [BestFriend]
+ internal static class DataCommand
{
public abstract class ArgumentsBase
{
@@ -56,7 +57,8 @@ public abstract class ArgumentsBase
public KeyValuePair>[] Transform;
}
- public abstract class ImplBase : ICommand
+ [BestFriend]
+ internal abstract class ImplBase : ICommand
where TArgs : ArgumentsBase
{
protected readonly IHost Host;
diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
index 937f019c37..7ed4144262 100644
--- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
@@ -162,7 +162,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
}
}
- public sealed class EvaluateCommand : DataCommand.ImplBase
+ internal sealed class EvaluateCommand : DataCommand.ImplBase
{
public sealed class Arguments : DataCommand.ArgumentsBase
{
diff --git a/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs b/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs
index 577f51a4db..a721aee15a 100644
--- a/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/SaveDataCommand.cs
@@ -22,7 +22,7 @@
namespace Microsoft.ML.Runtime.Data
{
- public sealed class SaveDataCommand : DataCommand.ImplBase
+ internal sealed class SaveDataCommand : DataCommand.ImplBase
{
public sealed class Arguments : DataCommand.ArgumentsBase
{
@@ -87,7 +87,7 @@ private void RunCore(IChannel ch)
}
}
- public sealed class ShowDataCommand : DataCommand.ImplBase
+ internal sealed class ShowDataCommand : DataCommand.ImplBase
{
public sealed class Arguments : DataCommand.ArgumentsBase
{
diff --git a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
index e1057d18b4..633f871dd0 100644
--- a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
@@ -20,7 +20,7 @@
namespace Microsoft.ML.Runtime.Tools
{
- public sealed class SavePredictorCommand : ICommand
+ internal sealed class SavePredictorCommand : ICommand
{
public sealed class Arguments
{
diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
index d4c737f7b1..c0b2d7ecc5 100644
--- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
@@ -37,7 +37,7 @@ public interface IDataScorerTransform : IDataTransform, ITransformTemplate
public delegate void SignatureBindableMapper(IPredictor predictor);
- public sealed class ScoreCommand : DataCommand.ImplBase
+ internal sealed class ScoreCommand : DataCommand.ImplBase
{
public sealed class Arguments : DataCommand.ArgumentsBase
{
@@ -232,7 +232,8 @@ private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNam
}
}
- public static class ScoreUtils
+ [BestFriend]
+ internal static class ScoreUtils
{
public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema)
{
diff --git a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs
index 5fbab5176a..20a9857a8f 100644
--- a/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/ShowSchemaCommand.cs
@@ -20,7 +20,7 @@
namespace Microsoft.ML.Runtime.Data
{
- public sealed class ShowSchemaCommand : DataCommand.ImplBase
+ internal sealed class ShowSchemaCommand : DataCommand.ImplBase
{
public sealed class Arguments : DataCommand.ArgumentsBase
{
diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs
index 574bd4e00a..407f7e713d 100644
--- a/src/Microsoft.ML.Data/Commands/TestCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs
@@ -14,8 +14,12 @@
namespace Microsoft.ML.Runtime.Data
{
- // This command is essentially chaining together Score and Evaluate, without the need to save the intermediary scored data.
- public sealed class TestCommand : DataCommand.ImplBase
+ ///
+ /// This command is essentially chaining together and
+ /// , without the need to save the intermediary scored data.
+ ///
+ [BestFriend]
+ internal sealed class TestCommand : DataCommand.ImplBase
{
public sealed class Arguments : DataCommand.ArgumentsBase
{
diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
index 53b025ef39..ff709de143 100644
--- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
@@ -32,7 +32,8 @@ public enum NormalizeOption
Yes
}
- public sealed class TrainCommand : DataCommand.ImplBase
+ [BestFriend]
+ internal sealed class TrainCommand : DataCommand.ImplBase
{
public sealed class Arguments : DataCommand.ArgumentsBase
{
@@ -202,7 +203,8 @@ private void RunCore(IChannel ch, string cmd)
}
}
- public static class TrainUtils
+ [BestFriend]
+ internal static class TrainUtils
{
public static void CheckTrainer(IExceptionContext ectx, IComponentFactory trainer, string dataFile)
{
diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
index 450d680765..49f25375f5 100644
--- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
@@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
+using System.IO;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
@@ -16,7 +17,8 @@
namespace Microsoft.ML.Runtime.Data
{
- public sealed class TrainTestCommand : DataCommand.ImplBase
+ [BestFriend]
+ internal sealed class TrainTestCommand : DataCommand.ImplBase
{
public sealed class Arguments : DataCommand.ArgumentsBase
{
@@ -184,11 +186,12 @@ private void RunCore(IChannel ch, string cmd)
Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor, testDataUsedInTrainer);
IDataLoader testPipe;
- using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ?
- Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip"))
+ bool hasOutfile = !string.IsNullOrEmpty(Args.OutputModelFile);
+ var tempFilePath = hasOutfile ? null : Path.GetTempFileName();
+
+ using (var file = new SimpleFileHandle(ch, hasOutfile ? Args.OutputModelFile : tempFilePath, true, !hasOutfile))
{
TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
-
ch.Trace("Constructing the testing pipeline");
using (var stream = file.OpenReadStream())
using (var rep = RepositoryReader.Open(stream, ch))
diff --git a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs
index 79c3295017..e9db784f1f 100644
--- a/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TypeInfoCommand.cs
@@ -17,7 +17,7 @@
namespace Microsoft.ML.Data.Commands
{
- public sealed class TypeInfoCommand : ICommand
+ internal sealed class TypeInfoCommand : ICommand
{
internal const string LoadName = "TypeInfo";
internal const string Summary = "Displays information about the standard primitive " +
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
index cf1e5f7a25..73dd62152b 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
@@ -2137,7 +2137,7 @@ public override ValueGetter GetIdGetter()
}
}
- public sealed class InfoCommand : ICommand
+ internal sealed class InfoCommand : ICommand
{
public const string LoadName = "IdvInfo";
diff --git a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
index 6db939a877..f866ca4817 100644
--- a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
+++ b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
@@ -173,11 +173,6 @@ public interface IPredictorWithFeatureWeights : IHaveFeatureWeights
{
}
- public interface IHasLabelGains : ITrainer
- {
- Double[] GetLabelGains();
- }
-
///
/// Interface for mapping input values to corresponding feature contributions.
/// This interface is commonly implemented by predictors.
diff --git a/src/Microsoft.ML.Data/EntryPoints/Cache.cs b/src/Microsoft.ML.Data/EntryPoints/Cache.cs
index 1bcbcc0bf8..aa2693f4aa 100644
--- a/src/Microsoft.ML.Data/EntryPoints/Cache.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/Cache.cs
@@ -68,9 +68,11 @@ public static CacheOutput CacheData(IHostEnvironment env, CacheInput input)
cols.Add(i);
}
+#pragma warning disable CS0618 // This ought to be addressed. See #1287.
// We are not disposing the fileHandle because we want it to stay around for the execution of the graph.
// It will be disposed when the environment is disposed.
var fileHandle = host.CreateTempFile();
+#pragma warning restore CS0618
using (var stream = fileHandle.CreateWriteStream())
saver.SaveData(stream, input.Data, cols.ToArray());
diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
index 550191b157..3254289f12 100644
--- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
@@ -99,7 +99,8 @@ public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight
public Optional GroupIdColumn = Optional.Implicit(DefaultColumnNames.GroupId);
}
- public static class LearnerEntryPointsUtils
+ [BestFriend]
+ internal static class LearnerEntryPointsUtils
{
public static string FindColumn(IExceptionContext ectx, ISchema schema, Optional value)
{
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index b6c650d77c..20d87fcbf6 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -690,7 +690,8 @@ public ValueMapper> GetWhatTheFeatureMapper(int
}
}
- public static class CalibratorUtils
+ [BestFriend]
+ internal static class CalibratorUtils
{
private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator,
ITrainer trainer, IPredictor predictor, RoleMappedSchema schema)
diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
index dde426f0d6..849318c1e0 100644
--- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
+++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
@@ -7,12 +7,13 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformTest" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ResultProcessor" + PublicKey.Value)]
-[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Data" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Api" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Ensemble" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.FastTree" + PublicKey.Value)]
diff --git a/src/Microsoft.ML.Data/Training/TrainerBase.cs b/src/Microsoft.ML.Data/Training/TrainerBase.cs
index ca2f2c7b64..d82dfb7ba6 100644
--- a/src/Microsoft.ML.Data/Training/TrainerBase.cs
+++ b/src/Microsoft.ML.Data/Training/TrainerBase.cs
@@ -19,7 +19,8 @@ public abstract class TrainerBase : ITrainer
public abstract PredictionKind PredictionKind { get; }
public abstract TrainerInfo Info { get; }
- protected TrainerBase(IHostEnvironment env, string name)
+ [BestFriend]
+ private protected TrainerBase(IHostEnvironment env, string name)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonEmpty(name, nameof(name));
@@ -30,6 +31,9 @@ protected TrainerBase(IHostEnvironment env, string name)
IPredictor ITrainer.Train(TrainContext context) => Train(context);
- public abstract TPredictor Train(TrainContext context);
+ TPredictor ITrainer.Train(TrainContext context) => Train(context);
+
+ [BestFriend]
+ private protected abstract TPredictor Train(TrainContext context);
}
}
diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
index 035d66c054..1e49c32ed3 100644
--- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
+++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
@@ -51,7 +51,8 @@ public abstract class TrainerEstimatorBase : ITrainerEstim
public abstract PredictionKind PredictionKind { get; }
- public TrainerEstimatorBase(IHost host,
+ [BestFriend]
+ private protected TrainerEstimatorBase(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
SchemaShape.Column weight = null)
@@ -87,7 +88,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
///
protected abstract SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema);
- public TModel Train(TrainContext context)
+ TModel ITrainer.Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
return TrainModelCore(context);
@@ -149,14 +150,15 @@ protected TTransformer TrainTransformer(IDataView trainSet,
return MakeTransformer(pred, trainSet.Schema);
}
- protected abstract TModel TrainModelCore(TrainContext trainContext);
+ [BestFriend]
+ private protected abstract TModel TrainModelCore(TrainContext trainContext);
protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema);
protected virtual RoleMappedData MakeRoles(IDataView data) =>
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name);
- IPredictor ITrainer.Train(TrainContext context) => Train(context);
+ IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context);
}
///
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
index 612631fae7..95bc4d3a10 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
@@ -263,33 +263,6 @@ public static IDataTransform CreateMinMaxNormalizer(IHostEnvironment env, IDataV
return normalizer.Fit(input).MakeDataTransform(input);
}
- ///
- /// Potentially apply a min-max normalizer to the data's feature column, keeping all existing role
- /// mappings except for the feature role mapping.
- ///
- /// The host environment to use to potentially instantiate the transform
- /// The role-mapped data that is potentially going to be modified by this method.
- /// The trainer to query as to whether it wants normalization. If the
- /// 's is true
- /// True if the normalizer was applied and was modified
- public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, ITrainer trainer)
- {
- Contracts.CheckValue(env, nameof(env));
- env.CheckValue(data, nameof(data));
- env.CheckValue(trainer, nameof(trainer));
-
- // If the trainer does not need normalization, or if the features either don't exist
- // or are not normalized, return false.
- if (!trainer.Info.NeedNormalization || data.Schema.FeaturesAreNormalized() != false)
- return false;
- var featInfo = data.Schema.Feature;
- env.AssertValue(featInfo); // Should be defined, if FeaturesAreNormalized returned a definite value.
-
- var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name);
- data = new RoleMappedData(view, data.Schema.GetColumnRoleNames());
- return true;
- }
-
///
/// Public create method corresponding to SignatureDataTransform.
///
diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs
index 9aa364b386..e8ec5f4e53 100644
--- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs
@@ -9,6 +9,7 @@
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Transforms;
+using System;
using System.Collections.Generic;
[assembly: LoadableClass(ScoringTransformer.Summary, typeof(IDataTransform), typeof(ScoringTransformer), typeof(ScoringTransformer.Arguments), typeof(SignatureDataTransform),
@@ -98,7 +99,11 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
}
}
- public static class TrainAndScoreTransformer
+ // Essentially, all trainer estimators when fitted return a transformer that produces scores -- which is to say, all machine
+ // learning algorithms actually behave more or less as this transform used to, so its presence is no longer necessary or helpful,
+ // from an API perspective, but this is still how things are invoked from the command line for now.
+ [BestFriend]
+ internal static class TrainAndScoreTransformer
{
public abstract class ArgumentsBase : TransformInputBase
{
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
index 728cccb1f6..e5e9b79afc 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
@@ -11,7 +11,7 @@
namespace Microsoft.ML.Ensemble.EntryPoints
{
- public static class Ensemble
+ internal static class Ensemble
{
[TlcModule.EntryPoint(Name = "Trainers.EnsembleBinaryClassifier", Desc = "Train binary ensemble.", UserName = EnsembleTrainer.UserNameValue)]
public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, EnsembleTrainer.Arguments input)
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
index dbe1517f22..257499cd13 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseScalarStacking.cs
@@ -9,7 +9,7 @@
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
- public abstract class BaseScalarStacking : BaseStacking
+ internal abstract class BaseScalarStacking : BaseStacking
{
internal BaseScalarStacking(IHostEnvironment env, string name, ArgumentsBase args)
: base(env, name, args)
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
index eb6dc3a650..d62650b7c2 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
@@ -16,7 +16,7 @@
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
using ColumnRole = RoleMappedSchema.ColumnRole;
- public abstract class BaseStacking : IStackingTrainer
+ internal abstract class BaseStacking : IStackingTrainer
{
public abstract class ArgumentsBase
{
@@ -28,13 +28,13 @@ public abstract class ArgumentsBase
internal abstract IComponentFactory>> GetPredictorFactory();
}
- protected readonly IComponentFactory>> BasePredictorType;
- protected readonly IHost Host;
- protected IPredictorProducing Meta;
+ private protected readonly IComponentFactory>> BasePredictorType;
+ private protected readonly IHost Host;
+ private protected IPredictorProducing Meta;
public Single ValidationDatasetProportion { get; }
- internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args)
+ private protected BaseStacking(IHostEnvironment env, string name, ArgumentsBase args)
{
Contracts.AssertValue(env);
env.AssertNonWhiteSpace(name);
@@ -49,7 +49,7 @@ internal BaseStacking(IHostEnvironment env, string name, ArgumentsBase args)
Host.CheckValue(BasePredictorType, nameof(BasePredictorType));
}
- internal BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx)
+ private protected BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx)
{
Contracts.AssertValue(env);
env.AssertNonWhiteSpace(name);
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
index 4e352e8265..f9e3b246f7 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
@@ -21,7 +21,7 @@
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
using TVectorPredictor = IPredictorProducing>;
- public sealed class MultiStacking : BaseStacking>, ICanSaveModel, IMultiClassOutputCombiner
+ internal sealed class MultiStacking : BaseStacking>, ICanSaveModel, IMultiClassOutputCombiner
{
public const string LoadName = "MultiStacking";
public const string LoaderSignature = "MultiStackingCombiner";
@@ -37,9 +37,11 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(MultiStacking).Assembly.FullName);
}
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
{
+ // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer.
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
[TGUI(Label = "Base predictor")]
@@ -49,6 +51,7 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF
public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);
}
+#pragma warning restore CS0649
public MultiStacking(IHostEnvironment env, Arguments args)
: base(env, LoaderSignature, args)
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
index 9adb9ba799..8c984613db 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
@@ -19,7 +19,7 @@ namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
using TScalarPredictor = IPredictorProducing;
- public sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel
+ internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel
{
public const string LoadName = "RegressionStacking";
public const string LoaderSignature = "RegressionStacking";
@@ -35,9 +35,11 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(RegressionStacking).Assembly.FullName);
}
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerFactory
{
+ // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer.
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
[TGUI(Label = "Base predictor")]
@@ -47,6 +49,7 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF
public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
}
+#pragma warning restore CS0649
public RegressionStacking(IHostEnvironment env, Arguments args)
: base(env, LoaderSignature, args)
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
index 8c36cc866e..f44f987b05 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
@@ -16,7 +16,7 @@
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
{
using TScalarPredictor = IPredictorProducing;
- public sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel
+ internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel
{
public const string UserName = "Stacking";
public const string LoadName = "Stacking";
@@ -33,9 +33,11 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(Stacking).Assembly.FullName);
}
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFactory
{
+ // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer.
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
[TGUI(Label = "Base predictor")]
@@ -45,6 +47,7 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto
public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);
}
+#pragma warning restore CS0649
public Stacking(IHostEnvironment env, Arguments args)
: base(env, LoaderSignature, args)
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
index cc03de02d4..a1befc112b 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
@@ -29,7 +29,7 @@ namespace Microsoft.ML.Runtime.Ensemble
///
/// A generic ensemble trainer for binary classification.
///
- public sealed class EnsembleTrainer : EnsembleTrainerBase,
IModelCombiner
{
@@ -48,6 +48,7 @@ public sealed class Arguments : ArgumentsBase
[TGUI(Label = "Output combiner", Description = "Output combiner type")]
public ISupportBinaryOutputCombinerFactory OutputCombiner = new MedianFactory();
+ // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer.
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
public IComponentFactory>[] BasePredictors;
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
index 9d7c6fab40..a8fc896c5b 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
@@ -101,7 +101,7 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env,
}
}
- public sealed override TPredictor Train(TrainContext context)
+ private protected sealed override TPredictor Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
index 3909fb1b07..1961e6a785 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
@@ -30,7 +30,7 @@ namespace Microsoft.ML.Runtime.Ensemble
///
/// A generic ensemble classifier for multi-class classification
///
- public sealed class MulticlassDataPartitionEnsembleTrainer :
+ internal sealed class MulticlassDataPartitionEnsembleTrainer :
EnsembleTrainerBase, EnsembleMultiClassPredictor,
IMulticlassSubModelSelector, IMultiClassOutputCombiner>,
IModelCombiner
@@ -49,6 +49,7 @@ public sealed class Arguments : ArgumentsBase
[TGUI(Label = "Output combiner", Description = "Output combiner type")]
public ISupportMulticlassOutputCombinerFactory OutputCombiner = new MultiMedian.Arguments();
+ // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer.
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMultiClassClassifierTrainer))]
public IComponentFactory>[] BasePredictors;
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
index b7e63b8862..09d394d596 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs
@@ -26,7 +26,7 @@
namespace Microsoft.ML.Runtime.Ensemble
{
using TScalarPredictor = IPredictorProducing;
- public sealed class RegressionEnsembleTrainer : EnsembleTrainerBase,
IModelCombiner
{
@@ -43,6 +43,7 @@ public sealed class Arguments : ArgumentsBase
[TGUI(Label = "Output combiner", Description = "Output combiner type")]
public ISupportRegressionOutputCombinerFactory OutputCombiner = new MedianFactory();
+ // REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer.
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
public IComponentFactory>[] BasePredictors;
diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs
index b56e92e4f7..f68108d2e8 100644
--- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs
@@ -17,7 +17,7 @@
namespace Microsoft.ML.Trainers.FastTree
{
[TlcModule.ComponentKind("FastTreeTrainer")]
- public interface IFastTreeTrainerFactory : IComponentFactory
+ internal interface IFastTreeTrainerFactory : IComponentFactory
{
}
@@ -31,7 +31,7 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
[TGUI(Label = "Optimize for unbalanced")]
public bool UnbalancedSets = false;
- public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this);
+ ITrainer IComponentFactory.CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this);
}
}
@@ -45,7 +45,7 @@ public Arguments()
EarlyStoppingMetrics = 1; // Use L1 by default.
}
- public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeRegressionTrainer(env, this);
+ ITrainer IComponentFactory.CreateComponent(IHostEnvironment env) => new FastTreeRegressionTrainer(env, this);
}
}
@@ -62,7 +62,7 @@ public sealed class Arguments : BoostedTreeArgs, IFastTreeTrainerFactory
"and intermediate values are compound Poisson loss.")]
public Double Index = 1.5;
- public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeTweedieTrainer(env, this);
+ ITrainer IComponentFactory.CreateComponent(IHostEnvironment env) => new FastTreeTweedieTrainer(env, this);
}
}
@@ -111,7 +111,7 @@ public Arguments()
EarlyStoppingMetrics = 1;
}
- public ITrainer CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this);
+ ITrainer IComponentFactory.CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this);
internal override void Check(IExceptionContext ectx)
{
diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
index 3b278a4e56..433cafa908 100644
--- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
@@ -154,7 +154,7 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments arg
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
- protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context)
+ private protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var trainData = context.TrainingSet;
diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
index 3eb2b83b84..2663c9df51 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
@@ -42,8 +42,7 @@ namespace Microsoft.ML.Trainers.FastTree
{
///
public sealed partial class FastTreeRankingTrainer
- : BoostingFastTreeTrainerBase, FastTreeRankingPredictor>,
- IHasLabelGains
+ : BoostingFastTreeTrainerBase, FastTreeRankingPredictor>
{
internal const string LoadNameValue = "FastTreeRanking";
internal const string UserNameValue = "FastTree (Boosted Trees) Ranking";
@@ -100,7 +99,7 @@ protected override float GetMaxLabel()
return GetLabelGains().Length - 1;
}
- protected override FastTreeRankingPredictor TrainModelCore(TrainContext context)
+ private protected override FastTreeRankingPredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var trainData = context.TrainingSet;
@@ -117,7 +116,7 @@ protected override FastTreeRankingPredictor TrainModelCore(TrainContext context)
return new FastTreeRankingPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs);
}
- public Double[] GetLabelGains()
+ private Double[] GetLabelGains()
{
try
{
diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
index 2b8e1b60c3..133f9c2bd4 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
@@ -84,7 +84,7 @@ internal FastTreeRegressionTrainer(IHostEnvironment env, Arguments args)
{
}
- protected override FastTreeRegressionPredictor TrainModelCore(TrainContext context)
+ private protected override FastTreeRegressionPredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var trainData = context.TrainingSet;
diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
index 7b70bf8169..f49798aa22 100644
--- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
@@ -86,7 +86,7 @@ internal FastTreeTweedieTrainer(IHostEnvironment env, Arguments args)
Initialize();
}
- protected override FastTreeTweediePredictor TrainModelCore(TrainContext context)
+ private protected override FastTreeTweediePredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var trainData = context.TrainingSet;
diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs
index 4afebf83ba..3f45cb7ed3 100644
--- a/src/Microsoft.ML.FastTree/GamClassification.cs
+++ b/src/Microsoft.ML.FastTree/GamClassification.cs
@@ -102,7 +102,7 @@ private static bool[] ConvertTargetsToBool(double[] targets)
return boolArray;
}
- protected override IPredictorProducing TrainModelCore(TrainContext context)
+ private protected override IPredictorProducing TrainModelCore(TrainContext context)
{
TrainBase(context);
var predictor = new BinaryClassGamPredictor(Host, InputLength, TrainSet,
diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs
index 9669bff805..575c1d9e24 100644
--- a/src/Microsoft.ML.FastTree/GamRegression.cs
+++ b/src/Microsoft.ML.FastTree/GamRegression.cs
@@ -70,7 +70,7 @@ internal override void CheckLabel(RoleMappedData data)
data.CheckRegressionLabel();
}
- protected override RegressionGamPredictor TrainModelCore(TrainContext context)
+ private protected override RegressionGamPredictor TrainModelCore(TrainContext context)
{
TrainBase(context);
return new RegressionGamPredictor(Host, InputLength, TrainSet, MeanEffect, BinEffects, FeatureMap);
diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs
index ffe8516b20..d77000d7ae 100644
--- a/src/Microsoft.ML.FastTree/GamTrainer.cs
+++ b/src/Microsoft.ML.FastTree/GamTrainer.cs
@@ -187,7 +187,7 @@ private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name,
InitializeThreads();
}
- protected void TrainBase(TrainContext context)
+ private protected void TrainBase(TrainContext context)
{
using (var ch = Host.Start("Training"))
{
@@ -981,7 +981,7 @@ public void SaveSummary(TextWriter writer, RoleMappedSchema schema)
/// , it is convenient to have the command itself nested within the base
/// predictor class.
///
- public sealed class VisualizationCommand : DataCommand.ImplBase
+ internal sealed class VisualizationCommand : DataCommand.ImplBase
{
public const string Summary = "Loads a model trained with a GAM learner, and starts an interactive web session to visualize it.";
public const string LoadName = "GamVisualization";
diff --git a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs
new file mode 100644
index 0000000000..a03d7bdab6
--- /dev/null
+++ b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs
@@ -0,0 +1,10 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+using Microsoft.ML;
+
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
+
+[assembly: WantsToBeBestFriends]
diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
index b5f85bf919..a7aedc692a 100644
--- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
@@ -168,7 +168,7 @@ public FastForestClassification(IHostEnvironment env, Arguments args)
{
}
- protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context)
+ private protected override IPredictorWithFeatureWeights TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var trainData = context.TrainingSet;
diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
index d805e9c750..4fe842214c 100644
--- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
@@ -189,7 +189,7 @@ public FastForestRegression(IHostEnvironment env, Arguments args)
{
}
- protected override FastForestRegressionPredictor TrainModelCore(TrainContext context)
+ private protected override FastForestRegressionPredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var trainData = context.TrainingSet;
diff --git a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs
index 4d02097941..14fcad759e 100644
--- a/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs
+++ b/src/Microsoft.ML.FastTree/SumupPerformanceCommand.cs
@@ -29,7 +29,7 @@ namespace Microsoft.ML.Trainers.FastTree
///
/// This is an internal utility command to measure the performance of the IntArray sumup operation.
///
- public sealed class SumupPerformanceCommand : ICommand
+ internal sealed class SumupPerformanceCommand : ICommand
{
public sealed class Arguments
{
diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
index 252ac42abe..b2a2973da7 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
@@ -547,8 +547,10 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
}
///
- public static class TreeEnsembleFeaturizerTransform
+ [BestFriend]
+ internal static class TreeEnsembleFeaturizerTransform
{
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
public sealed class Arguments : TrainAndScoreTransformer.ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Trainer to use", ShortName = "tr", NullName = "", SortOrder = 1, SignatureType = typeof(SignatureTreeEnsembleTrainer))]
@@ -586,6 +588,7 @@ public sealed class ArgumentsForEntryPoint : TransformInputBase
[Argument(ArgumentType.Required, HelpText = "Trainer to use", SortOrder = 10, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public IPredictorModel PredictorModel;
}
+#pragma warning restore CS0649
internal const string TreeEnsembleSummary =
"Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector " +
@@ -801,8 +804,9 @@ private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch,
}
}
- public static partial class TreeFeaturize
+ internal static partial class TreeFeaturize
{
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
[TlcModule.EntryPoint(Name = "Transforms.TreeLeafFeaturizer",
Desc = TreeEnsembleFeaturizerTransform.TreeEnsembleSummary,
UserName = TreeEnsembleFeaturizerTransform.UserName,
@@ -818,5 +822,6 @@ public static CommonOutputs.TransformOutput Featurizer(IHostEnvironment env, Tre
var xf = TreeEnsembleFeaturizerTransform.CreateForEntryPoint(env, input, input.Data);
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf };
}
+#pragma warning restore CS0649
}
}
diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
index 48ff08b67b..c1dfab7957 100644
--- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
+++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
@@ -130,7 +130,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
private static Double ProbClamp(Double p)
=> Math.Max(0, Math.Min(p, 1));
- protected override OlsLinearRegressionPredictor TrainModelCore(TrainContext context)
+ private protected override OlsLinearRegressionPredictor TrainModelCore(TrainContext context)
{
using (var ch = Host.Start("Training"))
{
diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
index a37dac9338..590aa06e27 100644
--- a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
+++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
@@ -133,7 +133,7 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa
return examplesToFeedTrain;
}
- protected override TPredictor TrainModelCore(TrainContext context)
+ private protected override TPredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
using (var ch = Host.Start("Training"))
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
index 2682e50301..7e2bb4f807 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
@@ -150,7 +150,7 @@ private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, Action
diff --git a/src/Microsoft.ML.Legacy/LearningPipeline.cs b/src/Microsoft.ML.Legacy/LearningPipeline.cs
index e56e49f6a3..7544d2f112 100644
--- a/src/Microsoft.ML.Legacy/LearningPipeline.cs
+++ b/src/Microsoft.ML.Legacy/LearningPipeline.cs
@@ -161,86 +161,84 @@ public PredictionModel Train()
where TInput : class
where TOutput : class, new()
{
- using (var environment = new ConsoleEnvironment(seed: _seed, conc: _conc))
+ var environment = new MLContext(seed: _seed, conc: _conc);
+ Experiment experiment = environment.CreateExperiment();
+ ILearningPipelineStep step = null;
+ List loaders = new List();
+ List> transformModels = new List>();
+ Var lastTransformModel = null;
+
+ foreach (ILearningPipelineItem currentItem in this)
{
- Experiment experiment = environment.CreateExperiment();
- ILearningPipelineStep step = null;
- List loaders = new List();
- List> transformModels = new List>();
- Var lastTransformModel = null;
+ if (currentItem is ILearningPipelineLoader loader)
+ loaders.Add(loader);
- foreach (ILearningPipelineItem currentItem in this)
+ step = currentItem.ApplyStep(step, experiment);
+ if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null)
+ transformModels.Add(dataStep.Model);
+ else if (step is ILearningPipelinePredictorStep predictorDataStep)
{
- if (currentItem is ILearningPipelineLoader loader)
- loaders.Add(loader);
+ if (lastTransformModel != null)
+ transformModels.Insert(0, lastTransformModel);
- step = currentItem.ApplyStep(step, experiment);
- if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null)
- transformModels.Add(dataStep.Model);
- else if (step is ILearningPipelinePredictorStep predictorDataStep)
+ Var predictorModel;
+ if (transformModels.Count != 0)
{
- if (lastTransformModel != null)
- transformModels.Insert(0, lastTransformModel);
-
- Var predictorModel;
- if (transformModels.Count != 0)
+ var localModelInput = new Transforms.ManyHeterogeneousModelCombiner
{
- var localModelInput = new Transforms.ManyHeterogeneousModelCombiner
- {
- PredictorModel = predictorDataStep.Model,
- TransformModels = new ArrayVar(transformModels.ToArray())
- };
- var localModelOutput = experiment.Add(localModelInput);
- predictorModel = localModelOutput.PredictorModel;
- }
- else
- predictorModel = predictorDataStep.Model;
-
- var scorer = new Transforms.Scorer
- {
- PredictorModel = predictorModel
+ PredictorModel = predictorDataStep.Model,
+ TransformModels = new ArrayVar(transformModels.ToArray())
};
-
- var scorerOutput = experiment.Add(scorer);
- lastTransformModel = scorerOutput.ScoringTransform;
- step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform);
- transformModels.Clear();
+ var localModelOutput = experiment.Add(localModelInput);
+ predictorModel = localModelOutput.PredictorModel;
}
- }
+ else
+ predictorModel = predictorDataStep.Model;
- if (transformModels.Count > 0)
- {
- if (lastTransformModel != null)
- transformModels.Insert(0, lastTransformModel);
-
- var modelInput = new Transforms.ModelCombiner
+ var scorer = new Transforms.Scorer
{
- Models = new ArrayVar(transformModels.ToArray())
+ PredictorModel = predictorModel
};
- var modelOutput = experiment.Add(modelInput);
- lastTransformModel = modelOutput.OutputModel;
+ var scorerOutput = experiment.Add(scorer);
+ lastTransformModel = scorerOutput.ScoringTransform;
+ step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform);
+ transformModels.Clear();
}
+ }
- experiment.Compile();
- foreach (ILearningPipelineLoader loader in loaders)
- {
- loader.SetInput(environment, experiment);
- }
- experiment.Run();
+ if (transformModels.Count > 0)
+ {
+ if (lastTransformModel != null)
+ transformModels.Insert(0, lastTransformModel);
- ITransformModel model = experiment.GetOutput(lastTransformModel);
- BatchPredictionEngine predictor;
- using (var memoryStream = new MemoryStream())
+ var modelInput = new Transforms.ModelCombiner
{
- model.Save(environment, memoryStream);
+ Models = new ArrayVar(transformModels.ToArray())
+ };
- memoryStream.Position = 0;
+ var modelOutput = experiment.Add(modelInput);
+ lastTransformModel = modelOutput.OutputModel;
+ }
- predictor = environment.CreateBatchPredictionEngine(memoryStream);
+ experiment.Compile();
+ foreach (ILearningPipelineLoader loader in loaders)
+ {
+ loader.SetInput(environment, experiment);
+ }
+ experiment.Run();
- return new PredictionModel(predictor, memoryStream);
- }
+ ITransformModel model = experiment.GetOutput(lastTransformModel);
+ BatchPredictionEngine predictor;
+ using (var memoryStream = new MemoryStream())
+ {
+ model.Save(environment, memoryStream);
+
+ memoryStream.Position = 0;
+
+ predictor = environment.CreateBatchPredictionEngine(memoryStream);
+
+ return new PredictionModel(predictor, memoryStream);
}
}
diff --git a/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs b/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs
index af99d8ff3c..586dbf99e3 100644
--- a/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs
+++ b/src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Legacy.Transforms;
using System;
@@ -25,7 +26,7 @@ internal sealed class LearningPipelineDebugProxy
private const int MaxSlotNamesToDisplay = 100;
private readonly LearningPipeline _pipeline;
- private readonly ConsoleEnvironment _environment;
+ private readonly IHostEnvironment _environment;
private IDataView _preview;
private Exception _pipelineExecutionException;
private PipelineItemDebugColumn[] _columns;
@@ -39,7 +40,7 @@ public LearningPipelineDebugProxy(LearningPipeline pipeline)
_pipeline = new LearningPipeline();
// use a ConcurrencyFactor of 1 so other threads don't need to run in the debugger
- _environment = new ConsoleEnvironment(conc: 1);
+ _environment = new MLContext(conc: 1);
foreach (ILearningPipelineItem item in pipeline)
{
diff --git a/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj b/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj
index 6e4034aa24..1ccf77d19d 100644
--- a/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj
+++ b/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj
@@ -6,10 +6,6 @@
CORECLR
-
-
-
-
diff --git a/src/Microsoft.ML.Legacy/Models/BinaryClassificationEvaluator.cs b/src/Microsoft.ML.Legacy/Models/BinaryClassificationEvaluator.cs
index 71b32a323a..2f2bce8016 100644
--- a/src/Microsoft.ML.Legacy/Models/BinaryClassificationEvaluator.cs
+++ b/src/Microsoft.ML.Legacy/Models/BinaryClassificationEvaluator.cs
@@ -24,54 +24,52 @@ public sealed partial class BinaryClassificationEvaluator
///
public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
- using (var environment = new ConsoleEnvironment())
- {
- environment.CheckValue(model, nameof(model));
- environment.CheckValue(testData, nameof(testData));
+ var environment = new MLContext();
+ environment.CheckValue(model, nameof(model));
+ environment.CheckValue(testData, nameof(testData));
- Experiment experiment = environment.CreateExperiment();
+ Experiment experiment = environment.CreateExperiment();
- ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
- if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
- {
- throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
- }
+ ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
+ if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
+ {
+ throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
+ }
- var datasetScorer = new DatasetTransformScorer
- {
- Data = testDataOutput.Data
- };
- DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
+ var datasetScorer = new DatasetTransformScorer
+ {
+ Data = testDataOutput.Data
+ };
+ DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
- Data = scoreOutput.ScoredData;
- Output evaluteOutput = experiment.Add(this);
+ Data = scoreOutput.ScoredData;
+ Output evaluteOutput = experiment.Add(this);
- experiment.Compile();
+ experiment.Compile();
- experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
- testData.SetInput(environment, experiment);
+ experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
+ testData.SetInput(environment, experiment);
- experiment.Run();
+ experiment.Run();
- IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
- if (overallMetrics == null)
- {
- throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
- }
+ IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
+ if (overallMetrics == null)
+ {
+ throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
+ }
- IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix);
- if (confusionMatrix == null)
- {
- throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
- }
+ IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix);
+ if (confusionMatrix == null)
+ {
+ throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
+ }
- var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
+ var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
- if (metric.Count != 1)
- throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics");
+ if (metric.Count != 1)
+ throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics");
- return metric[0];
- }
+ return metric[0];
}
}
}
diff --git a/src/Microsoft.ML.Legacy/Models/ClassificationEvaluator.cs b/src/Microsoft.ML.Legacy/Models/ClassificationEvaluator.cs
index 63e7f8055e..5d644baf32 100644
--- a/src/Microsoft.ML.Legacy/Models/ClassificationEvaluator.cs
+++ b/src/Microsoft.ML.Legacy/Models/ClassificationEvaluator.cs
@@ -25,54 +25,52 @@ public sealed partial class ClassificationEvaluator
///
public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
- using (var environment = new ConsoleEnvironment())
- {
- environment.CheckValue(model, nameof(model));
- environment.CheckValue(testData, nameof(testData));
+ var environment = new MLContext();
+ environment.CheckValue(model, nameof(model));
+ environment.CheckValue(testData, nameof(testData));
- Experiment experiment = environment.CreateExperiment();
+ Experiment experiment = environment.CreateExperiment();
- ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
- if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
- {
- throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
- }
+ ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
+ if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
+ {
+ throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
+ }
- var datasetScorer = new DatasetTransformScorer
- {
- Data = testDataOutput.Data,
- };
- DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
+ var datasetScorer = new DatasetTransformScorer
+ {
+ Data = testDataOutput.Data,
+ };
+ DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
- Data = scoreOutput.ScoredData;
- Output evaluteOutput = experiment.Add(this);
+ Data = scoreOutput.ScoredData;
+ Output evaluteOutput = experiment.Add(this);
- experiment.Compile();
+ experiment.Compile();
- experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
- testData.SetInput(environment, experiment);
+ experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
+ testData.SetInput(environment, experiment);
- experiment.Run();
+ experiment.Run();
- IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
- if (overallMetrics == null)
- {
- throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClassificationEvaluator)} Evaluate.");
- }
+ IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
+ if (overallMetrics == null)
+ {
+ throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClassificationEvaluator)} Evaluate.");
+ }
- IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix);
- if (confusionMatrix == null)
- {
- throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(ClassificationEvaluator)} Evaluate.");
- }
+ IDataView confusionMatrix = experiment.GetOutput(evaluteOutput.ConfusionMatrix);
+ if (confusionMatrix == null)
+ {
+ throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(ClassificationEvaluator)} Evaluate.");
+ }
- var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
+ var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
- if (metric.Count != 1)
- throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics");
+ if (metric.Count != 1)
+ throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics");
- return metric[0];
- }
+ return metric[0];
}
}
}
diff --git a/src/Microsoft.ML.Legacy/Models/ClusterEvaluator.cs b/src/Microsoft.ML.Legacy/Models/ClusterEvaluator.cs
index 411c85b176..5d12ad85f9 100644
--- a/src/Microsoft.ML.Legacy/Models/ClusterEvaluator.cs
+++ b/src/Microsoft.ML.Legacy/Models/ClusterEvaluator.cs
@@ -24,48 +24,46 @@ public sealed partial class ClusterEvaluator
///
public ClusterMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
- using (var environment = new ConsoleEnvironment())
- {
- environment.CheckValue(model, nameof(model));
- environment.CheckValue(testData, nameof(testData));
+ var environment = new MLContext();
+ environment.CheckValue(model, nameof(model));
+ environment.CheckValue(testData, nameof(testData));
- Experiment experiment = environment.CreateExperiment();
+ Experiment experiment = environment.CreateExperiment();
- ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
- if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
- {
- throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
- }
+ ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
+ if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
+ {
+ throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
+ }
- var datasetScorer = new DatasetTransformScorer
- {
- Data = testDataOutput.Data,
- };
- DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
+ var datasetScorer = new DatasetTransformScorer
+ {
+ Data = testDataOutput.Data,
+ };
+ DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
- Data = scoreOutput.ScoredData;
- Output evaluteOutput = experiment.Add(this);
+ Data = scoreOutput.ScoredData;
+ Output evaluteOutput = experiment.Add(this);
- experiment.Compile();
+ experiment.Compile();
- experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
- testData.SetInput(environment, experiment);
+ experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
+ testData.SetInput(environment, experiment);
- experiment.Run();
+ experiment.Run();
- IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
+ IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
- if (overallMetrics == null)
- {
- throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClusterEvaluator)} Evaluate.");
- }
+ if (overallMetrics == null)
+ {
+ throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(ClusterEvaluator)} Evaluate.");
+ }
- var metric = ClusterMetrics.FromOverallMetrics(environment, overallMetrics);
+ var metric = ClusterMetrics.FromOverallMetrics(environment, overallMetrics);
- Contracts.Assert(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics");
+ Contracts.Assert(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics");
- return metric[0];
- }
+ return metric[0];
}
}
}
diff --git a/src/Microsoft.ML.Legacy/Models/CrossValidator.cs b/src/Microsoft.ML.Legacy/Models/CrossValidator.cs
index d9d133a779..96be9db419 100644
--- a/src/Microsoft.ML.Legacy/Models/CrossValidator.cs
+++ b/src/Microsoft.ML.Legacy/Models/CrossValidator.cs
@@ -27,7 +27,7 @@ public CrossValidationOutput CrossValidate(Lea
where TInput : class
where TOutput : class, new()
{
- using (var environment = new ConsoleEnvironment())
+ var environment = new MLContext();
{
Experiment subGraph = environment.CreateExperiment();
ILearningPipelineStep step = null;
diff --git a/src/Microsoft.ML.Legacy/Models/OneVersusAll.cs b/src/Microsoft.ML.Legacy/Models/OneVersusAll.cs
index acd756556a..3269556f1f 100644
--- a/src/Microsoft.ML.Legacy/Models/OneVersusAll.cs
+++ b/src/Microsoft.ML.Legacy/Models/OneVersusAll.cs
@@ -52,26 +52,24 @@ public OvaPipelineItem(ITrainerInputWithLabel trainer, bool useProbabilities)
public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
{
- using (var env = new ConsoleEnvironment())
+ var env = new MLContext();
+ var subgraph = env.CreateExperiment();
+ subgraph.Add(_trainer);
+ var ova = new OneVersusAll();
+ if (previousStep != null)
{
- var subgraph = env.CreateExperiment();
- subgraph.Add(_trainer);
- var ova = new OneVersusAll();
- if (previousStep != null)
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
{
- if (!(previousStep is ILearningPipelineDataStep dataStep))
- {
- throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
- }
-
- _data = dataStep.Data;
- ova.TrainingData = dataStep.Data;
- ova.UseProbabilities = _useProbabilities;
- ova.Nodes = subgraph;
+ throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}
- Output output = experiment.Add(ova);
- return new OvaPipelineStep(output);
+
+ _data = dataStep.Data;
+ ova.TrainingData = dataStep.Data;
+ ova.UseProbabilities = _useProbabilities;
+ ova.Nodes = subgraph;
}
+ Output output = experiment.Add(ova);
+ return new OvaPipelineStep(output);
}
public Var GetInputData() => _data;
diff --git a/src/Microsoft.ML.Legacy/Models/OnnxConverter.cs b/src/Microsoft.ML.Legacy/Models/OnnxConverter.cs
index c49c71cf71..0c9eca3ee8 100644
--- a/src/Microsoft.ML.Legacy/Models/OnnxConverter.cs
+++ b/src/Microsoft.ML.Legacy/Models/OnnxConverter.cs
@@ -70,16 +70,14 @@ public sealed partial class OnnxConverter
/// Model that needs to be converted to ONNX format.
public void Convert(PredictionModel model)
{
- using (var environment = new ConsoleEnvironment())
- {
- environment.CheckValue(model, nameof(model));
+ var environment = new MLContext();
+ environment.CheckValue(model, nameof(model));
- Experiment experiment = environment.CreateExperiment();
- experiment.Add(this);
- experiment.Compile();
- experiment.SetInput(Model, model.PredictorModel);
- experiment.Run();
- }
+ Experiment experiment = environment.CreateExperiment();
+ experiment.Add(this);
+ experiment.Compile();
+ experiment.SetInput(Model, model.PredictorModel);
+ experiment.Run();
}
}
}
diff --git a/src/Microsoft.ML.Legacy/Models/RegressionEvaluator.cs b/src/Microsoft.ML.Legacy/Models/RegressionEvaluator.cs
index 35a7fd7500..ffee6108c6 100644
--- a/src/Microsoft.ML.Legacy/Models/RegressionEvaluator.cs
+++ b/src/Microsoft.ML.Legacy/Models/RegressionEvaluator.cs
@@ -24,49 +24,47 @@ public sealed partial class RegressionEvaluator
///
public RegressionMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
- using (var environment = new ConsoleEnvironment())
- {
- environment.CheckValue(model, nameof(model));
- environment.CheckValue(testData, nameof(testData));
+ var environment = new MLContext();
+ environment.CheckValue(model, nameof(model));
+ environment.CheckValue(testData, nameof(testData));
- Experiment experiment = environment.CreateExperiment();
+ Experiment experiment = environment.CreateExperiment();
- ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
- if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
- {
- throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
- }
+ ILearningPipelineStep testDataStep = testData.ApplyStep(previousStep: null, experiment);
+ if (!(testDataStep is ILearningPipelineDataStep testDataOutput))
+ {
+ throw environment.Except($"The {nameof(ILearningPipelineLoader)} did not return a {nameof(ILearningPipelineDataStep)} from ApplyStep.");
+ }
- var datasetScorer = new DatasetTransformScorer
- {
- Data = testDataOutput.Data,
- };
- DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
+ var datasetScorer = new DatasetTransformScorer
+ {
+ Data = testDataOutput.Data,
+ };
+ DatasetTransformScorer.Output scoreOutput = experiment.Add(datasetScorer);
- Data = scoreOutput.ScoredData;
- Output evaluteOutput = experiment.Add(this);
+ Data = scoreOutput.ScoredData;
+ Output evaluteOutput = experiment.Add(this);
- experiment.Compile();
+ experiment.Compile();
- experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
- testData.SetInput(environment, experiment);
+ experiment.SetInput(datasetScorer.TransformModel, model.PredictorModel);
+ testData.SetInput(environment, experiment);
- experiment.Run();
+ experiment.Run();
- IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
+ IDataView overallMetrics = experiment.GetOutput(evaluteOutput.OverallMetrics);
- if (overallMetrics == null)
- {
- throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(RegressionEvaluator)} Evaluate.");
- }
+ if (overallMetrics == null)
+ {
+ throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(RegressionEvaluator)} Evaluate.");
+ }
- var metric = RegressionMetrics.FromOverallMetrics(environment, overallMetrics);
+ var metric = RegressionMetrics.FromOverallMetrics(environment, overallMetrics);
- if (metric.Count != 1)
- throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics");
+ if (metric.Count != 1)
+ throw environment.Except($"Exactly one metric set was expected but found {metric.Count} metrics");
- return metric[0];
- }
+ return metric[0];
}
}
}
diff --git a/src/Microsoft.ML.Legacy/Models/TrainTestEvaluator.cs b/src/Microsoft.ML.Legacy/Models/TrainTestEvaluator.cs
index 6972b8cf3e..dbf5df1b50 100644
--- a/src/Microsoft.ML.Legacy/Models/TrainTestEvaluator.cs
+++ b/src/Microsoft.ML.Legacy/Models/TrainTestEvaluator.cs
@@ -30,7 +30,7 @@ public TrainTestEvaluatorOutput TrainTestEvaluate
/// Returns labels that correspond to indices of the score array in the case of
@@ -40,7 +36,7 @@ internal TransformModel PredictorModel
public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = DefaultColumnNames.Score)
{
names = null;
- var schema = _predictorModel.OutputSchema;
+ var schema = PredictorModel.OutputSchema;
int colIndex = -1;
if (!schema.TryGetColumnIndex(scoreColumnName, out colIndex))
return false;
@@ -125,15 +121,13 @@ public static Task> ReadAsync(
if (stream == null)
throw new ArgumentNullException(nameof(stream));
- using (var environment = new ConsoleEnvironment())
- {
- AssemblyRegistration.RegisterAssemblies(environment);
+ var environment = new MLContext();
+ AssemblyRegistration.RegisterAssemblies(environment);
- BatchPredictionEngine predictor =
- environment.CreateBatchPredictionEngine(stream);
+ BatchPredictionEngine predictor =
+ environment.CreateBatchPredictionEngine(stream);
- return Task.FromResult(new PredictionModel(predictor, stream));
- }
+ return Task.FromResult(new PredictionModel(predictor, stream));
}
///
@@ -141,7 +135,7 @@ public static Task> ReadAsync(
///
/// Incoming IDataView
/// IDataView which contains predictions
- public IDataView Predict(IDataView input) => _predictorModel.Apply(_env, input);
+ public IDataView Predict(IDataView input) => PredictorModel.Apply(_env, input);
///
/// Save model to file.
@@ -168,7 +162,7 @@ public Task WriteAsync(Stream stream)
{
if (stream == null)
throw new ArgumentNullException(nameof(stream));
- _predictorModel.Save(_env, stream);
+ PredictorModel.Save(_env, stream);
return Task.CompletedTask;
}
}
diff --git a/src/Microsoft.ML.Legacy/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Legacy/Properties/AssemblyInfo.cs
index 52835b2f81..70fe21adb6 100644
--- a/src/Microsoft.ML.Legacy/Properties/AssemblyInfo.cs
+++ b/src/Microsoft.ML.Legacy/Properties/AssemblyInfo.cs
@@ -2,8 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System.Reflection;
+using Microsoft.ML;
using System.Runtime.CompilerServices;
-using System.Runtime.InteropServices;
-[assembly: InternalsVisibleTo("Microsoft.ML.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")]
+[assembly: InternalsVisibleTo("Microsoft.ML.Tests" + PublicKey.TestValue)]
+[assembly: InternalsVisibleTo("Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/ModuleGenerator.cs b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/ModuleGenerator.cs
index c612197f31..817706522b 100644
--- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/ModuleGenerator.cs
+++ b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/CodeGen/ModuleGenerator.cs
@@ -21,7 +21,7 @@
namespace Microsoft.ML.Runtime.EntryPoints.CodeGen
{
- public class ModuleGenerator : IGenerator
+ internal sealed class ModuleGenerator : IGenerator
{
private readonly string _modulePrefix;
private readonly bool _generateModule;
diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs
index a3d02d10e4..d6e95961ec 100644
--- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs
+++ b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/ExecuteGraphCommand.cs
@@ -21,7 +21,7 @@
namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils
{
- public sealed class ExecuteGraphCommand : ICommand
+ internal sealed class ExecuteGraphCommand : ICommand
{
public sealed class Arguments
{
diff --git a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs
index 9ebb25e301..af6e2b818d 100644
--- a/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs
+++ b/src/Microsoft.ML.Legacy/Runtime/EntryPoints/JsonUtils/GraphRunner.cs
@@ -140,7 +140,7 @@ public void SetInput(string name, TInput input)
///
/// Get the data kind of a particular port.
///
- public TlcModule.DataKind GetPortDataKind(string name)
+ internal TlcModule.DataKind GetPortDataKind(string name)
{
_host.CheckNonEmpty(name, nameof(name));
EntryPointVariable variable;
diff --git a/src/Microsoft.ML.Legacy/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML.Legacy/Runtime/Internal/Tools/CSharpApiGenerator.cs
index 0bcf5a6776..0fd2aa9aa0 100644
--- a/src/Microsoft.ML.Legacy/Runtime/Internal/Tools/CSharpApiGenerator.cs
+++ b/src/Microsoft.ML.Legacy/Runtime/Internal/Tools/CSharpApiGenerator.cs
@@ -22,7 +22,7 @@
namespace Microsoft.ML.Runtime.Internal.Tools
{
- public sealed class CSharpApiGenerator : IGenerator
+ internal sealed class CSharpApiGenerator : IGenerator
{
public sealed class Arguments
{
diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
index 3a0c496868..53e6ba543b 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs
@@ -102,7 +102,7 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGb
InitParallelTraining();
}
- protected override TModel TrainModelCore(TrainContext context)
+ private protected override TModel TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
diff --git a/src/Microsoft.ML.Maml/ChainCommand.cs b/src/Microsoft.ML.Maml/ChainCommand.cs
index e49162b45b..c51a78952b 100644
--- a/src/Microsoft.ML.Maml/ChainCommand.cs
+++ b/src/Microsoft.ML.Maml/ChainCommand.cs
@@ -15,7 +15,8 @@ namespace Microsoft.ML.Runtime.Tools
{
using Stopwatch = System.Diagnostics.Stopwatch;
- public sealed class ChainCommand : ICommand
+ [BestFriend]
+ internal sealed class ChainCommand : ICommand
{
public sealed class Arguments
{
diff --git a/src/Microsoft.ML.Maml/HelpCommand.cs b/src/Microsoft.ML.Maml/HelpCommand.cs
index ee2adbb495..80ddd63e47 100644
--- a/src/Microsoft.ML.Maml/HelpCommand.cs
+++ b/src/Microsoft.ML.Maml/HelpCommand.cs
@@ -23,14 +23,15 @@
namespace Microsoft.ML.Runtime.Tools
{
- public interface IGenerator
+ [BestFriend]
+ internal interface IGenerator
{
void Generate(IEnumerable infos);
}
public delegate void SignatureModuleGenerator(string regenerate);
- public sealed class HelpCommand : ICommand
+ internal sealed class HelpCommand : ICommand
{
public sealed class Arguments
{
@@ -100,7 +101,9 @@ public void Run()
public void Run(int? columns)
{
+#pragma warning disable CS0618 // The help command should be entirely within the command line anyway.
AssemblyLoadingUtils.LoadAndRegister(_env, _extraAssemblies);
+#pragma warning restore CCS0618
using (var ch = _env.Start("Help"))
using (var sw = new StringWriter(CultureInfo.InvariantCulture))
@@ -423,7 +426,7 @@ private void GenerateModule(List components)
}
}
- public sealed class XmlGenerator : IGenerator
+ internal sealed class XmlGenerator : IGenerator
{
public sealed class Arguments
{
diff --git a/src/Microsoft.ML.Maml/MAML.cs b/src/Microsoft.ML.Maml/MAML.cs
index ff207b38e4..a4eaa43491 100644
--- a/src/Microsoft.ML.Maml/MAML.cs
+++ b/src/Microsoft.ML.Maml/MAML.cs
@@ -58,7 +58,9 @@ private static int MainWithProgress(string args)
string currentDirectory = Path.GetDirectoryName(typeof(Maml).Module.FullyQualifiedName);
using (var env = CreateEnvironment())
+#pragma warning disable CS0618 // This is the command line project, so the usage here is OK.
using (AssemblyLoadingUtils.CreateAssemblyRegistrar(env, currentDirectory))
+#pragma warning restore CS0618
using (var progressCancel = new CancellationTokenSource())
{
var progressTrackerTask = Task.Run(() => TrackProgress(env, progressCancel.Token));
@@ -107,7 +109,7 @@ private static ConsoleEnvironment CreateEnvironment()
/// so we always write . If set to true though, this executable will also print stack traces from the
/// marked exceptions as well.
///
- internal static int MainCore(ConsoleEnvironment env, string args, bool alwaysPrintStacktrace)
+ internal static int MainCore(IHostEnvironment env, string args, bool alwaysPrintStacktrace)
{
// REVIEW: How should extra dlls, tracking, etc be handled? Should the args objects for
// all commands derive from a common base?
diff --git a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj
index 219b6bf0b8..a366fb6f9c 100644
--- a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj
+++ b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj
@@ -7,10 +7,6 @@
netstandard2.0
-
-
-
-
diff --git a/src/Microsoft.ML.Maml/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Maml/Properties/AssemblyInfo.cs
index 2ddc9c4ffa..2226c6d7fc 100644
--- a/src/Microsoft.ML.Maml/Properties/AssemblyInfo.cs
+++ b/src/Microsoft.ML.Maml/Properties/AssemblyInfo.cs
@@ -2,9 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System.Reflection;
using System.Runtime.CompilerServices;
-using System.Runtime.InteropServices;
+using Microsoft.ML;
-[assembly: InternalsVisibleTo("Microsoft.ML.TestFramework, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")]
-[assembly: InternalsVisibleTo("Microsoft.ML.Benchmarks, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")]
\ No newline at end of file
+[assembly: InternalsVisibleTo("Microsoft.ML.TestFramework" + PublicKey.TestValue)]
+[assembly: InternalsVisibleTo("Microsoft.ML.Benchmarks" + PublicKey.TestValue)]
+
+[assembly: InternalsVisibleTo("Microsoft.ML.Legacy" + PublicKey.Value)]
+[assembly: InternalsVisibleTo("Microsoft.ML.ResultProcessor" + PublicKey.Value)]
diff --git a/src/Microsoft.ML.Maml/VersionCommand.cs b/src/Microsoft.ML.Maml/VersionCommand.cs
index a59f01e58b..7e20db2260 100644
--- a/src/Microsoft.ML.Maml/VersionCommand.cs
+++ b/src/Microsoft.ML.Maml/VersionCommand.cs
@@ -12,7 +12,7 @@
namespace Microsoft.ML.Runtime.Tools
{
- public sealed class VersionCommand : ICommand
+ internal sealed class VersionCommand : ICommand
{
internal const string Summary = "Prints the TLC version.";
diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs
index 486bf10eeb..6ee3c8a37e 100644
--- a/src/Microsoft.ML.PCA/PcaTrainer.cs
+++ b/src/Microsoft.ML.PCA/PcaTrainer.cs
@@ -136,7 +136,7 @@ private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featur
}
//Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9)
- protected override PcaPredictor TrainModelCore(TrainContext context)
+ private protected override PcaPredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
diff --git a/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs b/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs
index ba4b1e3872..8837ac945b 100644
--- a/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs
+++ b/src/Microsoft.ML.PipelineInference/AutoMlUtils.cs
@@ -338,7 +338,7 @@ public static AutoInference.LevelDependencyMap ComputeColumnResponsibilities(IDa
return mapping;
}
- public static TlcModule.SweepableParamAttribute[] GetSweepRanges(Type learnerInputType)
+ internal static TlcModule.SweepableParamAttribute[] GetSweepRanges(Type learnerInputType)
{
var paramSet = new List();
foreach (var prop in learnerInputType.GetProperties(BindingFlags.Instance |
@@ -370,7 +370,7 @@ public static TlcModule.SweepableParamAttribute[] GetSweepRanges(Type learnerInp
return paramSet.ToArray();
}
- public static IValueGenerator ToIValueGenerator(TlcModule.SweepableParamAttribute attr)
+ internal static IValueGenerator ToIValueGenerator(TlcModule.SweepableParamAttribute attr)
{
if (attr is TlcModule.SweepableLongParamAttribute sweepableLongParamAttr)
{
@@ -430,7 +430,7 @@ private static void SetValue(PropertyInfo pi, IComparable value, object entryPoi
///
/// Updates properties of entryPointObj instance based on the values in sweepParams
///
- public static bool UpdateProperties(object entryPointObj, TlcModule.SweepableParamAttribute[] sweepParams)
+ internal static bool UpdateProperties(object entryPointObj, TlcModule.SweepableParamAttribute[] sweepParams)
{
bool result = true;
foreach (var param in sweepParams)
@@ -501,7 +501,7 @@ public static void PopulateSweepableParams(RecipeInference.SuggestedRecipe.Sugge
}
}
- public static bool CheckEntryPointStateMatchesParamValues(object entryPointObj,
+ internal static bool CheckEntryPointStateMatchesParamValues(object entryPointObj,
TlcModule.SweepableParamAttribute[] sweepParams)
{
foreach (var param in sweepParams)
@@ -584,7 +584,7 @@ public static IRunResult[] ConvertToRunResults(PipelinePattern[] history, bool i
/// Method to convert set of sweepable hyperparameters into instances used
/// by the current smart hyperparameter sweepers.
///
- public static IComponentFactory[] ConvertToComponentFactories(TlcModule.SweepableParamAttribute[] hps)
+ internal static IComponentFactory[] ConvertToComponentFactories(TlcModule.SweepableParamAttribute[] hps)
{
var results = new IComponentFactory[hps.Length];
diff --git a/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs b/src/Microsoft.ML.PipelineInference/GenerateSweepCandidatesCommand.cs
similarity index 96%
rename from test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs
rename to src/Microsoft.ML.PipelineInference/GenerateSweepCandidatesCommand.cs
index 61aad5729a..8d03868719 100644
--- a/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs
+++ b/src/Microsoft.ML.PipelineInference/GenerateSweepCandidatesCommand.cs
@@ -10,23 +10,23 @@
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
-using Microsoft.ML.Runtime.MLTesting.Inference;
using Microsoft.ML.Runtime.PipelineInference;
using Microsoft.ML.Runtime.Sweeper;
[assembly: LoadableClass(typeof(GenerateSweepCandidatesCommand), typeof(GenerateSweepCandidatesCommand.Arguments), typeof(SignatureCommand),
"Generate Experiment Candidates", "GenerateSweepCandidates", DocName = "command/GenerateSweepCandidates.md")]
-namespace Microsoft.ML.Runtime.MLTesting.Inference
+namespace Microsoft.ML.Runtime.PipelineInference
{
///
/// This is a command that takes as an input:
/// 1- the schema of a dataset, in the format produced by InferSchema
- /// 2- the path to the datafile
- /// and generates experiment candidates by combining all the transform recipes suggested for the dataset, with all the learners available for the task.
+ /// 2- the path to the datafile
+ /// and generates experiment candidates by combining all the transform recipes suggested for the dataset, with all the learners available for the task.
///
- public sealed class GenerateSweepCandidatesCommand : ICommand
+ internal sealed class GenerateSweepCandidatesCommand : ICommand
{
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
public sealed class Arguments
{
[Argument(ArgumentType.Required, HelpText = "Text file with data to analyze.", ShortName = "data")]
@@ -53,6 +53,7 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "If this option is provided, the result RSP are indented and written in separate files for easier visual inspection. Otherwise all the generated RSPs are written to a single file, one RSP per line.")]
public bool Indent;
}
+#pragma warning disable CS0649
private readonly IHost _host;
private readonly string _dataFile;
diff --git a/src/Microsoft.ML.PipelineInference/InferenceUtils.cs b/src/Microsoft.ML.PipelineInference/InferenceUtils.cs
index acff2197d8..9510e94eb3 100644
--- a/src/Microsoft.ML.PipelineInference/InferenceUtils.cs
+++ b/src/Microsoft.ML.PipelineInference/InferenceUtils.cs
@@ -18,7 +18,7 @@ public static IDataView Take(this IDataView data, int count)
{
Contracts.CheckValue(data, nameof(data));
// REVIEW: This should take an env as a parameter, not create one.
- var env = new ConsoleEnvironment(0);
+ var env = new MLContext(seed: 0);
var take = SkipTakeFilter.Create(env, new SkipTakeFilter.TakeArguments { Count = count }, data);
return CacheCore(take, env);
}
@@ -27,7 +27,7 @@ public static IDataView Cache(this IDataView data)
{
Contracts.CheckValue(data, nameof(data));
// REVIEW: This should take an env as a parameter, not create one.
- return CacheCore(data, new ConsoleEnvironment(0));
+ return CacheCore(data, new MLContext(0));
}
private static IDataView CacheCore(IDataView data, IHostEnvironment env)
diff --git a/src/Microsoft.ML.PipelineInference/Interfaces/IPipelineNode.cs b/src/Microsoft.ML.PipelineInference/Interfaces/IPipelineNode.cs
index 30f93d2eb4..a8bee5d046 100644
--- a/src/Microsoft.ML.PipelineInference/Interfaces/IPipelineNode.cs
+++ b/src/Microsoft.ML.PipelineInference/Interfaces/IPipelineNode.cs
@@ -60,7 +60,7 @@ protected string GetEpName(Type type)
return epName;
}
- protected void PropagateParamSetValues(ParameterSet hyperParams,
+ private protected void PropagateParamSetValues(ParameterSet hyperParams,
TlcModule.SweepableParamAttribute[] sweepParams)
{
var spMap = sweepParams.ToDictionary(sp => sp.Name);
@@ -79,9 +79,9 @@ public sealed class TransformPipelineNode : PipelineNodeBase, IPipelineNode sweepParams = null,
CommonInputs.ITrainerInput subTrainerObj = null)
{
@@ -136,9 +136,10 @@ public sealed class TrainerPipelineNode : PipelineNodeBase, IPipelineNode sweepParams = null,
ParameterSet hyperParameterSet = null)
{
diff --git a/src/Microsoft.ML.PipelineInference/Properties/AssemblyInfo.cs b/src/Microsoft.ML.PipelineInference/Properties/AssemblyInfo.cs
new file mode 100644
index 0000000000..db1d151fa2
--- /dev/null
+++ b/src/Microsoft.ML.PipelineInference/Properties/AssemblyInfo.cs
@@ -0,0 +1,10 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+using Microsoft.ML;
+
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)]
+
+[assembly: WantsToBeBestFriends]
diff --git a/src/Microsoft.ML.PipelineInference/TextFileContents.cs b/src/Microsoft.ML.PipelineInference/TextFileContents.cs
index f3f45cc8b9..df2dbc4a7f 100644
--- a/src/Microsoft.ML.PipelineInference/TextFileContents.cs
+++ b/src/Microsoft.ML.PipelineInference/TextFileContents.cs
@@ -114,7 +114,7 @@ private static bool TryParseFile(IChannel ch, TextLoader.Arguments args, IMultiS
try
{
// No need to provide information from unsuccessful loader, so we create temporary environment and get information from it in case of success
- using (var loaderEnv = new ConsoleEnvironment(0, true))
+ using (var loaderEnv = new ConsoleEnvironment(0, verbose: true))
{
var messages = new ConcurrentBag();
loaderEnv.AddListener(
diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
index 2d545bb9a5..de68f9be1e 100644
--- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
+++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
@@ -244,7 +244,7 @@ public MatrixFactorizationTrainer(IHostEnvironment env,
/// Train a matrix factorization model based on training data, validation data, and so on in the given context.
///
/// The information collection needed for training. for details.
- public override MatrixFactorizationPredictor Train(TrainContext context)
+ private protected override MatrixFactorizationPredictor Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
diff --git a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj
index e5610126df..e0f084d70b 100644
--- a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj
+++ b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj
@@ -7,10 +7,6 @@
true
-
-
-
-
diff --git a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
index b896e37bf6..85b9234856 100644
--- a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
+++ b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
@@ -151,7 +151,9 @@ public void GetDefaultSettingValues(IHostEnvironment env, string predictorName,
///
private Dictionary GetDefaultSettings(IHostEnvironment env, string predictorName, string[] extraAssemblies = null)
{
+#pragma warning disable CS0618 // The result processor is an internal command line processing utility anyway, so this is, while not great, OK.
AssemblyLoadingUtils.LoadAndRegister(env, extraAssemblies);
+#pragma warning restore CS0618
var cls = env.ComponentCatalog.GetLoadableClassInfo(predictorName);
if (cls == null)
@@ -1154,7 +1156,9 @@ public static int Main(string[] args)
{
string currentDirectory = Path.GetDirectoryName(typeof(ResultProcessor).Module.FullyQualifiedName);
using (var env = new ConsoleEnvironment(42))
+#pragma warning disable CS0618 // The result processor is an internal command line processing utility anyway, so this is, while not great, OK.
using (AssemblyLoadingUtils.CreateAssemblyRegistrar(env, currentDirectory))
+#pragma warning restore CS0618
return Main(env, args);
}
@@ -1197,7 +1201,9 @@ protected static void Run(IHostEnvironment env, string[] args)
if (cmd.IncludePerFoldResults)
cmd.PerFoldResultSeparator = "" + PredictionUtil.SepCharFromString(cmd.PerFoldResultSeparator);
+#pragma warning disable CS0618 // The result processor is an internal command line processing utility anyway, so this is, while not great, OK.
AssemblyLoadingUtils.LoadAndRegister(env, cmd.ExtraAssemblies);
+#pragma warning restore CS0618
if (cmd.Metrics.Length == 0)
cmd.Metrics = null;
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
index b730050a10..756c1964a6 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
@@ -429,7 +429,7 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress
return new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned);
}
- public override FieldAwareFactorizationMachinePredictor Train(TrainContext context)
+ private protected override FieldAwareFactorizationMachinePredictor Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var initPredictor = context.InitialPredictor as FieldAwareFactorizationMachinePredictor;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
index 04f272683c..14dd0692cb 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs
@@ -378,7 +378,7 @@ protected virtual void PreTrainingProcessInstance(float label, in VBuffer
///
/// The basic training calls the optimizer
///
- protected override TModel TrainModelCore(TrainContext context)
+ private protected override TModel TrainModelCore(TrainContext context)
{
Contracts.CheckValue(context, nameof(context));
Host.CheckParam(context.InitialPredictor == null || context.InitialPredictor is TModel, nameof(context.InitialPredictor));
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
index 286fad1063..9da2a2b23d 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
@@ -135,7 +135,7 @@ protected TScalarTrainer GetTrainer()
///
/// The trainig context for this learner.
/// The trained model.
- public TModel Train(TrainContext context)
+ TModel ITrainer.Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var data = context.TrainingSet;
@@ -216,7 +216,7 @@ private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
return cols;
}
- IPredictor ITrainer.Train(TrainContext context) => Train(context);
+ IPredictor ITrainer.Train(TrainContext context) => ((ITrainer)this).Train(context);
///
/// Fits the data to the trainer.
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
index 756607deea..ce6b0566f0 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
@@ -91,7 +91,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
protected override MulticlassPredictionTransformer MakeTransformer(MultiClassNaiveBayesPredictor model, Schema trainSchema)
=> new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
- protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context)
+ private protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var data = context.TrainingSet;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
index 1c25986c77..93f2361325 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
@@ -254,7 +254,7 @@ private protected static TArgs InvokeAdvanced(Action advancedSetti
return args;
}
- protected sealed override TModel TrainModelCore(TrainContext context)
+ private protected sealed override TModel TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var initPredictor = context.InitialPredictor;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
index 3ac918501d..283d4272b9 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs
@@ -68,7 +68,7 @@ private protected LinearTrainerBase(IHostEnvironment env, string featureColumn,
{
}
- protected override TModel TrainModelCore(TrainContext context)
+ private protected override TModel TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
using (var ch = Host.Start("Training"))
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
index 2640144cc6..ec39fc1e1c 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
@@ -75,7 +75,7 @@ public BinaryPredictionTransformer Fit(IDataView input)
return new BinaryPredictionTransformer(Host, pred, input.Schema, featureColumn: null);
}
- public override RandomPredictor Train(TrainContext context)
+ private protected override RandomPredictor Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));
return new RandomPredictor(Host, Host.Rand.Next());
@@ -273,7 +273,7 @@ public BinaryPredictionTransformer Fit(IDataView input)
return new BinaryPredictionTransformer(Host, pred, input.Schema, featureColumn: null);
}
- public override PriorPredictor Train(TrainContext context)
+ private protected override PriorPredictor Train(TrainContext context)
{
Contracts.CheckValue(context, nameof(context));
var data = context.TrainingSet;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs
index fe2e6b752b..d419a3a8a1 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs
@@ -28,7 +28,7 @@ public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape
private static readonly TrainerInfo _info = new TrainerInfo();
public override TrainerInfo Info => _info;
- protected override TModel TrainModelCore(TrainContext context)
+ private protected override TModel TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
using (var ch = Host.Start("Training"))
diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
index d78b6eaa7d..f3bc7a362d 100644
--- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
+++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs
@@ -163,20 +163,20 @@ private FastForestRegressionPredictor FitModel(IEnumerable previousR
/// An array of ParamaterSets which are the candidate configurations to sweep.
private ParameterSet[] GenerateCandidateConfigurations(int numOfCandidates, IEnumerable previousRuns, FastForestRegressionPredictor forest)
{
- ParameterSet[] configs = new ParameterSet[numOfCandidates];
-
// Get k best previous runs ParameterSets.
ParameterSet[] bestKParamSets = GetKBestConfigurations(previousRuns, forest, _args.LocalSearchParentCount);
// Perform local searches using the k best previous run configurations.
ParameterSet[] eiChallengers = GreedyPlusRandomSearch(bestKParamSets, forest, (int)Math.Ceiling(numOfCandidates / 2.0F), previousRuns);
- // Generate another set of random configurations to interleave
+ // Generate another set of random configurations to interleave.
ParameterSet[] randomChallengers = _randomSweeper.ProposeSweeps(numOfCandidates - eiChallengers.Length, previousRuns);
- // Return interleaved challenger candidates with random candidates
- for (int j = 0; j < configs.Length; j++)
- configs[j] = j % 2 == 0 ? eiChallengers[j / 2] : randomChallengers[j / 2];
+ // Return interleaved challenger candidates with random candidates. Since the number of candidates from either can be less than
+ // the number asked for, since we only generate unique candidates, and the number from either method may vary considerably.
+ ParameterSet[] configs = new ParameterSet[eiChallengers.Length + randomChallengers.Length];
+ Array.Copy(eiChallengers, 0, configs, 0, eiChallengers.Length);
+ Array.Copy(randomChallengers, 0, configs, eiChallengers.Length, randomChallengers.Length);
return configs;
}
diff --git a/src/Microsoft.ML.Sweeper/ConfigRunner.cs b/src/Microsoft.ML.Sweeper/ConfigRunner.cs
index 3219d691b1..806af2f5e6 100644
--- a/src/Microsoft.ML.Sweeper/ConfigRunner.cs
+++ b/src/Microsoft.ML.Sweeper/ConfigRunner.cs
@@ -110,7 +110,9 @@ public virtual void Finish()
string currentDirectory = Path.GetDirectoryName(typeof(ExeConfigRunnerBase).Module.FullyQualifiedName);
using (var ch = Host.Start("Finish"))
+#pragma warning disable CS0618 // As this deals with invoking command lines, this may be OK, though this code has some other problems.
using (AssemblyLoadingUtils.CreateAssemblyRegistrar(Host, currentDirectory))
+#pragma warning restore CS0618
{
var runs = RunNums.ToArray();
var args = Utils.BuildArray(RunNums.Count + 2,
diff --git a/src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj b/src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj
index d48a762104..9ed5d25e0e 100644
--- a/src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj
+++ b/src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj
@@ -13,11 +13,6 @@
-
-
-
-
-
diff --git a/src/Microsoft.ML.Sweeper/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Sweeper/Properties/AssemblyInfo.cs
new file mode 100644
index 0000000000..160c67eb55
--- /dev/null
+++ b/src/Microsoft.ML.Sweeper/Properties/AssemblyInfo.cs
@@ -0,0 +1,9 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+using Microsoft.ML;
+
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)]
diff --git a/src/Microsoft.ML.Sweeper/SweepCommand.cs b/src/Microsoft.ML.Sweeper/SweepCommand.cs
index 51dc28559d..fe61dc3c6e 100644
--- a/src/Microsoft.ML.Sweeper/SweepCommand.cs
+++ b/src/Microsoft.ML.Sweeper/SweepCommand.cs
@@ -16,8 +16,10 @@
namespace Microsoft.ML.Runtime.Sweeper
{
- public sealed class SweepCommand : ICommand
+ [BestFriend]
+ internal sealed class SweepCommand : ICommand
{
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
public sealed class Arguments
{
[Argument(ArgumentType.Multiple, HelpText = "Config runner", ShortName = "run,ev,evaluator", SignatureType = typeof(SignatureConfigRunner))]
@@ -39,6 +41,7 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Random seed", ShortName = "seed")]
public int? RandomSeed;
}
+#pragma warning restore CS0649
internal const string Summary = "Given a command line template and sweep ranges, creates and runs a sweep.";
diff --git a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
index 09a054be25..e9a057feb6 100644
--- a/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
+++ b/src/Microsoft.ML.Transforms/LearnerFeatureSelection.cs
@@ -21,10 +21,11 @@ namespace Microsoft.ML.Transforms
/// is greater than a threshold.
/// Instantiates a DropSlots transform to actually drop the slots.
///
- public static class LearnerFeatureSelectionTransform
+ internal static class LearnerFeatureSelectionTransform
{
internal const string Summary = "Selects the slots for which the absolute value of the corresponding weight in a linear learner is greater than a threshold.";
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
public sealed class Arguments
{
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If the corresponding absolute value of the weight for a slot is greater than this threshold, the slot is preserved", ShortName = "ft", SortOrder = 2)]
@@ -33,6 +34,8 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of slots to preserve", ShortName = "topk", SortOrder = 1)]
public int? NumSlotsToKeep;
+ // If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer, but the utility
+ // of this would be limited because estimators and transformers now act more or less like this transform used to.
[Argument(ArgumentType.Multiple, HelpText = "Filter", ShortName = "f", SortOrder = 1, SignatureType = typeof(SignatureFeatureScorerTrainer))]
public IComponentFactory>> Filter =
ComponentFactoryUtils.CreateFromFunction(env =>
@@ -74,6 +77,7 @@ internal void Check(IExceptionContext ectx)
ectx.CheckUserArg((NumSlotsToKeep ?? int.MaxValue) > 0, nameof(NumSlotsToKeep), "Must be positive");
}
}
+#pragma warning restore CS0649
internal static string RegistrationName = "LearnerFeatureSelectionTransform";
diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.json b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.json
index 880cd67637..d74ebe1c3f 100644
--- a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.json
+++ b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.json
@@ -166,24 +166,24 @@
],
"dataType": "FLOAT",
"floatData": [
- 0.6232001,
- 0.482400715,
- 0.495733529,
- 0.416533619,
- 0.425866365,
- 0.5456011,
- 0.477333277,
- 0.4338674,
- 0.20399937,
- 0.225407124,
- 0.111400835,
- 0.109446481,
- 0.120521255,
- 0.198697388,
- 0.121824168,
- 0.182410166,
- 0.108143583,
- 0.107166387
+ 0.5522167,
+ 0.3039403,
+ 0.319211155,
+ 0.261575729,
+ 0.320196062,
+ 0.344088882,
+ 0.293349,
+ 0.273151934,
+ 0.15763472,
+ 0.285144627,
+ 0.332245946,
+ 0.325724274,
+ 0.315217048,
+ 0.328623,
+ 0.3706516,
+ 0.41992715,
+ 0.307970464,
+ 0.164492577
],
"name": "C"
},
@@ -193,8 +193,8 @@
],
"dataType": "FLOAT",
"floatData": [
- 1.97708726,
- 0.200497344
+ 0.9740776,
+ 0.940771043
],
"name": "C2"
},
diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.onnx b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.onnx
deleted file mode 100644
index e1dc568fc4..0000000000
Binary files a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.onnx and /dev/null differ
diff --git a/test/BaselineOutput/SingleDebug/PCA/pca.tsv b/test/BaselineOutput/SingleDebug/PCA/pca.tsv
index ece1e59164..328ff1bf86 100644
--- a/test/BaselineOutput/SingleDebug/PCA/pca.tsv
+++ b/test/BaselineOutput/SingleDebug/PCA/pca.tsv
@@ -2,7 +2,7 @@
#@ sep=tab
#@ col=pca:R4:0-4
#@ }
-2.085487 0.09400085 2.58366132 -1.721405 -0.732070744
-0.9069792 0.7748574 0.6097196 1.07868779 0.453838825
--0.167718172 -0.92723 -0.19140324 0.243479848 -1.060547
-0.548309 0.5576686 -0.587472439 -1.38610959 0.9422219
+-2.085465 -0.09400512 -2.58367229 1.72141707 0.732049346
+-0.906982958 -0.774861753 -0.609727442 -1.07867944 -0.453824759
+0.167715371 0.927231133 0.191398591 -0.243467987 1.06056178
+-0.54830873 -0.5576661 0.587476134 1.38609958 -0.9422357
diff --git a/test/BaselineOutput/SingleRelease/PCA/pca.tsv b/test/BaselineOutput/SingleRelease/PCA/pca.tsv
index ece1e59164..328ff1bf86 100644
--- a/test/BaselineOutput/SingleRelease/PCA/pca.tsv
+++ b/test/BaselineOutput/SingleRelease/PCA/pca.tsv
@@ -2,7 +2,7 @@
#@ sep=tab
#@ col=pca:R4:0-4
#@ }
-2.085487 0.09400085 2.58366132 -1.721405 -0.732070744
-0.9069792 0.7748574 0.6097196 1.07868779 0.453838825
--0.167718172 -0.92723 -0.19140324 0.243479848 -1.060547
-0.548309 0.5576686 -0.587472439 -1.38610959 0.9422219
+-2.085465 -0.09400512 -2.58367229 1.72141707 0.732049346
+-0.906982958 -0.774861753 -0.609727442 -1.07867944 -0.453824759
+0.167715371 0.927231133 0.191398591 -0.243467987 1.06056178
+-0.54830873 -0.5576661 0.587476134 1.38609958 -0.9422357
diff --git a/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs b/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs
index ff2ddd1bbe..3b0b055d30 100644
--- a/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs
+++ b/test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs
@@ -5,33 +5,36 @@
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Transforms;
namespace Microsoft.ML.Benchmarks
{
internal static class EnvironmentFactory
{
- internal static ConsoleEnvironment CreateClassificationEnvironment()
+ internal static MLContext CreateClassificationEnvironment()
where TLoader : IDataReader
where TTransformer : ITransformer
- where TTrainer : ITrainer
+ where TTrainer : ITrainerEstimator, IPredictor>
{
- var environment = new ConsoleEnvironment(verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance);
+ var ctx = new MLContext();
+ IHostEnvironment environment = ctx;
environment.ComponentCatalog.RegisterAssembly(typeof(TLoader).Assembly);
environment.ComponentCatalog.RegisterAssembly(typeof(TTransformer).Assembly);
environment.ComponentCatalog.RegisterAssembly(typeof(TTrainer).Assembly);
- return environment;
+ return ctx;
}
- internal static ConsoleEnvironment CreateRankingEnvironment()
+ internal static MLContext CreateRankingEnvironment()
where TEvaluator : IEvaluator
where TLoader : IDataReader
where TTransformer : ITransformer
- where TTrainer : ITrainer
+ where TTrainer : ITrainerEstimator, IPredictor>
{
- var environment = new ConsoleEnvironment(verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance);
+ var ctx = new MLContext();
+ IHostEnvironment environment = ctx;
environment.ComponentCatalog.RegisterAssembly(typeof(TEvaluator).Assembly);
environment.ComponentCatalog.RegisterAssembly(typeof(TLoader).Assembly);
@@ -40,7 +43,7 @@ internal static ConsoleEnvironment CreateRankingEnvironment
+ {
+ s.HasHeader = true;
+ s.Separator = ",";
+ });
- trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures");
- trans = new ColumnConcatenatingTransformer(env, "Features", "NumFeatures", "CatFeatures").Transform(trans);
- trans = TrainAndScoreTransformer.Create(env, new TrainAndScoreTransformer.Arguments
- {
- Trainer = ComponentFactoryUtils.CreateFromFunction(host =>
- new KMeansPlusPlusTrainer(host, "Features", advancedSettings: s=>
- {
- s.K = 100;
- })),
- FeatureColumn = "Features"
- }, trans);
- trans = new ColumnConcatenatingTransformer(env, "Features", "Features", "Score").Transform(trans);
+ var estimatorPipeline = ml.Transforms.Categorical.OneHotEncoding("CatFeatures")
+ .Append(ml.Transforms.Normalize("NumFeatures"))
+ .Append(ml.Transforms.Concatenate("Features", "NumFeatures", "CatFeatures"))
+ .Append(ml.Clustering.Trainers.KMeans("Features"))
+ .Append(ml.Transforms.Concatenate("Features", "Features", "Score"))
+ .Append(ml.BinaryClassification.Trainers.LogisticRegression(advancedSettings: args => { args.EnforceNonNegativity = true; args.OptTol = 1e-3f; }));
- // Train
- var trainer = new LogisticRegression(env, "Label", "Features", advancedSettings: args => { args.EnforceNonNegativity = true; args.OptTol = 1e-3f; });
- var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
- return trainer.Train(trainRoles);
- }
+ var model = estimatorPipeline.Fit(input);
+ // Return the last model in the chain.
+ return model.LastTransformer.Model;
}
}
}
\ No newline at end of file
diff --git a/test/Microsoft.ML.Benchmarks/Numeric/Ranking.cs b/test/Microsoft.ML.Benchmarks/Numeric/Ranking.cs
index 42ea0c040f..adc1bccdce 100644
--- a/test/Microsoft.ML.Benchmarks/Numeric/Ranking.cs
+++ b/test/Microsoft.ML.Benchmarks/Numeric/Ranking.cs
@@ -24,7 +24,7 @@ public void SetupTrainingSpeedTests()
{
_mslrWeb10k_Validate = Path.GetFullPath(TestDatasets.MSLRWeb.validFilename);
_mslrWeb10k_Train = Path.GetFullPath(TestDatasets.MSLRWeb.trainFilename);
-
+
if (!File.Exists(_mslrWeb10k_Validate))
throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _mslrWeb10k_Validate));
@@ -42,10 +42,8 @@ public void TrainTest_Ranking_MSLRWeb10K_RawNumericFeatures_FastTreeRanking()
" xf=HashTransform{col=GroupId} xf=NAHandleTransform{col=Features}" +
" tr=FastTreeRanking{}";
- using (var environment = EnvironmentFactory.CreateRankingEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateRankingEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
[Benchmark]
@@ -59,10 +57,8 @@ public void TrainTest_Ranking_MSLRWeb10K_RawNumericFeatures_LightGBMRanking()
" xf=NAHandleTransform{col=Features}" +
" tr=LightGBMRanking{}";
- using (var environment = EnvironmentFactory.CreateRankingEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateRankingEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
}
@@ -100,10 +96,8 @@ public void SetupScoringSpeedTests()
" tr=FastTreeRanking{}" +
" out={" + _modelPath_MSLR + "}";
- using (var environment = EnvironmentFactory.CreateRankingEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateRankingEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
[Benchmark]
@@ -112,10 +106,8 @@ public void Test_Ranking_MSLRWeb10K_RawNumericFeatures_FastTreeRanking()
// This benchmark is profiling bulk scoring speed and not training speed.
string cmd = @"Test data=" + _mslrWeb10k_Test + " in=" + _modelPath_MSLR;
- using (var environment = EnvironmentFactory.CreateRankingEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateRankingEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
}
}
diff --git a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs
index 4d728ab45b..39342cf224 100644
--- a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs
+++ b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs
@@ -36,32 +36,30 @@ public void SetupIrisPipeline()
string _irisDataPath = Program.GetInvariantCultureDataPath("iris.txt");
- using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance))
- {
- var reader = new TextLoader(env,
- new TextLoader.Arguments()
+ var env = new MLContext(seed: 1, conc: 1);
+ var reader = new TextLoader(env,
+ new TextLoader.Arguments()
+ {
+ Separator = "\t",
+ HasHeader = true,
+ Column = new[]
{
- Separator = "\t",
- HasHeader = true,
- Column = new[]
- {
new TextLoader.Column("Label", DataKind.R4, 0),
new TextLoader.Column("SepalLength", DataKind.R4, 1),
new TextLoader.Column("SepalWidth", DataKind.R4, 2),
new TextLoader.Column("PetalLength", DataKind.R4, 3),
new TextLoader.Column("PetalWidth", DataKind.R4, 4),
- }
- });
+ }
+ });
- IDataView data = reader.Read(_irisDataPath);
+ IDataView data = reader.Read(_irisDataPath);
- var pipeline = new ColumnConcatenatingEstimator (env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" })
- .Append(new SdcaMultiClassTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; }));
+ var pipeline = new ColumnConcatenatingEstimator(env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" })
+ .Append(new SdcaMultiClassTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; }));
- var model = pipeline.Fit(data);
+ var model = pipeline.Fit(data);
- _irisModel = model.MakePredictionFunction(env);
- }
+ _irisModel = model.MakePredictionFunction(env);
}
[GlobalSetup(Target = nameof(MakeSentimentPredictions))]
@@ -74,9 +72,8 @@ public void SetupSentimentPipeline()
string _sentimentDataPath = Program.GetInvariantCultureDataPath("wikipedia-detox-250-line-data.tsv");
- using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance))
- {
- var reader = new TextLoader(env,
+ var env = new MLContext(seed: 1, conc: 1);
+ var reader = new TextLoader(env,
new TextLoader.Arguments()
{
Separator = "\t",
@@ -88,15 +85,14 @@ public void SetupSentimentPipeline()
}
});
- IDataView data = reader.Read(_sentimentDataPath);
+ IDataView data = reader.Read(_sentimentDataPath);
- var pipeline = new TextFeaturizingEstimator(env, "SentimentText", "Features")
- .Append(new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; }));
+ var pipeline = new TextFeaturizingEstimator(env, "SentimentText", "Features")
+ .Append(new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; }));
- var model = pipeline.Fit(data);
+ var model = pipeline.Fit(data);
- _sentimentModel = model.MakePredictionFunction(env);
- }
+ _sentimentModel = model.MakePredictionFunction(env);
}
[GlobalSetup(Target = nameof(MakeBreastCancerPredictions))]
@@ -109,9 +105,8 @@ public void SetupBreastCancerPipeline()
string _breastCancerDataPath = Program.GetInvariantCultureDataPath("breast-cancer.txt");
- using (var env = new ConsoleEnvironment(seed: 1, conc: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance))
- {
- var reader = new TextLoader(env,
+ var env = new MLContext(seed: 1, conc: 1);
+ var reader = new TextLoader(env,
new TextLoader.Arguments()
{
Separator = "\t",
@@ -123,14 +118,13 @@ public void SetupBreastCancerPipeline()
}
});
- IDataView data = reader.Read(_breastCancerDataPath);
+ IDataView data = reader.Read(_breastCancerDataPath);
- var pipeline = new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; });
+ var pipeline = new SdcaBinaryTrainer(env, "Label", "Features", advancedSettings: (s) => { s.NumThreads = 1; s.ConvergenceTolerance = 1e-2f; });
- var model = pipeline.Fit(data);
+ var model = pipeline.Fit(data);
- _breastCancerModel = model.MakePredictionFunction(env);
- }
+ _breastCancerModel = model.MakePredictionFunction(env);
}
[Benchmark]
diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
index 4227c8708f..7e2e5f8702 100644
--- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
+++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs
@@ -63,18 +63,17 @@ private Legacy.PredictionModel Train(string dataPath)
[Benchmark]
public void TrainSentiment()
{
- using (var env = new ConsoleEnvironment(seed: 1))
- {
- // Pipeline
- var loader = TextLoader.ReadFile(env,
- new TextLoader.Arguments()
+ var env = new MLContext(seed: 1);
+ // Pipeline
+ var loader = TextLoader.ReadFile(env,
+ new TextLoader.Arguments()
+ {
+ AllowQuoting = false,
+ AllowSparse = false,
+ Separator = "tab",
+ HasHeader = true,
+ Column = new[]
{
- AllowQuoting = false,
- AllowSparse = false,
- Separator = "tab",
- HasHeader = true,
- Column = new[]
- {
new TextLoader.Column()
{
Name = "Label",
@@ -88,14 +87,14 @@ public void TrainSentiment()
Source = new [] { new TextLoader.Range() { Min=1, Max=1} },
Type = DataKind.Text
}
- }
- }, new MultiFileSource(_sentimentDataPath));
+ }
+ }, new MultiFileSource(_sentimentDataPath));
- var text = TextFeaturizingEstimator.Create(env,
- new TextFeaturizingEstimator.Arguments()
+ var text = TextFeaturizingEstimator.Create(env,
+ new TextFeaturizingEstimator.Arguments()
+ {
+ Column = new TextFeaturizingEstimator.Column
{
- Column = new TextFeaturizingEstimator.Column
- {
Name = "WordEmbeddings",
Source = new[] { "SentimentText" }
},
@@ -121,13 +120,10 @@ public void TrainSentiment()
ModelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe,
}, text);
- // Train
- var trainer = new SdcaMultiClassTrainer(env, "Label", "Features", maxIterations: 20);
- var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
-
- var predicted = trainer.Train(trainRoles);
- _consumer.Consume(predicted);
- }
+ // Train
+ var trainer = new SdcaMultiClassTrainer(env, "Label", "Features", maxIterations: 20);
+ var predicted = trainer.Fit(trans);
+ _consumer.Consume(predicted);
}
[GlobalSetup(Targets = new string[] { nameof(PredictIris), nameof(PredictIrisBatchOf1), nameof(PredictIrisBatchOf2), nameof(PredictIrisBatchOf5) })]
diff --git a/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs b/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs
index 8918c7b2e5..d7705aaadb 100644
--- a/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs
+++ b/test/Microsoft.ML.Benchmarks/Text/MultiClassClassification.cs
@@ -25,13 +25,13 @@ public void SetupTrainingSpeedTests()
_dataPath_Wiki = Path.GetFullPath(TestDatasets.WikiDetox.trainFilename);
if (!File.Exists(_dataPath_Wiki))
- throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _dataPath_Wiki));
+ throw new FileNotFoundException(string.Format(Errors.DatasetNotFound, _dataPath_Wiki));
}
[Benchmark]
public void CV_Multiclass_WikiDetox_BigramsAndTrichar_OVAAveragedPerceptron()
{
- string cmd = @"CV k=5 data=" + _dataPath_Wiki +
+ string cmd = @"CV k=5 data=" + _dataPath_Wiki +
" loader=TextLoader{quote=- sparse=- col=Label:R4:0 col=rev_id:TX:1 col=comment:TX:2 col=logged_in:BL:4 col=ns:TX:5 col=sample:TX:6 col=split:TX:7 col=year:R4:3 header=+}" +
" xf=Convert{col=logged_in type=R4}" +
" xf=CategoricalTransform{col=ns}" +
@@ -39,10 +39,8 @@ public void CV_Multiclass_WikiDetox_BigramsAndTrichar_OVAAveragedPerceptron()
" xf=Concat{col=Features:FeaturesText,logged_in,ns}" +
" tr=OVA{p=AveragedPerceptron{iter=10}}";
- using (var environment = EnvironmentFactory.CreateClassificationEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateClassificationEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
[Benchmark]
@@ -56,10 +54,8 @@ public void CV_Multiclass_WikiDetox_BigramsAndTrichar_LightGBMMulticlass()
" xf=Concat{col=Features:FeaturesText,logged_in,ns}" +
" tr=LightGBMMulticlass{iter=10}";
- using (var environment = EnvironmentFactory.CreateClassificationEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateClassificationEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
[Benchmark]
@@ -74,10 +70,8 @@ public void CV_Multiclass_WikiDetox_WordEmbeddings_OVAAveragedPerceptron()
" xf=WordEmbeddingsTransform{col=FeaturesWordEmbedding:FeaturesText_TransformedText model=FastTextWikipedia300D}" +
" xf=Concat{col=Features:FeaturesText,FeaturesWordEmbedding,logged_in,ns}";
- using (var environment = EnvironmentFactory.CreateClassificationEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateClassificationEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
[Benchmark]
@@ -92,10 +86,8 @@ public void CV_Multiclass_WikiDetox_WordEmbeddings_SDCAMC()
" xf=WordEmbeddingsTransform{col=FeaturesWordEmbedding:FeaturesText_TransformedText model=FastTextWikipedia300D}" +
" xf=Concat{col=Features:FeaturesWordEmbedding,logged_in,ns}";
- using (var environment = EnvironmentFactory.CreateClassificationEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateClassificationEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
}
@@ -122,10 +114,8 @@ public void SetupScoringSpeedTests()
" tr=OVA{p=AveragedPerceptron{iter=10}}" +
" out={" + _modelPath_Wiki + "}";
- using (var environment = EnvironmentFactory.CreateClassificationEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateClassificationEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
[Benchmark]
@@ -135,10 +125,8 @@ public void Test_Multiclass_WikiDetox_BigramsAndTrichar_OVAAveragedPerceptron()
string modelpath = Path.Combine(Directory.GetCurrentDirectory(), @"WikiModel.fold000.zip");
string cmd = @"Test data=" + _dataPath_Wiki + " in=" + modelpath;
- using (var environment = EnvironmentFactory.CreateClassificationEnvironment())
- {
- Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
- }
+ var environment = EnvironmentFactory.CreateClassificationEnvironment();
+ Maml.MainCore(environment, cmd, alwaysPrintStacktrace: false);
}
}
}
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
index 5f6fdb1620..30cb1fcc99 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
@@ -25,99 +25,95 @@ public TestCSharpApi(ITestOutputHelper output) : base(output)
public void TestSimpleExperiment()
{
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
- using (var env = new ConsoleEnvironment())
- {
- var experiment = env.CreateExperiment();
+ var env = new MLContext();
+ var experiment = env.CreateExperiment();
- var importInput = new Legacy.Data.TextLoader(dataPath);
- var importOutput = experiment.Add(importInput);
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ var importOutput = experiment.Add(importInput);
- var normalizeInput = new Legacy.Transforms.MinMaxNormalizer
- {
- Data = importOutput.Data
- };
- normalizeInput.AddColumn("NumericFeatures");
- var normalizeOutput = experiment.Add(normalizeInput);
-
- experiment.Compile();
- experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
- experiment.Run();
- var data = experiment.GetOutput(normalizeOutput.OutputData);
-
- var schema = data.Schema;
- Assert.Equal(5, schema.ColumnCount);
- var expected = new[] { "Label", "Workclass", "Categories", "NumericFeatures", "NumericFeatures" };
- for (int i = 0; i < schema.ColumnCount; i++)
- Assert.Equal(expected[i], schema.GetColumnName(i));
- }
+ var normalizeInput = new Legacy.Transforms.MinMaxNormalizer
+ {
+ Data = importOutput.Data
+ };
+ normalizeInput.AddColumn("NumericFeatures");
+ var normalizeOutput = experiment.Add(normalizeInput);
+
+ experiment.Compile();
+ experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
+ experiment.Run();
+ var data = experiment.GetOutput(normalizeOutput.OutputData);
+
+ var schema = data.Schema;
+ Assert.Equal(5, schema.ColumnCount);
+ var expected = new[] { "Label", "Workclass", "Categories", "NumericFeatures", "NumericFeatures" };
+ for (int i = 0; i < schema.ColumnCount; i++)
+ Assert.Equal(expected[i], schema.GetColumnName(i));
}
[Fact]
public void TestSimpleTrainExperiment()
{
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
- using (var env = new ConsoleEnvironment())
- {
- var experiment = env.CreateExperiment();
-
- var importInput = new Legacy.Data.TextLoader(dataPath);
- var importOutput = experiment.Add(importInput);
-
- var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer
- {
- Data = importOutput.Data
- };
- catInput.AddColumn("Categories");
- var catOutput = experiment.Add(catInput);
+ var env = new MLContext();
+ var experiment = env.CreateExperiment();
- var concatInput = new Legacy.Transforms.ColumnConcatenator
- {
- Data = catOutput.OutputData
- };
- concatInput.AddColumn("Features", "Categories", "NumericFeatures");
- var concatOutput = experiment.Add(concatInput);
-
- var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier
- {
- TrainingData = concatOutput.OutputData,
- LossFunction = new HingeLossSDCAClassificationLossFunction() { Margin = 1.1f },
- NumThreads = 1,
- Shuffle = false
- };
- var sdcaOutput = experiment.Add(sdcaInput);
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ var importOutput = experiment.Add(importInput);
- var scoreInput = new Legacy.Transforms.DatasetScorer
- {
- Data = concatOutput.OutputData,
- PredictorModel = sdcaOutput.PredictorModel
- };
- var scoreOutput = experiment.Add(scoreInput);
+ var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer
+ {
+ Data = importOutput.Data
+ };
+ catInput.AddColumn("Categories");
+ var catOutput = experiment.Add(catInput);
- var evalInput = new Legacy.Models.BinaryClassificationEvaluator
- {
- Data = scoreOutput.ScoredData
- };
- var evalOutput = experiment.Add(evalInput);
+ var concatInput = new Legacy.Transforms.ColumnConcatenator
+ {
+ Data = catOutput.OutputData
+ };
+ concatInput.AddColumn("Features", "Categories", "NumericFeatures");
+ var concatOutput = experiment.Add(concatInput);
- experiment.Compile();
- experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
- experiment.Run();
- var data = experiment.GetOutput(evalOutput.OverallMetrics);
+ var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier
+ {
+ TrainingData = concatOutput.OutputData,
+ LossFunction = new HingeLossSDCAClassificationLossFunction() { Margin = 1.1f },
+ NumThreads = 1,
+ Shuffle = false
+ };
+ var sdcaOutput = experiment.Add(sdcaInput);
+
+ var scoreInput = new Legacy.Transforms.DatasetScorer
+ {
+ Data = concatOutput.OutputData,
+ PredictorModel = sdcaOutput.PredictorModel
+ };
+ var scoreOutput = experiment.Add(scoreInput);
- var schema = data.Schema;
- var b = schema.TryGetColumnIndex("AUC", out int aucCol);
+ var evalInput = new Legacy.Models.BinaryClassificationEvaluator
+ {
+ Data = scoreOutput.ScoredData
+ };
+ var evalOutput = experiment.Add(evalInput);
+
+ experiment.Compile();
+ experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
+ experiment.Run();
+ var data = experiment.GetOutput(evalOutput.OverallMetrics);
+
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex("AUC", out int aucCol);
+ Assert.True(b);
+ using (var cursor = data.GetRowCursor(col => col == aucCol))
+ {
+ var getter = cursor.GetGetter(aucCol);
+ b = cursor.MoveNext();
Assert.True(b);
- using (var cursor = data.GetRowCursor(col => col == aucCol))
- {
- var getter = cursor.GetGetter(aucCol);
- b = cursor.MoveNext();
- Assert.True(b);
- double auc = 0;
- getter(ref auc);
- Assert.Equal(0.93, auc, 2);
- b = cursor.MoveNext();
- Assert.False(b);
- }
+ double auc = 0;
+ getter(ref auc);
+ Assert.Equal(0.93, auc, 2);
+ b = cursor.MoveNext();
+ Assert.False(b);
}
}
@@ -125,71 +121,69 @@ public void TestSimpleTrainExperiment()
public void TestTrainTestMacro()
{
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
- using (var env = new ConsoleEnvironment())
- {
- var subGraph = env.CreateExperiment();
-
- var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer();
- catInput.AddColumn("Categories");
- var catOutput = subGraph.Add(catInput);
-
- var concatInput = new Legacy.Transforms.ColumnConcatenator
- {
- Data = catOutput.OutputData
- };
- concatInput.AddColumn("Features", "Categories", "NumericFeatures");
- var concatOutput = subGraph.Add(concatInput);
+ var env = new MLContext();
+ var subGraph = env.CreateExperiment();
- var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier
- {
- TrainingData = concatOutput.OutputData,
- LossFunction = new HingeLossSDCAClassificationLossFunction() { Margin = 1.1f },
- NumThreads = 1,
- Shuffle = false
- };
- var sdcaOutput = subGraph.Add(sdcaInput);
-
- var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
- {
- TransformModels = new ArrayVar(catOutput.Model, concatOutput.Model),
- PredictorModel = sdcaOutput.PredictorModel
- };
- var modelCombineOutput = subGraph.Add(modelCombine);
+ var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer();
+ catInput.AddColumn("Categories");
+ var catOutput = subGraph.Add(catInput);
- var experiment = env.CreateExperiment();
+ var concatInput = new Legacy.Transforms.ColumnConcatenator
+ {
+ Data = catOutput.OutputData
+ };
+ concatInput.AddColumn("Features", "Categories", "NumericFeatures");
+ var concatOutput = subGraph.Add(concatInput);
- var importInput = new Legacy.Data.TextLoader(dataPath);
- var importOutput = experiment.Add(importInput);
+ var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier
+ {
+ TrainingData = concatOutput.OutputData,
+ LossFunction = new HingeLossSDCAClassificationLossFunction() { Margin = 1.1f },
+ NumThreads = 1,
+ Shuffle = false
+ };
+ var sdcaOutput = subGraph.Add(sdcaInput);
+
+ var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
+ {
+ TransformModels = new ArrayVar(catOutput.Model, concatOutput.Model),
+ PredictorModel = sdcaOutput.PredictorModel
+ };
+ var modelCombineOutput = subGraph.Add(modelCombine);
- var trainTestInput = new Legacy.Models.TrainTestBinaryEvaluator
- {
- TrainingData = importOutput.Data,
- TestingData = importOutput.Data,
- Nodes = subGraph
- };
- trainTestInput.Inputs.Data = catInput.Data;
- trainTestInput.Outputs.Model = modelCombineOutput.PredictorModel;
- var trainTestOutput = experiment.Add(trainTestInput);
+ var experiment = env.CreateExperiment();
- experiment.Compile();
- experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
- experiment.Run();
- var data = experiment.GetOutput(trainTestOutput.OverallMetrics);
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ var importOutput = experiment.Add(importInput);
- var schema = data.Schema;
- var b = schema.TryGetColumnIndex("AUC", out int aucCol);
+ var trainTestInput = new Legacy.Models.TrainTestBinaryEvaluator
+ {
+ TrainingData = importOutput.Data,
+ TestingData = importOutput.Data,
+ Nodes = subGraph
+ };
+ trainTestInput.Inputs.Data = catInput.Data;
+ trainTestInput.Outputs.Model = modelCombineOutput.PredictorModel;
+ var trainTestOutput = experiment.Add(trainTestInput);
+
+ experiment.Compile();
+ experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
+ experiment.Run();
+ var data = experiment.GetOutput(trainTestOutput.OverallMetrics);
+
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex("AUC", out int aucCol);
+ Assert.True(b);
+ using (var cursor = data.GetRowCursor(col => col == aucCol))
+ {
+ var getter = cursor.GetGetter(aucCol);
+ b = cursor.MoveNext();
Assert.True(b);
- using (var cursor = data.GetRowCursor(col => col == aucCol))
- {
- var getter = cursor.GetGetter(aucCol);
- b = cursor.MoveNext();
- Assert.True(b);
- double auc = 0;
- getter(ref auc);
- Assert.Equal(0.93, auc, 2);
- b = cursor.MoveNext();
- Assert.False(b);
- }
+ double auc = 0;
+ getter(ref auc);
+ Assert.Equal(0.93, auc, 2);
+ b = cursor.MoveNext();
+ Assert.False(b);
}
}
@@ -197,68 +191,66 @@ public void TestTrainTestMacro()
public void TestCrossValidationBinaryMacro()
{
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
- using (var env = new ConsoleEnvironment())
- {
- var subGraph = env.CreateExperiment();
-
- var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer();
- catInput.AddColumn("Categories");
- var catOutput = subGraph.Add(catInput);
+ var env = new MLContext();
+ var subGraph = env.CreateExperiment();
- var concatInput = new Legacy.Transforms.ColumnConcatenator
- {
- Data = catOutput.OutputData
- };
- concatInput.AddColumn("Features", "Categories", "NumericFeatures");
- var concatOutput = subGraph.Add(concatInput);
-
- var lrInput = new Legacy.Trainers.LogisticRegressionBinaryClassifier
- {
- TrainingData = concatOutput.OutputData,
- NumThreads = 1
- };
- var lrOutput = subGraph.Add(lrInput);
+ var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer();
+ catInput.AddColumn("Categories");
+ var catOutput = subGraph.Add(catInput);
- var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
- {
- TransformModels = new ArrayVar(catOutput.Model, concatOutput.Model),
- PredictorModel = lrOutput.PredictorModel
- };
- var modelCombineOutput = subGraph.Add(modelCombine);
+ var concatInput = new Legacy.Transforms.ColumnConcatenator
+ {
+ Data = catOutput.OutputData
+ };
+ concatInput.AddColumn("Features", "Categories", "NumericFeatures");
+ var concatOutput = subGraph.Add(concatInput);
- var experiment = env.CreateExperiment();
+ var lrInput = new Legacy.Trainers.LogisticRegressionBinaryClassifier
+ {
+ TrainingData = concatOutput.OutputData,
+ NumThreads = 1
+ };
+ var lrOutput = subGraph.Add(lrInput);
- var importInput = new Legacy.Data.TextLoader(dataPath);
- var importOutput = experiment.Add(importInput);
+ var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
+ {
+ TransformModels = new ArrayVar(catOutput.Model, concatOutput.Model),
+ PredictorModel = lrOutput.PredictorModel
+ };
+ var modelCombineOutput = subGraph.Add(modelCombine);
- var crossValidateBinary = new Legacy.Models.BinaryCrossValidator
- {
- Data = importOutput.Data,
- Nodes = subGraph
- };
- crossValidateBinary.Inputs.Data = catInput.Data;
- crossValidateBinary.Outputs.Model = modelCombineOutput.PredictorModel;
- var crossValidateOutput = experiment.Add(crossValidateBinary);
+ var experiment = env.CreateExperiment();
- experiment.Compile();
- importInput.SetInput(env, experiment);
- experiment.Run();
- var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]);
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ var importOutput = experiment.Add(importInput);
- var schema = data.Schema;
- var b = schema.TryGetColumnIndex("AUC", out int aucCol);
+ var crossValidateBinary = new Legacy.Models.BinaryCrossValidator
+ {
+ Data = importOutput.Data,
+ Nodes = subGraph
+ };
+ crossValidateBinary.Inputs.Data = catInput.Data;
+ crossValidateBinary.Outputs.Model = modelCombineOutput.PredictorModel;
+ var crossValidateOutput = experiment.Add(crossValidateBinary);
+
+ experiment.Compile();
+ importInput.SetInput(env, experiment);
+ experiment.Run();
+ var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]);
+
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex("AUC", out int aucCol);
+ Assert.True(b);
+ using (var cursor = data.GetRowCursor(col => col == aucCol))
+ {
+ var getter = cursor.GetGetter(aucCol);
+ b = cursor.MoveNext();
Assert.True(b);
- using (var cursor = data.GetRowCursor(col => col == aucCol))
- {
- var getter = cursor.GetGetter(aucCol);
- b = cursor.MoveNext();
- Assert.True(b);
- double auc = 0;
- getter(ref auc);
- Assert.Equal(0.87, auc, 1);
- b = cursor.MoveNext();
- Assert.False(b);
- }
+ double auc = 0;
+ getter(ref auc);
+ Assert.Equal(0.87, auc, 1);
+ b = cursor.MoveNext();
+ Assert.False(b);
}
}
@@ -266,42 +258,41 @@ public void TestCrossValidationBinaryMacro()
public void TestCrossValidationMacro()
{
var dataPath = GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename);
- using (var env = new ConsoleEnvironment(42))
- {
- var subGraph = env.CreateExperiment();
+ var env = new MLContext(42);
+ var subGraph = env.CreateExperiment();
- var nop = new Legacy.Transforms.NoOperation();
- var nopOutput = subGraph.Add(nop);
+ var nop = new Legacy.Transforms.NoOperation();
+ var nopOutput = subGraph.Add(nop);
- var generate = new Legacy.Transforms.RandomNumberGenerator();
- generate.Column = new[] { new Legacy.Transforms.GenerateNumberTransformColumn() { Name = "Weight1" } };
- generate.Data = nopOutput.OutputData;
- var generateOutput = subGraph.Add(generate);
+ var generate = new Legacy.Transforms.RandomNumberGenerator();
+ generate.Column = new[] { new Legacy.Transforms.GenerateNumberTransformColumn() { Name = "Weight1" } };
+ generate.Data = nopOutput.OutputData;
+ var generateOutput = subGraph.Add(generate);
- var learnerInput = new Legacy.Trainers.PoissonRegressor
- {
- TrainingData = generateOutput.OutputData,
- NumThreads = 1,
- WeightColumn = "Weight1"
- };
- var learnerOutput = subGraph.Add(learnerInput);
+ var learnerInput = new Legacy.Trainers.PoissonRegressor
+ {
+ TrainingData = generateOutput.OutputData,
+ NumThreads = 1,
+ WeightColumn = "Weight1"
+ };
+ var learnerOutput = subGraph.Add(learnerInput);
- var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
- {
- TransformModels = new ArrayVar(nopOutput.Model, generateOutput.Model),
- PredictorModel = learnerOutput.PredictorModel
- };
- var modelCombineOutput = subGraph.Add(modelCombine);
+ var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
+ {
+ TransformModels = new ArrayVar(nopOutput.Model, generateOutput.Model),
+ PredictorModel = learnerOutput.PredictorModel
+ };
+ var modelCombineOutput = subGraph.Add(modelCombine);
- var experiment = env.CreateExperiment();
- var importInput = new Legacy.Data.TextLoader(dataPath)
+ var experiment = env.CreateExperiment();
+ var importInput = new Legacy.Data.TextLoader(dataPath)
+ {
+ Arguments = new Legacy.Data.TextLoaderArguments
+ {
+ Separator = new[] { ';' },
+ HasHeader = true,
+ Column = new[]
{
- Arguments = new Legacy.Data.TextLoaderArguments
- {
- Separator = new[] { ';' },
- HasHeader = true,
- Column = new[]
- {
new TextLoaderColumn()
{
Name = "Label",
@@ -316,95 +307,94 @@ public void TestCrossValidationMacro()
Type = Legacy.Data.DataKind.Num
}
}
- }
- };
- var importOutput = experiment.Add(importInput);
+ }
+ };
+ var importOutput = experiment.Add(importInput);
- var crossValidate = new Legacy.Models.CrossValidator
+ var crossValidate = new Legacy.Models.CrossValidator
+ {
+ Data = importOutput.Data,
+ Nodes = subGraph,
+ Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureRegressorTrainer,
+ TransformModel = null,
+ WeightColumn = "Weight1"
+ };
+ crossValidate.Inputs.Data = nop.Data;
+ crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
+ var crossValidateOutput = experiment.Add(crossValidate);
+
+ experiment.Compile();
+ importInput.SetInput(env, experiment);
+ experiment.Run();
+ var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);
+
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex("L1(avg)", out int metricCol);
+ Assert.True(b);
+ b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
+ Assert.True(b);
+ b = schema.TryGetColumnIndex("IsWeighted", out int isWeightedCol);
+ using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol || col == isWeightedCol))
+ {
+ var getter = cursor.GetGetter(metricCol);
+ var foldGetter = cursor.GetGetter>(foldCol);
+ ReadOnlyMemory fold = default;
+ var isWeightedGetter = cursor.GetGetter(isWeightedCol);
+ bool isWeighted = default;
+ double avg = 0;
+ double weightedAvg = 0;
+ for (int w = 0; w < 2; w++)
{
- Data = importOutput.Data,
- Nodes = subGraph,
- Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureRegressorTrainer,
- TransformModel = null,
- WeightColumn = "Weight1"
- };
- crossValidate.Inputs.Data = nop.Data;
- crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
- var crossValidateOutput = experiment.Add(crossValidate);
-
- experiment.Compile();
- importInput.SetInput(env, experiment);
- experiment.Run();
- var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);
+ // Get the average.
+ b = cursor.MoveNext();
+ Assert.True(b);
+ if (w == 1)
+ getter(ref weightedAvg);
+ else
+ getter(ref avg);
+ foldGetter(ref fold);
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold));
+ isWeightedGetter(ref isWeighted);
+ Assert.True(isWeighted == (w == 1));
- var schema = data.Schema;
- var b = schema.TryGetColumnIndex("L1(avg)", out int metricCol);
- Assert.True(b);
- b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
- Assert.True(b);
- b = schema.TryGetColumnIndex("IsWeighted", out int isWeightedCol);
- using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol || col == isWeightedCol))
+ // Get the standard deviation.
+ b = cursor.MoveNext();
+ Assert.True(b);
+ double stdev = 0;
+ getter(ref stdev);
+ foldGetter(ref fold);
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold));
+ if (w == 1)
+ Assert.Equal(1.585, stdev, 3);
+ else
+ Assert.Equal(1.39, stdev, 2);
+ isWeightedGetter(ref isWeighted);
+ Assert.True(isWeighted == (w == 1));
+ }
+ double sum = 0;
+ double weightedSum = 0;
+ for (int f = 0; f < 2; f++)
{
- var getter = cursor.GetGetter(metricCol);
- var foldGetter = cursor.GetGetter>(foldCol);
- ReadOnlyMemory fold = default;
- var isWeightedGetter = cursor.GetGetter(isWeightedCol);
- bool isWeighted = default;
- double avg = 0;
- double weightedAvg = 0;
for (int w = 0; w < 2; w++)
{
- // Get the average.
b = cursor.MoveNext();
Assert.True(b);
- if (w == 1)
- getter(ref weightedAvg);
- else
- getter(ref avg);
- foldGetter(ref fold);
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold));
- isWeightedGetter(ref isWeighted);
- Assert.True(isWeighted == (w == 1));
-
- // Get the standard deviation.
- b = cursor.MoveNext();
- Assert.True(b);
- double stdev = 0;
- getter(ref stdev);
+ double val = 0;
+ getter(ref val);
foldGetter(ref fold);
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold));
if (w == 1)
- Assert.Equal(1.585, stdev, 3);
+ weightedSum += val;
else
- Assert.Equal(1.39, stdev, 2);
+ sum += val;
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold));
isWeightedGetter(ref isWeighted);
Assert.True(isWeighted == (w == 1));
}
- double sum = 0;
- double weightedSum = 0;
- for (int f = 0; f < 2; f++)
- {
- for (int w = 0; w < 2; w++)
- {
- b = cursor.MoveNext();
- Assert.True(b);
- double val = 0;
- getter(ref val);
- foldGetter(ref fold);
- if (w == 1)
- weightedSum += val;
- else
- sum += val;
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold));
- isWeightedGetter(ref isWeighted);
- Assert.True(isWeighted == (w == 1));
- }
- }
- Assert.Equal(weightedAvg, weightedSum / 2);
- Assert.Equal(avg, sum / 2);
- b = cursor.MoveNext();
- Assert.False(b);
}
+ Assert.Equal(weightedAvg, weightedSum / 2);
+ Assert.Equal(avg, sum / 2);
+ b = cursor.MoveNext();
+ Assert.False(b);
}
}
@@ -412,207 +402,203 @@ public void TestCrossValidationMacro()
public void TestCrossValidationMacroWithMultiClass()
{
var dataPath = GetDataPath(@"Train-Tiny-28x28.txt");
- using (var env = new ConsoleEnvironment(42))
- {
- var subGraph = env.CreateExperiment();
+ var env = new MLContext(42);
+ var subGraph = env.CreateExperiment();
- var nop = new Legacy.Transforms.NoOperation();
- var nopOutput = subGraph.Add(nop);
+ var nop = new Legacy.Transforms.NoOperation();
+ var nopOutput = subGraph.Add(nop);
- var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentClassifier
- {
- TrainingData = nopOutput.OutputData,
- NumThreads = 1
- };
- var learnerOutput = subGraph.Add(learnerInput);
-
- var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
- {
- TransformModels = new ArrayVar(nopOutput.Model),
- PredictorModel = learnerOutput.PredictorModel
- };
- var modelCombineOutput = subGraph.Add(modelCombine);
+ var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentClassifier
+ {
+ TrainingData = nopOutput.OutputData,
+ NumThreads = 1
+ };
+ var learnerOutput = subGraph.Add(learnerInput);
- var experiment = env.CreateExperiment();
- var importInput = new Legacy.Data.TextLoader(dataPath);
- var importOutput = experiment.Add(importInput);
+ var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
+ {
+ TransformModels = new ArrayVar(nopOutput.Model),
+ PredictorModel = learnerOutput.PredictorModel
+ };
+ var modelCombineOutput = subGraph.Add(modelCombine);
- var crossValidate = new Legacy.Models.CrossValidator
- {
- Data = importOutput.Data,
- Nodes = subGraph,
- Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer,
- TransformModel = null
- };
- crossValidate.Inputs.Data = nop.Data;
- crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
- var crossValidateOutput = experiment.Add(crossValidate);
+ var experiment = env.CreateExperiment();
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ var importOutput = experiment.Add(importInput);
- experiment.Compile();
- importInput.SetInput(env, experiment);
- experiment.Run();
- var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);
+ var crossValidate = new Legacy.Models.CrossValidator
+ {
+ Data = importOutput.Data,
+ Nodes = subGraph,
+ Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer,
+ TransformModel = null
+ };
+ crossValidate.Inputs.Data = nop.Data;
+ crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
+ var crossValidateOutput = experiment.Add(crossValidate);
+
+ experiment.Compile();
+ importInput.SetInput(env, experiment);
+ experiment.Run();
+ var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);
+
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex("Accuracy(micro-avg)", out int metricCol);
+ Assert.True(b);
+ b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
+ Assert.True(b);
+ using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
+ {
+ var getter = cursor.GetGetter(metricCol);
+ var foldGetter = cursor.GetGetter>(foldCol);
+ ReadOnlyMemory fold = default;
- var schema = data.Schema;
- var b = schema.TryGetColumnIndex("Accuracy(micro-avg)", out int metricCol);
- Assert.True(b);
- b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
+ // Get the average.
+ b = cursor.MoveNext();
Assert.True(b);
- using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
- {
- var getter = cursor.GetGetter(metricCol);
- var foldGetter = cursor.GetGetter>(foldCol);
- ReadOnlyMemory fold = default;
+ double avg = 0;
+ getter(ref avg);
+ foldGetter(ref fold);
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold));
- // Get the average.
- b = cursor.MoveNext();
- Assert.True(b);
- double avg = 0;
- getter(ref avg);
- foldGetter(ref fold);
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold));
+ // Get the standard deviation.
+ b = cursor.MoveNext();
+ Assert.True(b);
+ double stdev = 0;
+ getter(ref stdev);
+ foldGetter(ref fold);
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold));
+ Assert.Equal(0.015, stdev, 3);
- // Get the standard deviation.
+ double sum = 0;
+ double val = 0;
+ for (int f = 0; f < 2; f++)
+ {
b = cursor.MoveNext();
Assert.True(b);
- double stdev = 0;
- getter(ref stdev);
+ getter(ref val);
foldGetter(ref fold);
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold));
- Assert.Equal(0.025, stdev, 3);
-
- double sum = 0;
- double val = 0;
- for (int f = 0; f < 2; f++)
- {
- b = cursor.MoveNext();
- Assert.True(b);
- getter(ref val);
- foldGetter(ref fold);
- sum += val;
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold));
- }
- Assert.Equal(avg, sum / 2);
- b = cursor.MoveNext();
- Assert.False(b);
+ sum += val;
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold));
}
+ Assert.Equal(avg, sum / 2);
+ b = cursor.MoveNext();
+ Assert.False(b);
+ }
- var confusion = experiment.GetOutput(crossValidateOutput.ConfusionMatrix);
- schema = confusion.Schema;
- b = schema.TryGetColumnIndex("Count", out int countCol);
- Assert.True(b);
- b = schema.TryGetColumnIndex("Fold Index", out foldCol);
- Assert.True(b);
- var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol);
- Assert.True(type is VectorType vecType && vecType.ItemType is TextType && vecType.Size == 10);
- var slotNames = default(VBuffer>);
- schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref slotNames);
- Assert.True(slotNames.Values.Select((s, i) => ReadOnlyMemoryUtils.EqualsStr(i.ToString(), s)).All(x => x));
- using (var curs = confusion.GetRowCursor(col => true))
- {
- var countGetter = curs.GetGetter>(countCol);
- var foldGetter = curs.GetGetter>(foldCol);
- var confCount = default(VBuffer);
- var foldIndex = default(ReadOnlyMemory);
- int rowCount = 0;
- var foldCur = "Fold 0";
- while (curs.MoveNext())
+ var confusion = experiment.GetOutput(crossValidateOutput.ConfusionMatrix);
+ schema = confusion.Schema;
+ b = schema.TryGetColumnIndex("Count", out int countCol);
+ Assert.True(b);
+ b = schema.TryGetColumnIndex("Fold Index", out foldCol);
+ Assert.True(b);
+ var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol);
+ Assert.True(type is VectorType vecType && vecType.ItemType is TextType && vecType.Size == 10);
+ var slotNames = default(VBuffer>);
+ schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref slotNames);
+ Assert.True(slotNames.Values.Select((s, i) => ReadOnlyMemoryUtils.EqualsStr(i.ToString(), s)).All(x => x));
+ using (var curs = confusion.GetRowCursor(col => true))
+ {
+ var countGetter = curs.GetGetter>(countCol);
+ var foldGetter = curs.GetGetter>(foldCol);
+ var confCount = default(VBuffer);
+ var foldIndex = default(ReadOnlyMemory);
+ int rowCount = 0;
+ var foldCur = "Fold 0";
+ while (curs.MoveNext())
+ {
+ countGetter(ref confCount);
+ foldGetter(ref foldIndex);
+ rowCount++;
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr(foldCur, foldIndex));
+ if (rowCount == 10)
{
- countGetter(ref confCount);
- foldGetter(ref foldIndex);
- rowCount++;
- Assert.True(ReadOnlyMemoryUtils.EqualsStr(foldCur, foldIndex));
- if (rowCount == 10)
- {
- rowCount = 0;
- foldCur = "Fold 1";
- }
+ rowCount = 0;
+ foldCur = "Fold 1";
}
- Assert.Equal(0, rowCount);
}
-
- var warnings = experiment.GetOutput(crossValidateOutput.Warnings);
- using (var cursor = warnings.GetRowCursor(col => true))
- Assert.False(cursor.MoveNext());
+ Assert.Equal(0, rowCount);
}
+
+ var warnings = experiment.GetOutput(crossValidateOutput.Warnings);
+ using (var cursor = warnings.GetRowCursor(col => true))
+ Assert.False(cursor.MoveNext());
}
[Fact]
public void TestCrossValidationMacroMultiClassWithWarnings()
{
var dataPath = GetDataPath(@"Train-Tiny-28x28.txt");
- using (var env = new ConsoleEnvironment(42))
- {
- var subGraph = env.CreateExperiment();
+ var env = new MLContext(42);
+ var subGraph = env.CreateExperiment();
- var nop = new Legacy.Transforms.NoOperation();
- var nopOutput = subGraph.Add(nop);
+ var nop = new Legacy.Transforms.NoOperation();
+ var nopOutput = subGraph.Add(nop);
- var learnerInput = new Legacy.Trainers.LogisticRegressionClassifier
- {
- TrainingData = nopOutput.OutputData,
- NumThreads = 1
- };
- var learnerOutput = subGraph.Add(learnerInput);
-
- var experiment = env.CreateExperiment();
- var importInput = new Legacy.Data.TextLoader(dataPath);
- var importOutput = experiment.Add(importInput);
-
- var filter = new Legacy.Transforms.RowRangeFilter();
- filter.Data = importOutput.Data;
- filter.Column = "Label";
- filter.Min = 0;
- filter.Max = 5;
- var filterOutput = experiment.Add(filter);
-
- var term = new Legacy.Transforms.TextToKeyConverter();
- term.Column = new[]
- {
+ var learnerInput = new Legacy.Trainers.LogisticRegressionClassifier
+ {
+ TrainingData = nopOutput.OutputData,
+ NumThreads = 1
+ };
+ var learnerOutput = subGraph.Add(learnerInput);
+
+ var experiment = env.CreateExperiment();
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ var importOutput = experiment.Add(importInput);
+
+ var filter = new Legacy.Transforms.RowRangeFilter();
+ filter.Data = importOutput.Data;
+ filter.Column = "Label";
+ filter.Min = 0;
+ filter.Max = 5;
+ var filterOutput = experiment.Add(filter);
+
+ var term = new Legacy.Transforms.TextToKeyConverter();
+ term.Column = new[]
+ {
new Legacy.Transforms.ValueToKeyMappingTransformerColumn()
{
Source = "Label", Name = "Strat", Sort = Legacy.Transforms.ValueToKeyMappingTransformerSortOrder.Value
}
};
- term.Data = filterOutput.OutputData;
- var termOutput = experiment.Add(term);
+ term.Data = filterOutput.OutputData;
+ var termOutput = experiment.Add(term);
- var crossValidate = new Legacy.Models.CrossValidator
- {
- Data = termOutput.OutputData,
- Nodes = subGraph,
- Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer,
- TransformModel = null,
- StratificationColumn = "Strat"
- };
- crossValidate.Inputs.Data = nop.Data;
- crossValidate.Outputs.PredictorModel = learnerOutput.PredictorModel;
- var crossValidateOutput = experiment.Add(crossValidate);
-
- experiment.Compile();
- importInput.SetInput(env, experiment);
- experiment.Run();
- var warnings = experiment.GetOutput(crossValidateOutput.Warnings);
+ var crossValidate = new Legacy.Models.CrossValidator
+ {
+ Data = termOutput.OutputData,
+ Nodes = subGraph,
+ Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer,
+ TransformModel = null,
+ StratificationColumn = "Strat"
+ };
+ crossValidate.Inputs.Data = nop.Data;
+ crossValidate.Outputs.PredictorModel = learnerOutput.PredictorModel;
+ var crossValidateOutput = experiment.Add(crossValidate);
+
+ experiment.Compile();
+ importInput.SetInput(env, experiment);
+ experiment.Run();
+ var warnings = experiment.GetOutput(crossValidateOutput.Warnings);
+
+ var schema = warnings.Schema;
+ var b = schema.TryGetColumnIndex("WarningText", out int warningCol);
+ Assert.True(b);
+ using (var cursor = warnings.GetRowCursor(col => col == warningCol))
+ {
+ var getter = cursor.GetGetter>(warningCol);
- var schema = warnings.Schema;
- var b = schema.TryGetColumnIndex("WarningText", out int warningCol);
+ b = cursor.MoveNext();
Assert.True(b);
- using (var cursor = warnings.GetRowCursor(col => col == warningCol))
- {
- var getter = cursor.GetGetter>(warningCol);
-
- b = cursor.MoveNext();
- Assert.True(b);
- var warning = default(ReadOnlyMemory);
- getter(ref warning);
- Assert.Contains("test instances with class values not seen in the training set.", warning.ToString());
- b = cursor.MoveNext();
- Assert.True(b);
- getter(ref warning);
- Assert.Contains("Detected columns of variable length: SortedScores, SortedClasses", warning.ToString());
- b = cursor.MoveNext();
- Assert.False(b);
- }
+ var warning = default(ReadOnlyMemory);
+ getter(ref warning);
+ Assert.Contains("test instances with class values not seen in the training set.", warning.ToString());
+ b = cursor.MoveNext();
+ Assert.True(b);
+ getter(ref warning);
+ Assert.Contains("Detected columns of variable length: SortedScores, SortedClasses", warning.ToString());
+ b = cursor.MoveNext();
+ Assert.False(b);
}
}
@@ -620,95 +606,93 @@ public void TestCrossValidationMacroMultiClassWithWarnings()
public void TestCrossValidationMacroWithStratification()
{
var dataPath = GetDataPath(@"breast-cancer.txt");
- using (var env = new ConsoleEnvironment(42))
- {
- var subGraph = env.CreateExperiment();
+ var env = new MLContext(42);
+ var subGraph = env.CreateExperiment();
- var nop = new Legacy.Transforms.NoOperation();
- var nopOutput = subGraph.Add(nop);
+ var nop = new Legacy.Transforms.NoOperation();
+ var nopOutput = subGraph.Add(nop);
- var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier
- {
- TrainingData = nopOutput.OutputData,
- NumThreads = 1
- };
- var learnerOutput = subGraph.Add(learnerInput);
-
- var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
- {
- TransformModels = new ArrayVar(nopOutput.Model),
- PredictorModel = learnerOutput.PredictorModel
- };
- var modelCombineOutput = subGraph.Add(modelCombine);
+ var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier
+ {
+ TrainingData = nopOutput.OutputData,
+ NumThreads = 1
+ };
+ var learnerOutput = subGraph.Add(learnerInput);
- var experiment = env.CreateExperiment();
- var importInput = new Legacy.Data.TextLoader(dataPath);
- importInput.Arguments.Column = new Legacy.Data.TextLoaderColumn[]
- {
+ var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
+ {
+ TransformModels = new ArrayVar(nopOutput.Model),
+ PredictorModel = learnerOutput.PredictorModel
+ };
+ var modelCombineOutput = subGraph.Add(modelCombine);
+
+ var experiment = env.CreateExperiment();
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ importInput.Arguments.Column = new Legacy.Data.TextLoaderColumn[]
+ {
new Legacy.Data.TextLoaderColumn { Name = "Label", Source = new[] { new Legacy.Data.TextLoaderRange(0) } },
new Legacy.Data.TextLoaderColumn { Name = "Strat", Source = new[] { new Legacy.Data.TextLoaderRange(1) } },
new Legacy.Data.TextLoaderColumn { Name = "Features", Source = new[] { new Legacy.Data.TextLoaderRange(2, 9) } }
- };
- var importOutput = experiment.Add(importInput);
+ };
+ var importOutput = experiment.Add(importInput);
- var crossValidate = new Legacy.Models.CrossValidator
- {
- Data = importOutput.Data,
- Nodes = subGraph,
- TransformModel = null,
- StratificationColumn = "Strat"
- };
- crossValidate.Inputs.Data = nop.Data;
- crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
- var crossValidateOutput = experiment.Add(crossValidate);
- experiment.Compile();
- experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
- experiment.Run();
- var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);
-
- var schema = data.Schema;
- var b = schema.TryGetColumnIndex("AUC", out int metricCol);
- Assert.True(b);
- b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
+ var crossValidate = new Legacy.Models.CrossValidator
+ {
+ Data = importOutput.Data,
+ Nodes = subGraph,
+ TransformModel = null,
+ StratificationColumn = "Strat"
+ };
+ crossValidate.Inputs.Data = nop.Data;
+ crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
+ var crossValidateOutput = experiment.Add(crossValidate);
+ experiment.Compile();
+ experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
+ experiment.Run();
+ var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);
+
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex("AUC", out int metricCol);
+ Assert.True(b);
+ b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
+ Assert.True(b);
+ using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
+ {
+ var getter = cursor.GetGetter(metricCol);
+ var foldGetter = cursor.GetGetter>(foldCol);
+ ReadOnlyMemory fold = default;
+
+ // Get the verage.
+ b = cursor.MoveNext();
Assert.True(b);
- using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
- {
- var getter = cursor.GetGetter(metricCol);
- var foldGetter = cursor.GetGetter>(foldCol);
- ReadOnlyMemory fold = default;
+ double avg = 0;
+ getter(ref avg);
+ foldGetter(ref fold);
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold));
- // Get the verage.
- b = cursor.MoveNext();
- Assert.True(b);
- double avg = 0;
- getter(ref avg);
- foldGetter(ref fold);
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold));
+ // Get the standard deviation.
+ b = cursor.MoveNext();
+ Assert.True(b);
+ double stdev = 0;
+ getter(ref stdev);
+ foldGetter(ref fold);
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold));
+ Assert.Equal(0.00488, stdev, 5);
- // Get the standard deviation.
+ double sum = 0;
+ double val = 0;
+ for (int f = 0; f < 2; f++)
+ {
b = cursor.MoveNext();
Assert.True(b);
- double stdev = 0;
- getter(ref stdev);
+ getter(ref val);
foldGetter(ref fold);
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold));
- Assert.Equal(0.00485, stdev, 5);
-
- double sum = 0;
- double val = 0;
- for (int f = 0; f < 2; f++)
- {
- b = cursor.MoveNext();
- Assert.True(b);
- getter(ref val);
- foldGetter(ref fold);
- sum += val;
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold));
- }
- Assert.Equal(avg, sum / 2);
- b = cursor.MoveNext();
- Assert.False(b);
+ sum += val;
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold));
}
+ Assert.Equal(avg, sum / 2);
+ b = cursor.MoveNext();
+ Assert.False(b);
}
}
@@ -716,127 +700,125 @@ public void TestCrossValidationMacroWithStratification()
public void TestCrossValidationMacroWithNonDefaultNames()
{
string dataPath = GetDataPath(@"adult.tiny.with-schema.txt");
- using (var env = new ConsoleEnvironment(42))
- {
- var subGraph = env.CreateExperiment();
+ var env = new MLContext(42);
+ var subGraph = env.CreateExperiment();
- var textToKey = new Legacy.Transforms.TextToKeyConverter();
- textToKey.Column = new[] { new Legacy.Transforms.ValueToKeyMappingTransformerColumn() { Name = "Label1", Source = "Label" } };
- var textToKeyOutput = subGraph.Add(textToKey);
+ var textToKey = new Legacy.Transforms.TextToKeyConverter();
+ textToKey.Column = new[] { new Legacy.Transforms.ValueToKeyMappingTransformerColumn() { Name = "Label1", Source = "Label" } };
+ var textToKeyOutput = subGraph.Add(textToKey);
- var hash = new Legacy.Transforms.HashConverter();
- hash.Column = new[] { new Legacy.Transforms.HashJoiningTransformColumn() { Name = "GroupId1", Source = "Workclass" } };
- hash.Data = textToKeyOutput.OutputData;
- var hashOutput = subGraph.Add(hash);
+ var hash = new Legacy.Transforms.HashConverter();
+ hash.Column = new[] { new Legacy.Transforms.HashJoiningTransformColumn() { Name = "GroupId1", Source = "Workclass" } };
+ hash.Data = textToKeyOutput.OutputData;
+ var hashOutput = subGraph.Add(hash);
- var learnerInput = new Legacy.Trainers.FastTreeRanker
- {
- TrainingData = hashOutput.OutputData,
- NumThreads = 1,
- LabelColumn = "Label1",
- GroupIdColumn = "GroupId1"
- };
- var learnerOutput = subGraph.Add(learnerInput);
-
- var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
- {
- TransformModels = new ArrayVar(textToKeyOutput.Model, hashOutput.Model),
- PredictorModel = learnerOutput.PredictorModel
- };
- var modelCombineOutput = subGraph.Add(modelCombine);
-
- var experiment = env.CreateExperiment();
- var importInput = new Legacy.Data.TextLoader(dataPath);
- importInput.Arguments.HasHeader = true;
- importInput.Arguments.Column = new TextLoaderColumn[]
- {
+ var learnerInput = new Legacy.Trainers.FastTreeRanker
+ {
+ TrainingData = hashOutput.OutputData,
+ NumThreads = 1,
+ LabelColumn = "Label1",
+ GroupIdColumn = "GroupId1"
+ };
+ var learnerOutput = subGraph.Add(learnerInput);
+
+ var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
+ {
+ TransformModels = new ArrayVar(textToKeyOutput.Model, hashOutput.Model),
+ PredictorModel = learnerOutput.PredictorModel
+ };
+ var modelCombineOutput = subGraph.Add(modelCombine);
+
+ var experiment = env.CreateExperiment();
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ importInput.Arguments.HasHeader = true;
+ importInput.Arguments.Column = new TextLoaderColumn[]
+ {
new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
new TextLoaderColumn { Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = Legacy.Data.DataKind.Text },
new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(9, 14) } }
- };
- var importOutput = experiment.Add(importInput);
+ };
+ var importOutput = experiment.Add(importInput);
- var crossValidate = new Legacy.Models.CrossValidator
- {
- Data = importOutput.Data,
- Nodes = subGraph,
- TransformModel = null,
- LabelColumn = "Label1",
- GroupColumn = "GroupId1",
- NameColumn = "Workclass",
- Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureRankerTrainer
- };
- crossValidate.Inputs.Data = textToKey.Data;
- crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
- var crossValidateOutput = experiment.Add(crossValidate);
- experiment.Compile();
- experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
- experiment.Run();
- var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);
-
- var schema = data.Schema;
- var b = schema.TryGetColumnIndex("NDCG", out int metricCol);
+ var crossValidate = new Legacy.Models.CrossValidator
+ {
+ Data = importOutput.Data,
+ Nodes = subGraph,
+ TransformModel = null,
+ LabelColumn = "Label1",
+ GroupColumn = "GroupId1",
+ NameColumn = "Workclass",
+ Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureRankerTrainer
+ };
+ crossValidate.Inputs.Data = textToKey.Data;
+ crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
+ var crossValidateOutput = experiment.Add(crossValidate);
+ experiment.Compile();
+ experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
+ experiment.Run();
+ var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);
+
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex("NDCG", out int metricCol);
+ Assert.True(b);
+ b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
+ Assert.True(b);
+ using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
+ {
+ var getter = cursor.GetGetter>(metricCol);
+ var foldGetter = cursor.GetGetter>(foldCol);
+ ReadOnlyMemory fold = default;
+
+ // Get the verage.
+ b = cursor.MoveNext();
Assert.True(b);
- b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
+ var avg = default(VBuffer);
+ getter(ref avg);
+ foldGetter(ref fold);
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold));
+
+ // Get the standard deviation.
+ b = cursor.MoveNext();
Assert.True(b);
- using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
+ var stdev = default(VBuffer);
+ getter(ref stdev);
+ foldGetter(ref fold);
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold));
+ Assert.Equal(2.462, stdev.Values[0], 3);
+ Assert.Equal(2.763, stdev.Values[1], 3);
+ Assert.Equal(3.273, stdev.Values[2], 3);
+
+ var sumBldr = new BufferBuilder(R8Adder.Instance);
+ sumBldr.Reset(avg.Length, true);
+ var val = default(VBuffer);
+ for (int f = 0; f < 2; f++)
{
- var getter = cursor.GetGetter>(metricCol);
- var foldGetter = cursor.GetGetter>(foldCol);
- ReadOnlyMemory fold = default;
-
- // Get the verage.
b = cursor.MoveNext();
Assert.True(b);
- var avg = default(VBuffer);
- getter(ref avg);
+ getter(ref val);
foldGetter(ref fold);
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold));
-
- // Get the standard deviation.
- b = cursor.MoveNext();
- Assert.True(b);
- var stdev = default(VBuffer);
- getter(ref stdev);
- foldGetter(ref fold);
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold));
- Assert.Equal(2.462, stdev.Values[0], 3);
- Assert.Equal(2.763, stdev.Values[1], 3);
- Assert.Equal(3.273, stdev.Values[2], 3);
-
- var sumBldr = new BufferBuilder(R8Adder.Instance);
- sumBldr.Reset(avg.Length, true);
- var val = default(VBuffer);
- for (int f = 0; f < 2; f++)
- {
- b = cursor.MoveNext();
- Assert.True(b);
- getter(ref val);
- foldGetter(ref fold);
- sumBldr.AddFeatures(0, in val);
- Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold));
- }
- var sum = default(VBuffer);
- sumBldr.GetResult(ref sum);
- for (int i = 0; i < avg.Length; i++)
- Assert.Equal(avg.Values[i], sum.Values[i] / 2);
- b = cursor.MoveNext();
- Assert.False(b);
+ sumBldr.AddFeatures(0, in val);
+ Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold));
}
+ var sum = default(VBuffer);
+ sumBldr.GetResult(ref sum);
+ for (int i = 0; i < avg.Length; i++)
+ Assert.Equal(avg.Values[i], sum.Values[i] / 2);
+ b = cursor.MoveNext();
+ Assert.False(b);
+ }
- data = experiment.GetOutput(crossValidateOutput.PerInstanceMetrics);
- Assert.True(data.Schema.TryGetColumnIndex("Instance", out int nameCol));
- using (var cursor = data.GetRowCursor(col => col == nameCol))
- {
- var getter = cursor.GetGetter>(nameCol);
- while (cursor.MoveNext())
- {
- ReadOnlyMemory name = default;
- getter(ref name);
- Assert.Subset(new HashSet() { "Private", "?", "Federal-gov" }, new HashSet() { name.ToString() });
- if (cursor.Position > 4)
- break;
- }
+ data = experiment.GetOutput(crossValidateOutput.PerInstanceMetrics);
+ Assert.True(data.Schema.TryGetColumnIndex("Instance", out int nameCol));
+ using (var cursor = data.GetRowCursor(col => col == nameCol))
+ {
+ var getter = cursor.GetGetter>(nameCol);
+ while (cursor.MoveNext())
+ {
+ ReadOnlyMemory name = default;
+ getter(ref name);
+ Assert.Subset(new HashSet() { "Private", "?", "Federal-gov" }, new HashSet() { name.ToString() });
+ if (cursor.Position > 4)
+ break;
}
}
}
@@ -845,58 +827,56 @@ public void TestCrossValidationMacroWithNonDefaultNames()
public void TestOvaMacro()
{
var dataPath = GetDataPath(@"iris.txt");
- using (var env = new ConsoleEnvironment(42))
- {
- // Specify subgraph for OVA
- var subGraph = env.CreateExperiment();
- var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier { NumThreads = 1 };
- var learnerOutput = subGraph.Add(learnerInput);
- // Create pipeline with OVA and multiclass scoring.
- var experiment = env.CreateExperiment();
- var importInput = new Legacy.Data.TextLoader(dataPath);
- importInput.Arguments.Column = new TextLoaderColumn[]
- {
+ var env = new MLContext(42);
+ // Specify subgraph for OVA
+ var subGraph = env.CreateExperiment();
+ var learnerInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier { NumThreads = 1 };
+ var learnerOutput = subGraph.Add(learnerInput);
+ // Create pipeline with OVA and multiclass scoring.
+ var experiment = env.CreateExperiment();
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ importInput.Arguments.Column = new TextLoaderColumn[]
+ {
new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } }
- };
- var importOutput = experiment.Add(importInput);
- var oneVersusAll = new Legacy.Models.OneVersusAll
- {
- TrainingData = importOutput.Data,
- Nodes = subGraph,
- UseProbabilities = true,
- };
- var ovaOutput = experiment.Add(oneVersusAll);
- var scoreInput = new Legacy.Transforms.DatasetScorer
- {
- Data = importOutput.Data,
- PredictorModel = ovaOutput.PredictorModel
- };
- var scoreOutput = experiment.Add(scoreInput);
- var evalInput = new Legacy.Models.ClassificationEvaluator
- {
- Data = scoreOutput.ScoredData
- };
- var evalOutput = experiment.Add(evalInput);
- experiment.Compile();
- experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
- experiment.Run();
-
- var data = experiment.GetOutput(evalOutput.OverallMetrics);
- var schema = data.Schema;
- var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
+ };
+ var importOutput = experiment.Add(importInput);
+ var oneVersusAll = new Legacy.Models.OneVersusAll
+ {
+ TrainingData = importOutput.Data,
+ Nodes = subGraph,
+ UseProbabilities = true,
+ };
+ var ovaOutput = experiment.Add(oneVersusAll);
+ var scoreInput = new Legacy.Transforms.DatasetScorer
+ {
+ Data = importOutput.Data,
+ PredictorModel = ovaOutput.PredictorModel
+ };
+ var scoreOutput = experiment.Add(scoreInput);
+ var evalInput = new Legacy.Models.ClassificationEvaluator
+ {
+ Data = scoreOutput.ScoredData
+ };
+ var evalOutput = experiment.Add(evalInput);
+ experiment.Compile();
+ experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
+ experiment.Run();
+
+ var data = experiment.GetOutput(evalOutput.OverallMetrics);
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
+ Assert.True(b);
+ using (var cursor = data.GetRowCursor(col => col == accCol))
+ {
+ var getter = cursor.GetGetter(accCol);
+ b = cursor.MoveNext();
Assert.True(b);
- using (var cursor = data.GetRowCursor(col => col == accCol))
- {
- var getter = cursor.GetGetter(accCol);
- b = cursor.MoveNext();
- Assert.True(b);
- double acc = 0;
- getter(ref acc);
- Assert.Equal(0.96, acc, 2);
- b = cursor.MoveNext();
- Assert.False(b);
- }
+ double acc = 0;
+ getter(ref acc);
+ Assert.Equal(0.96, acc, 2);
+ b = cursor.MoveNext();
+ Assert.False(b);
}
}
@@ -904,58 +884,56 @@ public void TestOvaMacro()
public void TestOvaMacroWithUncalibratedLearner()
{
var dataPath = GetDataPath(@"iris.txt");
- using (var env = new ConsoleEnvironment(42))
- {
- // Specify subgraph for OVA
- var subGraph = env.CreateExperiment();
- var learnerInput = new Legacy.Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false };
- var learnerOutput = subGraph.Add(learnerInput);
- // Create pipeline with OVA and multiclass scoring.
- var experiment = env.CreateExperiment();
- var importInput = new Legacy.Data.TextLoader(dataPath);
- importInput.Arguments.Column = new TextLoaderColumn[]
- {
+ var env = new MLContext(42);
+ // Specify subgraph for OVA
+ var subGraph = env.CreateExperiment();
+ var learnerInput = new Legacy.Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false };
+ var learnerOutput = subGraph.Add(learnerInput);
+ // Create pipeline with OVA and multiclass scoring.
+ var experiment = env.CreateExperiment();
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ importInput.Arguments.Column = new TextLoaderColumn[]
+ {
new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } }
- };
- var importOutput = experiment.Add(importInput);
- var oneVersusAll = new Legacy.Models.OneVersusAll
- {
- TrainingData = importOutput.Data,
- Nodes = subGraph,
- UseProbabilities = true,
- };
- var ovaOutput = experiment.Add(oneVersusAll);
- var scoreInput = new Legacy.Transforms.DatasetScorer
- {
- Data = importOutput.Data,
- PredictorModel = ovaOutput.PredictorModel
- };
- var scoreOutput = experiment.Add(scoreInput);
- var evalInput = new Legacy.Models.ClassificationEvaluator
- {
- Data = scoreOutput.ScoredData
- };
- var evalOutput = experiment.Add(evalInput);
- experiment.Compile();
- experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
- experiment.Run();
-
- var data = experiment.GetOutput(evalOutput.OverallMetrics);
- var schema = data.Schema;
- var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
+ };
+ var importOutput = experiment.Add(importInput);
+ var oneVersusAll = new Legacy.Models.OneVersusAll
+ {
+ TrainingData = importOutput.Data,
+ Nodes = subGraph,
+ UseProbabilities = true,
+ };
+ var ovaOutput = experiment.Add(oneVersusAll);
+ var scoreInput = new Legacy.Transforms.DatasetScorer
+ {
+ Data = importOutput.Data,
+ PredictorModel = ovaOutput.PredictorModel
+ };
+ var scoreOutput = experiment.Add(scoreInput);
+ var evalInput = new Legacy.Models.ClassificationEvaluator
+ {
+ Data = scoreOutput.ScoredData
+ };
+ var evalOutput = experiment.Add(evalInput);
+ experiment.Compile();
+ experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
+ experiment.Run();
+
+ var data = experiment.GetOutput(evalOutput.OverallMetrics);
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
+ Assert.True(b);
+ using (var cursor = data.GetRowCursor(col => col == accCol))
+ {
+ var getter = cursor.GetGetter(accCol);
+ b = cursor.MoveNext();
Assert.True(b);
- using (var cursor = data.GetRowCursor(col => col == accCol))
- {
- var getter = cursor.GetGetter(accCol);
- b = cursor.MoveNext();
- Assert.True(b);
- double acc = 0;
- getter(ref acc);
- Assert.Equal(0.71, acc, 2);
- b = cursor.MoveNext();
- Assert.False(b);
- }
+ double acc = 0;
+ getter(ref acc);
+ Assert.Equal(0.71, acc, 2);
+ b = cursor.MoveNext();
+ Assert.False(b);
}
}
@@ -963,37 +941,35 @@ public void TestOvaMacroWithUncalibratedLearner()
public void TestTensorFlowEntryPoint()
{
var dataPath = GetDataPath("Train-Tiny-28x28.txt");
- using (var env = new ConsoleEnvironment(42))
- {
- var experiment = env.CreateExperiment();
+ var env = new MLContext(42);
+ var experiment = env.CreateExperiment();
- var importInput = new Legacy.Data.TextLoader(dataPath);
- importInput.Arguments.Column = new TextLoaderColumn[]
- {
+ var importInput = new Legacy.Data.TextLoader(dataPath);
+ importInput.Arguments.Column = new TextLoaderColumn[]
+ {
new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
new TextLoaderColumn { Name = "Placeholder", Source = new[] { new TextLoaderRange(1, 784) } }
- };
- var importOutput = experiment.Add(importInput);
+ };
+ var importOutput = experiment.Add(importInput);
- var tfTransformInput = new Legacy.Transforms.TensorFlowScorer
- {
- Data = importOutput.Data,
- ModelLocation = "mnist_model/frozen_saved_model.pb",
- InputColumns = new[] { "Placeholder" },
- OutputColumns = new[] { "Softmax" },
- };
- var tfTransformOutput = experiment.Add(tfTransformInput);
-
- experiment.Compile();
- experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
- experiment.Run();
- var data = experiment.GetOutput(tfTransformOutput.OutputData);
-
- var schema = data.Schema;
- Assert.Equal(3, schema.ColumnCount);
- Assert.Equal("Softmax", schema.GetColumnName(2));
- Assert.Equal(10, (schema.GetColumnType(2) as VectorType)?.Size);
- }
+ var tfTransformInput = new Legacy.Transforms.TensorFlowScorer
+ {
+ Data = importOutput.Data,
+ ModelLocation = "mnist_model/frozen_saved_model.pb",
+ InputColumns = new[] { "Placeholder" },
+ OutputColumns = new[] { "Softmax" },
+ };
+ var tfTransformOutput = experiment.Add(tfTransformInput);
+
+ experiment.Compile();
+ experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
+ experiment.Run();
+ var data = experiment.GetOutput(tfTransformOutput.OutputData);
+
+ var schema = data.Schema;
+ Assert.Equal(3, schema.ColumnCount);
+ Assert.Equal("Softmax", schema.GetColumnName(2));
+ Assert.Equal(10, (schema.GetColumnType(2) as VectorType)?.Size);
}
}
}
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestContracts.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestContracts.cs
index 6b4fd5297f..ca1315b83b 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestContracts.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestContracts.cs
@@ -3,8 +3,6 @@
// See the LICENSE file in the project root for more information.
using System;
-using Microsoft.ML.Runtime;
-using Microsoft.ML.Runtime.Data;
using Xunit;
namespace Microsoft.ML.Runtime.RunTests
{
@@ -42,7 +40,7 @@ private void Helper(IExceptionContext ectx, MessageSensitivity expected)
[Fact]
public void ExceptionSensitivity()
{
- var env = new ConsoleEnvironment();
+ var env = new MLContext();
// Default sensitivity should be unknown, that is, all bits set.
Helper(null, MessageSensitivity.Unknown);
// If we set it to be not sensitive, then the messages should be marked insensitive,
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs
index 3798246fa8..688287dda2 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs
@@ -13,7 +13,7 @@ public sealed class TestEarlyStoppingCriteria
{
private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter)
{
- var env = new ConsoleEnvironment()
+ var env = new MLContext()
.AddStandardComponents();
var sub = new SubComponent(name, args);
return sub.CreateInstance(env, lowerIsBetter);
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs
index ce01945bcb..2f1df54285 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestHosts.cs
@@ -20,7 +20,7 @@ public class TestHosts
[Fact]
public void TestCancellation()
{
- var env = new ConsoleEnvironment(seed: 42);
+ IHostEnvironment env = new MLContext(seed: 42);
for (int z = 0; z < 1000; z++)
{
var mainHost = env.Register("Main");
diff --git a/test/Microsoft.ML.InferenceTesting/Microsoft.ML.InferenceTesting.csproj b/test/Microsoft.ML.InferenceTesting/Microsoft.ML.InferenceTesting.csproj
deleted file mode 100644
index 60b52497a0..0000000000
--- a/test/Microsoft.ML.InferenceTesting/Microsoft.ML.InferenceTesting.csproj
+++ /dev/null
@@ -1,22 +0,0 @@
-
-
-
- true
- CORECLR
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs
index 00927936b1..9c8bc34a2f 100644
--- a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs
+++ b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs
@@ -166,42 +166,40 @@ public void OnnxStatic()
var modelFile = "squeezenet/00000001/model.onnx";
- using (var env = new ConsoleEnvironment(null, false, 0, 1, null, null))
+ var env = new MLContext(conc: 1);
+ var imageHeight = 224;
+ var imageWidth = 224;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+
+ var data = TextLoader.CreateReader(env, ctx => (
+ imagePath: ctx.LoadText(0),
+ name: ctx.LoadText(1)))
+ .Read(dataFile);
+
+ // Note that CamelCase column names are there to match the TF graph node names.
+ var pipe = data.MakeNewEstimator()
+ .Append(row => (
+ row.name,
+ data_0: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true)))
+ .Append(row => (row.name, softmaxout_1: row.data_0.ApplyOnnxModel(modelFile)));
+
+ TestEstimatorCore(pipe.AsDynamic, data.AsDynamic);
+
+ var result = pipe.Fit(data).Transform(data).AsDynamic;
+ result.Schema.TryGetColumnIndex("softmaxout_1", out int output);
+ using (var cursor = result.GetRowCursor(col => col == output))
{
- var imageHeight = 224;
- var imageWidth = 224;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
-
- var data = TextLoader.CreateReader(env, ctx => (
- imagePath: ctx.LoadText(0),
- name: ctx.LoadText(1)))
- .Read(dataFile);
-
- // Note that CamelCase column names are there to match the TF graph node names.
- var pipe = data.MakeNewEstimator()
- .Append(row => (
- row.name,
- data_0: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true)))
- .Append(row => (row.name, softmaxout_1: row.data_0.ApplyOnnxModel(modelFile)));
-
- TestEstimatorCore(pipe.AsDynamic, data.AsDynamic);
-
- var result = pipe.Fit(data).Transform(data).AsDynamic;
- result.Schema.TryGetColumnIndex("softmaxout_1", out int output);
- using (var cursor = result.GetRowCursor(col => col == output))
+ var buffer = default(VBuffer);
+ var getter = cursor.GetGetter>(output);
+ var numRows = 0;
+ while (cursor.MoveNext())
{
- var buffer = default(VBuffer);
- var getter = cursor.GetGetter>(output);
- var numRows = 0;
- while (cursor.MoveNext())
- {
- getter(ref buffer);
- Assert.Equal(1000, buffer.Length);
- numRows += 1;
- }
- Assert.Equal(3, numRows);
+ getter(ref buffer);
+ Assert.Equal(1000, buffer.Length);
+ numRows += 1;
}
+ Assert.Equal(3, numRows);
}
}
@@ -211,11 +209,9 @@ void TestCommandLine()
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return;
- using (var env = new ConsoleEnvironment())
- {
- var x = Maml.Main(new[] { @"showschema loader=Text{col=data_0:R4:0-150527} xf=Onnx{InputColumn=data_0 OutputColumn=softmaxout_1 model={squeezenet/00000001/model.onnx}}" });
- Assert.Equal(0, x);
- }
+ var env = new MLContext();
+ var x = Maml.Main(new[] { @"showschema loader=Text{col=data_0:R4:0-150527} xf=Onnx{InputColumn=data_0 OutputColumn=softmaxout_1 model={squeezenet/00000001/model.onnx}}" });
+ Assert.Equal(0, x);
}
}
}
diff --git a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLine.cs b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLine.cs
index 0839a57ed3..3f4fa341dd 100644
--- a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLine.cs
+++ b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLine.cs
@@ -97,7 +97,7 @@ public void CmdParsingSingle()
///
private static void Init(IndentingTextWriter wrt, object defaults)
{
- var env = new ConsoleEnvironment(seed: 42);
+ var env = new MLContext(seed: 42);
wrt.WriteLine("Usage:");
wrt.WriteLine(CmdParser.ArgumentsUsage(env, defaults.GetType(), defaults, false, 200));
}
@@ -107,7 +107,7 @@ private static void Init(IndentingTextWriter wrt, object defaults)
///
private static void Process(IndentingTextWriter wrt, string text, ArgsBase defaults)
{
- var env = new ConsoleEnvironment(seed: 42);
+ var env = new MLContext(seed: 42);
using (wrt.Nest())
{
var args1 = defaults.Clone();
diff --git a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs
index 6d95829454..51330648ee 100644
--- a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs
+++ b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs
@@ -19,7 +19,7 @@ public class CmdLineReverseTests
[TestCategory("Cmd Parsing")]
public void ArgumentParseTest()
{
- var env = new ConsoleEnvironment(seed: 42);
+ var env = new MLContext(seed: 42);
var innerArg1 = new SimpleArg()
{
required = -2,
diff --git a/test/Microsoft.ML.InferenceTesting/InferRecipesCommand.cs b/test/Microsoft.ML.Predictor.Tests/Commands/InferRecipesCommand.cs
similarity index 96%
rename from test/Microsoft.ML.InferenceTesting/InferRecipesCommand.cs
rename to test/Microsoft.ML.Predictor.Tests/Commands/InferRecipesCommand.cs
index a6dbae8248..d84b6f9a7d 100644
--- a/test/Microsoft.ML.InferenceTesting/InferRecipesCommand.cs
+++ b/test/Microsoft.ML.Predictor.Tests/Commands/InferRecipesCommand.cs
@@ -22,8 +22,9 @@ namespace Microsoft.ML.Runtime.MLTesting.Inference
/// This command generates a suggested RSP to load the text file and recipes it prior to training.
/// The results are output to the console and also to the RSP file, if it's specified.
///
- public sealed class InferRecipesCommand : ICommand
+ internal sealed class InferRecipesCommand : ICommand
{
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
public sealed class Arguments
{
[Argument(ArgumentType.Required, HelpText = "Text file with data to analyze", ShortName = "data")]
@@ -35,6 +36,7 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Optional path to the schema definition file generated by the InferSchema command", ShortName = "schema")]
public string SchemaDefinitionFile;
}
+#pragma warning restore CS0649
private readonly IHost _host;
private readonly string _dataFile;
diff --git a/test/Microsoft.ML.InferenceTesting/InferSchemaCommand.cs b/test/Microsoft.ML.Predictor.Tests/Commands/InferSchemaCommand.cs
similarity index 95%
rename from test/Microsoft.ML.InferenceTesting/InferSchemaCommand.cs
rename to test/Microsoft.ML.Predictor.Tests/Commands/InferSchemaCommand.cs
index 7a96008bd0..269da88b19 100644
--- a/test/Microsoft.ML.InferenceTesting/InferSchemaCommand.cs
+++ b/test/Microsoft.ML.Predictor.Tests/Commands/InferSchemaCommand.cs
@@ -22,8 +22,9 @@ namespace Microsoft.ML.Runtime.MLTesting.Inference
/// This command generates a suggested RSP to load the text file and recipes it prior to training.
/// The results are output to the console and also to the RSP file, if it's specified.
///
- public sealed class InferSchemaCommand : ICommand
+ internal sealed class InferSchemaCommand : ICommand
{
+#pragma warning disable CS0649 // The fields will still be set via the reflection driven mechanisms.
public sealed class Arguments
{
[Argument(ArgumentType.Required, HelpText = "Text file with data to analyze", ShortName = "data")]
@@ -32,6 +33,7 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Path to the output json file describing the columns", ShortName = "out")]
public string OutputFile;
}
+#pragma warning restore CS0649
private readonly IHost _host;
private readonly string _dataFile;
diff --git a/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj b/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj
index d089946b87..9cc27b17e2 100644
--- a/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj
+++ b/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj
@@ -16,7 +16,6 @@
-
diff --git a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs
index 48f6336ecb..84cbea24ba 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs
@@ -28,111 +28,102 @@ public TestAutoInference(ITestOutputHelper helper)
[TestCategory("EntryPoints")]
public void TestLearn()
{
- using (var env = new ConsoleEnvironment()
- .AddStandardComponents()) // AutoInference.InferPipelines uses ComponentCatalog to read text data
- {
- string pathData = GetDataPath("adult.train");
- string pathDataTest = GetDataPath("adult.test");
- int numOfSampleRows = 1000;
- int batchSize = 5;
- int numIterations = 10;
- int numTransformLevels = 3;
- SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
-
- // Using the simple, uniform random sampling (with replacement) engine
- PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(env);
-
- // Test initial learning
- var amls = AutoInference.InferPipelines(env, autoMlEngine, pathData, "", out var schema, numTransformLevels, batchSize,
- metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations / 2), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer);
- env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations / 2);
-
- // Resume learning
- amls.UpdateTerminator(new IterationTerminator(numIterations));
- bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
- env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations);
-
- // Use best pipeline for another task
- var inputFileTrain = new SimpleFileHandle(env, pathData, false, false);
+ var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners
+ string pathData = GetDataPath("adult.train");
+ string pathDataTest = GetDataPath("adult.test");
+ int numOfSampleRows = 1000;
+ int batchSize = 5;
+ int numIterations = 10;
+ int numTransformLevels = 3;
+ SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
+
+ // Using the simple, uniform random sampling (with replacement) engine
+ PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(env);
+
+ // Test initial learning
+ var amls = AutoInference.InferPipelines(env, autoMlEngine, pathData, "", out var schema, numTransformLevels, batchSize,
+ metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations / 2), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer);
+ env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations / 2);
+
+ // Resume learning
+ amls.UpdateTerminator(new IterationTerminator(numIterations));
+ bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
+ env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations);
+
+ // Use best pipeline for another task
+ var inputFileTrain = new SimpleFileHandle(env, pathData, false, false);
#pragma warning disable 0618
- var datasetTrain = ImportTextData.ImportText(env,
- new ImportTextData.Input { InputFile = inputFileTrain, CustomSchema = schema }).Data;
- var inputFileTest = new SimpleFileHandle(env, pathDataTest, false, false);
- var datasetTest = ImportTextData.ImportText(env,
- new ImportTextData.Input { InputFile = inputFileTest, CustomSchema = schema }).Data;
+ var datasetTrain = ImportTextData.ImportText(env,
+ new ImportTextData.Input { InputFile = inputFileTrain, CustomSchema = schema }).Data;
+ var inputFileTest = new SimpleFileHandle(env, pathDataTest, false, false);
+ var datasetTest = ImportTextData.ImportText(env,
+ new ImportTextData.Input { InputFile = inputFileTest, CustomSchema = schema }).Data;
#pragma warning restore 0618
- // REVIEW: Theoretically, it could be the case that a new, very bad learner is introduced and
- // we get unlucky and only select it every time, such that this test fails. Not
- // likely at all, but a non-zero probability. Should be ok, since all current learners are returning d > .80.
- bestPipeline.RunTrainTestExperiment(datasetTrain, datasetTest, metric, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer,
- out var testMetricValue, out var trainMtericValue);
- env.Check(testMetricValue > 0.2);
- }
+ // REVIEW: Theoretically, it could be the case that a new, very bad learner is introduced and
+ // we get unlucky and only select it every time, such that this test fails. Not
+ // likely at all, but a non-zero probability. Should be ok, since all current learners are returning d > .80.
+ bestPipeline.RunTrainTestExperiment(datasetTrain, datasetTest, metric, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer,
+ out var testMetricValue, out var trainMtericValue);
+ env.Check(testMetricValue > 0.2);
Done();
}
[Fact(Skip = "Need CoreTLC specific baseline update")]
public void TestTextDatasetLearn()
{
- using (var env = new ConsoleEnvironment()
- .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners
- {
- string pathData = GetDataPath(@"../UnitTest/tweets_labeled_10k_test_validation.tsv");
- int batchSize = 5;
- int numIterations = 35;
- int numTransformLevels = 1;
- int numSampleRows = 100;
- SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro);
-
- // Using the simple, uniform random sampling (with replacement) engine
- PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(env);
-
- // Test initial learning
- var amls = AutoInference.InferPipelines(env, autoMlEngine, pathData, "", out var _, numTransformLevels, batchSize,
- metric, out var _, numSampleRows, new IterationTerminator(numIterations),
- MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer);
- env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations);
- }
+ var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners
+ string pathData = GetDataPath(@"../UnitTest/tweets_labeled_10k_test_validation.tsv");
+ int batchSize = 5;
+ int numIterations = 35;
+ int numTransformLevels = 1;
+ int numSampleRows = 100;
+ SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro);
+
+ // Using the simple, uniform random sampling (with replacement) engine
+ PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(env);
+
+ // Test initial learning
+ var amls = AutoInference.InferPipelines(env, autoMlEngine, pathData, "", out var _, numTransformLevels, batchSize,
+ metric, out var _, numSampleRows, new IterationTerminator(numIterations),
+ MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer);
+ env.Check(amls.GetAllEvaluatedPipelines().Length == numIterations);
Done();
}
[Fact]
public void TestPipelineNodeCloning()
{
- using (var env = new ConsoleEnvironment()
- .AddStandardComponents()) // RecipeInference.AllowedLearners uses ComponentCatalog to find all learners
- {
- var lr1 = RecipeInference
+ var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners
+ var lr1 = RecipeInference
.AllowedLearners(env, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer)
.First(learner => learner.PipelineNode != null && learner.LearnerName.Contains("LogisticRegression"));
- var sdca1 = RecipeInference
- .AllowedLearners(env, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer)
- .First(learner => learner.PipelineNode != null && learner.LearnerName.Contains("StochasticDualCoordinateAscent"));
-
- // Clone and change hyperparam values
- var lr2 = lr1.Clone();
- lr1.PipelineNode.SweepParams[0].RawValue = 1.2f;
- lr2.PipelineNode.SweepParams[0].RawValue = 3.5f;
- var sdca2 = sdca1.Clone();
- sdca1.PipelineNode.SweepParams[0].RawValue = 3;
- sdca2.PipelineNode.SweepParams[0].RawValue = 0;
-
- // Make sure the changes are propagated to entry point objects
- env.Check(lr1.PipelineNode.UpdateProperties());
- env.Check(lr2.PipelineNode.UpdateProperties());
- env.Check(sdca1.PipelineNode.UpdateProperties());
- env.Check(sdca2.PipelineNode.UpdateProperties());
- env.Check(lr1.PipelineNode.CheckEntryPointStateMatchesParamValues());
- env.Check(lr2.PipelineNode.CheckEntryPointStateMatchesParamValues());
- env.Check(sdca1.PipelineNode.CheckEntryPointStateMatchesParamValues());
- env.Check(sdca2.PipelineNode.CheckEntryPointStateMatchesParamValues());
-
- // Make sure second object's set of changes didn't overwrite first object's
- env.Check(!lr1.PipelineNode.SweepParams[0].RawValue.Equals(lr2.PipelineNode.SweepParams[0].RawValue));
- env.Check(!sdca2.PipelineNode.SweepParams[0].RawValue.Equals(sdca1.PipelineNode.SweepParams[0].RawValue));
- }
+ var sdca1 = RecipeInference
+ .AllowedLearners(env, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer)
+ .First(learner => learner.PipelineNode != null && learner.LearnerName.Contains("StochasticDualCoordinateAscent"));
+
+ // Clone and change hyperparam values
+ var lr2 = lr1.Clone();
+ lr1.PipelineNode.SweepParams[0].RawValue = 1.2f;
+ lr2.PipelineNode.SweepParams[0].RawValue = 3.5f;
+ var sdca2 = sdca1.Clone();
+ sdca1.PipelineNode.SweepParams[0].RawValue = 3;
+ sdca2.PipelineNode.SweepParams[0].RawValue = 0;
+
+ // Make sure the changes are propagated to entry point objects
+ env.Check(lr1.PipelineNode.UpdateProperties());
+ env.Check(lr2.PipelineNode.UpdateProperties());
+ env.Check(sdca1.PipelineNode.UpdateProperties());
+ env.Check(sdca2.PipelineNode.UpdateProperties());
+ env.Check(lr1.PipelineNode.CheckEntryPointStateMatchesParamValues());
+ env.Check(lr2.PipelineNode.CheckEntryPointStateMatchesParamValues());
+ env.Check(sdca1.PipelineNode.CheckEntryPointStateMatchesParamValues());
+ env.Check(sdca2.PipelineNode.CheckEntryPointStateMatchesParamValues());
+
+ // Make sure second object's set of changes didn't overwrite first object's
+ env.Check(!lr1.PipelineNode.SweepParams[0].RawValue.Equals(lr2.PipelineNode.SweepParams[0].RawValue));
+ env.Check(!sdca2.PipelineNode.SweepParams[0].RawValue.Equals(sdca1.PipelineNode.SweepParams[0].RawValue));
}
[Fact]
@@ -143,45 +134,42 @@ public void TestHyperparameterFreezing()
int batchSize = 1;
int numIterations = 10;
int numTransformLevels = 3;
- using (var env = new ConsoleEnvironment()
- .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners
- {
- SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
-
- // Using the simple, uniform random sampling (with replacement) brain
- PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env);
-
- // Run initial experiments
- var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize,
- metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations),
- MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer);
-
- // Clear results
- amls.ClearEvaluatedPipelines();
-
- // Get space, remove transforms and all but one learner, freeze hyperparameters on learner.
- var space = amls.GetSearchSpace();
- var transforms = space.Item1.Where(t =>
- t.ExpertType != typeof(TransformInference.Experts.Categorical)).ToArray();
- var learners = new[] { space.Item2.First() };
- var hyperParam = learners[0].PipelineNode.SweepParams.First();
- var frozenParamValue = hyperParam.RawValue;
- hyperParam.Frozen = true;
- amls.UpdateSearchSpace(learners, transforms);
-
- // Allow for one more iteration
- amls.UpdateTerminator(new IterationTerminator(numIterations + 1));
-
- // Do learning. Only retained learner should be left in all pipelines.
- bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
-
- // Make sure all pipelines have retained learner
- Assert.True(amls.GetAllEvaluatedPipelines().All(p => p.Learner.LearnerName == learners[0].LearnerName));
-
- // Make sure hyperparameter value did not change
- Assert.NotNull(bestPipeline);
- Assert.Equal(bestPipeline.Learner.PipelineNode.SweepParams.First().RawValue, frozenParamValue);
- }
+ var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners
+ SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
+
+ // Using the simple, uniform random sampling (with replacement) brain
+ PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env);
+
+ // Run initial experiments
+ var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize,
+ metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations),
+ MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer);
+
+ // Clear results
+ amls.ClearEvaluatedPipelines();
+
+ // Get space, remove transforms and all but one learner, freeze hyperparameters on learner.
+ var space = amls.GetSearchSpace();
+ var transforms = space.Item1.Where(t =>
+ t.ExpertType != typeof(TransformInference.Experts.Categorical)).ToArray();
+ var learners = new[] { space.Item2.First() };
+ var hyperParam = learners[0].PipelineNode.SweepParams.First();
+ var frozenParamValue = hyperParam.RawValue;
+ hyperParam.Frozen = true;
+ amls.UpdateSearchSpace(learners, transforms);
+
+ // Allow for one more iteration
+ amls.UpdateTerminator(new IterationTerminator(numIterations + 1));
+
+ // Do learning. Only retained learner should be left in all pipelines.
+ bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
+
+ // Make sure all pipelines have retained learner
+ Assert.True(amls.GetAllEvaluatedPipelines().All(p => p.Learner.LearnerName == learners[0].LearnerName));
+
+ // Make sure hyperparameter value did not change
+ Assert.NotNull(bestPipeline);
+ Assert.Equal(bestPipeline.Learner.PipelineNode.SweepParams.First().RawValue, frozenParamValue);
}
[Fact(Skip = "Dataset not available.")]
@@ -192,30 +180,27 @@ public void TestRegressionPipelineWithMinimizingMetric()
int batchSize = 5;
int numIterations = 10;
int numTransformLevels = 1;
- using (var env = new ConsoleEnvironment()
- .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners
- {
- SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro);
+ var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners
+ SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro);
- // Using the simple, uniform random sampling (with replacement) brain
- PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env);
+ // Using the simple, uniform random sampling (with replacement) brain
+ PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env);
- // Run initial experiments
- var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize,
- metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations),
- MacroUtils.TrainerKinds.SignatureRegressorTrainer);
+ // Run initial experiments
+ var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize,
+ metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations),
+ MacroUtils.TrainerKinds.SignatureRegressorTrainer);
- // Allow for one more iteration
- amls.UpdateTerminator(new IterationTerminator(numIterations + 1));
+ // Allow for one more iteration
+ amls.UpdateTerminator(new IterationTerminator(numIterations + 1));
- // Do learning. Only retained learner should be left in all pipelines.
- bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
+ // Do learning. Only retained learner should be left in all pipelines.
+ bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfSampleRows);
- // Make sure hyperparameter value did not change
- Assert.NotNull(bestPipeline);
- Assert.True(amls.GetAllEvaluatedPipelines().All(
- p => p.PerformanceSummary.MetricValue >= bestPipeline.PerformanceSummary.MetricValue));
- }
+ // Make sure hyperparameter value did not change
+ Assert.NotNull(bestPipeline);
+ Assert.True(amls.GetAllEvaluatedPipelines().All(
+ p => p.PerformanceSummary.MetricValue >= bestPipeline.PerformanceSummary.MetricValue));
}
[Fact]
@@ -227,27 +212,24 @@ public void TestLearnerConstrainingByName()
int numIterations = 1;
int numTransformLevels = 2;
var retainedLearnerNames = new[] { $"LogisticRegressionBinaryClassifier", $"FastTreeBinaryClassifier" };
- using (var env = new ConsoleEnvironment()
- .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners
- {
- SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
-
- // Using the simple, uniform random sampling (with replacement) brain.
- PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env);
-
- // Run initial experiment.
- var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _,
- numTransformLevels, batchSize, metric, out var _, numOfSampleRows,
- new IterationTerminator(numIterations), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer);
-
- // Keep only logistic regression and FastTree.
- amls.KeepSelectedLearners(retainedLearnerNames);
- var space = amls.GetSearchSpace();
-
- // Make sure only learners left are those retained.
- Assert.Equal(retainedLearnerNames.Length, space.Item2.Length);
- Assert.True(space.Item2.All(l => retainedLearnerNames.Any(r => r == l.LearnerName)));
- }
+ var env = new MLContext().AddStandardComponents(); // AutoInference uses ComponentCatalog to find all learners
+ SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
+
+ // Using the simple, uniform random sampling (with replacement) brain.
+ PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env);
+
+ // Run initial experiment.
+ var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _,
+ numTransformLevels, batchSize, metric, out var _, numOfSampleRows,
+ new IterationTerminator(numIterations), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer);
+
+ // Keep only logistic regression and FastTree.
+ amls.KeepSelectedLearners(retainedLearnerNames);
+ var space = amls.GetSearchSpace();
+
+ // Make sure only learners left are those retained.
+ Assert.Equal(retainedLearnerNames.Length, space.Item2.Length);
+ Assert.True(space.Item2.All(l => retainedLearnerNames.Any(r => r == l.LearnerName)));
}
[Fact]
diff --git a/test/Microsoft.ML.Predictor.Tests/TestDatasetInference.cs b/test/Microsoft.ML.Predictor.Tests/TestDatasetInference.cs
index 0b92eb7b30..5cf5a86ceb 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestDatasetInference.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestDatasetInference.cs
@@ -25,7 +25,7 @@ public TestDatasetInference(ITestOutputHelper helper)
{
}
- [Fact(Skip="Disabled")]
+ [Fact(Skip = "Disabled")]
public void DatasetInferenceTest()
{
var datasets = new[]
@@ -35,51 +35,49 @@ public void DatasetInferenceTest()
GetDataPath(@"..\UnitTest\breast-cancer.txt"),
};
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ var h = env.Register("InferDatasetFeatures", seed: 0, verbose: false);
+
+ using (var ch = h.Start("InferDatasetFeatures"))
{
- var h = env.Register("InferDatasetFeatures", seed: 0, verbose: false);
- using (var ch = h.Start("InferDatasetFeatures"))
+ for (int i = 0; i < datasets.Length; i++)
{
+ var sample = TextFileSample.CreateFromFullFile(h, datasets[i]);
+ var splitResult = TextFileContents.TrySplitColumns(h, sample, TextFileContents.DefaultSeparators);
+ if (!splitResult.IsSuccess)
+ throw ch.ExceptDecode("Couldn't detect separator.");
- for (int i = 0; i < datasets.Length; i++)
- {
- var sample = TextFileSample.CreateFromFullFile(h, datasets[i]);
- var splitResult = TextFileContents.TrySplitColumns(h, sample, TextFileContents.DefaultSeparators);
- if (!splitResult.IsSuccess)
- throw ch.ExceptDecode("Couldn't detect separator.");
-
- var typeInfResult = ColumnTypeInference.InferTextFileColumnTypes(Env, sample,
- new ColumnTypeInference.Arguments
- {
- Separator = splitResult.Separator,
- AllowSparse = splitResult.AllowSparse,
- AllowQuote = splitResult.AllowQuote,
- ColumnCount = splitResult.ColumnCount
- });
-
- if (!typeInfResult.IsSuccess)
- return;
-
- ColumnGroupingInference.GroupingColumn[] columns = null;
- bool hasHeader = false;
- columns = InferenceUtils.InferColumnPurposes(ch, h, sample, splitResult, out hasHeader);
- Guid id = new Guid("60C77F4E-DB62-4351-8311-9B392A12968E");
- var commandArgs = new DatasetFeatureInference.Arguments(typeInfResult.Data,
- columns.Select(
- col =>
- new DatasetFeatureInference.Column(col.SuggestedName, col.Purpose, col.ItemKind,
- col.ColumnRangeSelector)).ToArray(), sample.FullFileSize, sample.ApproximateRowCount,
- false, id, true);
-
- string jsonString = DatasetFeatureInference.InferDatasetFeatures(env, commandArgs);
- var outFile = string.Format("dataset-inference-result-{0:00}.txt", i);
- string dataPath = GetOutputPath(@"..\Common\Inference", outFile);
- using (var sw = new StreamWriter(File.Create(dataPath)))
- sw.WriteLine(jsonString);
-
- CheckEquality(@"..\Common\Inference", outFile);
- }
+ var typeInfResult = ColumnTypeInference.InferTextFileColumnTypes(Env, sample,
+ new ColumnTypeInference.Arguments
+ {
+ Separator = splitResult.Separator,
+ AllowSparse = splitResult.AllowSparse,
+ AllowQuote = splitResult.AllowQuote,
+ ColumnCount = splitResult.ColumnCount
+ });
+
+ if (!typeInfResult.IsSuccess)
+ return;
+
+ ColumnGroupingInference.GroupingColumn[] columns = null;
+ bool hasHeader = false;
+ columns = InferenceUtils.InferColumnPurposes(ch, h, sample, splitResult, out hasHeader);
+ Guid id = new Guid("60C77F4E-DB62-4351-8311-9B392A12968E");
+ var commandArgs = new DatasetFeatureInference.Arguments(typeInfResult.Data,
+ columns.Select(
+ col =>
+ new DatasetFeatureInference.Column(col.SuggestedName, col.Purpose, col.ItemKind,
+ col.ColumnRangeSelector)).ToArray(), sample.FullFileSize, sample.ApproximateRowCount,
+ false, id, true);
+
+ string jsonString = DatasetFeatureInference.InferDatasetFeatures(env, commandArgs);
+ var outFile = string.Format("dataset-inference-result-{0:00}.txt", i);
+ string dataPath = GetOutputPath(@"..\Common\Inference", outFile);
+ using (var sw = new StreamWriter(File.Create(dataPath)))
+ sw.WriteLine(jsonString);
+
+ CheckEquality(@"..\Common\Inference", outFile);
}
}
Done();
@@ -93,26 +91,24 @@ public void InferSchemaCommandTest()
GetDataPath(Path.Combine("..", "data", "wikipedia-detox-250-line-data.tsv"))
};
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ var h = env.Register("InferSchemaCommandTest", seed: 0, verbose: false);
+ using (var ch = h.Start("InferSchemaCommandTest"))
{
- var h = env.Register("InferSchemaCommandTest", seed: 0, verbose: false);
- using (var ch = h.Start("InferSchemaCommandTest"))
+ for (int i = 0; i < datasets.Length; i++)
{
- for (int i = 0; i < datasets.Length; i++)
+ var outFile = string.Format("dataset-infer-schema-result-{0:00}.txt", i);
+ string dataPath = GetOutputPath(Path.Combine("..", "Common", "Inference"), outFile);
+ var args = new InferSchemaCommand.Arguments()
{
- var outFile = string.Format("dataset-infer-schema-result-{0:00}.txt", i);
- string dataPath = GetOutputPath(Path.Combine("..", "Common", "Inference"), outFile);
- var args = new InferSchemaCommand.Arguments()
- {
- DataFile = datasets[i],
- OutputFile = dataPath,
- };
+ DataFile = datasets[i],
+ OutputFile = dataPath,
+ };
- var cmd = new InferSchemaCommand(Env, args);
- cmd.Run();
+ var cmd = new InferSchemaCommand(Env, args);
+ cmd.Run();
- CheckEquality(Path.Combine("..", "Common", "Inference"), outFile);
- }
+ CheckEquality(Path.Combine("..", "Common", "Inference"), outFile);
}
}
Done();
@@ -128,26 +124,24 @@ public void InferRecipesCommandTest()
GetDataPath(Path.Combine("..", "data", "wikipedia-detox-250-line-data-schema.txt")))
};
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ var h = env.Register("InferRecipesCommandTest", seed: 0, verbose: false);
+ using (var ch = h.Start("InferRecipesCommandTest"))
{
- var h = env.Register("InferRecipesCommandTest", seed: 0, verbose: false);
- using (var ch = h.Start("InferRecipesCommandTest"))
+ for (int i = 0; i < datasets.Length; i++)
{
- for (int i = 0; i < datasets.Length; i++)
+ var outFile = string.Format("dataset-infer-recipe-result-{0:00}.txt", i);
+ string dataPath = GetOutputPath(Path.Combine("..", "Common", "Inference"), outFile);
+ var args = new InferRecipesCommand.Arguments()
{
- var outFile = string.Format("dataset-infer-recipe-result-{0:00}.txt", i);
- string dataPath = GetOutputPath(Path.Combine("..", "Common", "Inference"), outFile);
- var args = new InferRecipesCommand.Arguments()
- {
- DataFile = datasets[i].Item1,
- SchemaDefinitionFile = datasets[i].Item2,
- RspOutputFile = dataPath
- };
- var cmd = new InferRecipesCommand(Env, args);
- cmd.Run();
-
- CheckEquality(Path.Combine("..", "Common", "Inference"), outFile);
- }
+ DataFile = datasets[i].Item1,
+ SchemaDefinitionFile = datasets[i].Item2,
+ RspOutputFile = dataPath
+ };
+ var cmd = new InferRecipesCommand(Env, args);
+ cmd.Run();
+
+ CheckEquality(Path.Combine("..", "Common", "Inference"), outFile);
}
}
Done();
diff --git a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs
index b72c9773d3..5e277daa05 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs
@@ -123,42 +123,40 @@ public void PipelineSweeperNoTransforms()
const int batchSize = 5;
const int numIterations = 20;
const int numTransformLevels = 2;
- using (var env = new ConsoleEnvironment())
- {
- SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
+ var env = new MLContext();
+ SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
- // Using the simple, uniform random sampling (with replacement) engine
- PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(Env);
+ // Using the simple, uniform random sampling (with replacement) engine
+ PipelineOptimizerBase autoMlEngine = new UniformRandomEngine(Env);
- // Create search object
- var amls = new AutoInference.AutoMlMlState(Env, metric, autoMlEngine, new IterationTerminator(numIterations),
- MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer, datasetTrain, datasetTest);
+ // Create search object
+ var amls = new AutoInference.AutoMlMlState(Env, metric, autoMlEngine, new IterationTerminator(numIterations),
+ MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer, datasetTrain, datasetTest);
- // Infer search space
- amls.InferSearchSpace(numTransformLevels);
+ // Infer search space
+ amls.InferSearchSpace(numTransformLevels);
- // Create macro object
- var pipelineSweepInput = new Microsoft.ML.Legacy.Models.PipelineSweeper()
- {
- BatchSize = batchSize,
- };
-
- var exp = new Experiment(Env);
- var output = exp.Add(pipelineSweepInput);
- exp.Compile();
- exp.SetInput(pipelineSweepInput.TrainingData, datasetTrain);
- exp.SetInput(pipelineSweepInput.TestingData, datasetTest);
- exp.SetInput(pipelineSweepInput.State, amls);
- exp.SetInput(pipelineSweepInput.CandidateOutputs, new IDataView[0]);
- exp.Run();
-
- // Make sure you get back an AutoMlState, and that it ran for correct number of iterations
- // with at least minimal performance values (i.e., best should have AUC better than 0.1 on this dataset).
- AutoInference.AutoMlMlState amlsOut = (AutoInference.AutoMlMlState)exp.GetOutput(output.State);
- Assert.NotNull(amlsOut);
- Assert.Equal(amlsOut.GetAllEvaluatedPipelines().Length, numIterations);
- Assert.True(amlsOut.GetBestPipeline().PerformanceSummary.MetricValue > 0.8);
- }
+ // Create macro object
+ var pipelineSweepInput = new Microsoft.ML.Legacy.Models.PipelineSweeper()
+ {
+ BatchSize = batchSize,
+ };
+
+ var exp = new Experiment(Env);
+ var output = exp.Add(pipelineSweepInput);
+ exp.Compile();
+ exp.SetInput(pipelineSweepInput.TrainingData, datasetTrain);
+ exp.SetInput(pipelineSweepInput.TestingData, datasetTest);
+ exp.SetInput(pipelineSweepInput.State, amls);
+ exp.SetInput(pipelineSweepInput.CandidateOutputs, new IDataView[0]);
+ exp.Run();
+
+ // Make sure you get back an AutoMlState, and that it ran for correct number of iterations
+ // with at least minimal performance values (i.e., best should have AUC better than 0.1 on this dataset).
+ AutoInference.AutoMlMlState amlsOut = (AutoInference.AutoMlMlState)exp.GetOutput(output.State);
+ Assert.NotNull(amlsOut);
+ Assert.Equal(amlsOut.GetAllEvaluatedPipelines().Length, numIterations);
+ Assert.True(amlsOut.GetBestPipeline().PerformanceSummary.MetricValue > 0.8);
}
[Fact]
diff --git a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs
index 3b28fc1cfa..bdb9677495 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs
@@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.ImageAnalytics;
using Xunit;
@@ -20,7 +19,7 @@ public ImageAnalyticsTests(ITestOutputHelper output)
[Fact]
public void SimpleImageSmokeTest()
{
- var env = new ConsoleEnvironment(0, verbose: true);
+ var env = new MLContext(0);
var reader = TextLoader.CreateReader(env,
ctx => ctx.LoadText(0).LoadAsImage().AsGrayscale().Resize(10, 8).ExtractPixels());
diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
index b5fb351b99..ab7049a58a 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
@@ -58,7 +58,7 @@ private void CheckSchemaHasColumn(ISchema schema, string name, out int idx)
[Fact]
public void SimpleTextLoaderCopyColumnsTest()
{
- var env = new ConsoleEnvironment(0, verbose: true);
+ var env = new MLContext(0);
const string data = "0 hello 3.14159 -0 2\n"
+ "1 1 2 4 15";
@@ -159,7 +159,7 @@ private static Obnoxious3 MakeObnoxious3(Scalar hi, Obnoxious1 my, T
[Fact]
public void SimpleTextLoaderObnoxiousTypeTest()
{
- var env = new ConsoleEnvironment(0, verbose: true);
+ var env = new MLContext(0);
const string data = "0 hello 3.14159 -0 2\n"
+ "1 1 2 4 15";
@@ -206,7 +206,7 @@ private static KeyValuePair P(string name, ColumnType type)
[Fact]
public void AssertStaticSimple()
{
- var env = new ConsoleEnvironment(0, verbose: true);
+ var env = new MLContext(0);
var schema = SimpleSchemaUtils.Create(env,
P("hello", TextType.Instance),
P("my", new VectorType(NumberType.I8, 5)),
@@ -230,7 +230,7 @@ public void AssertStaticSimple()
[Fact]
public void AssertStaticSimpleFailure()
{
- var env = new ConsoleEnvironment(0, verbose: true);
+ var env = new MLContext(0);
var schema = SimpleSchemaUtils.Create(env,
P("hello", TextType.Instance),
P("my", new VectorType(NumberType.I8, 5)),
@@ -262,7 +262,7 @@ private sealed class MetaCounted : ICounted
[Fact]
public void AssertStaticKeys()
{
- var env = new ConsoleEnvironment(0, verbose: true);
+ var env = new MLContext(0);
var counted = new MetaCounted();
// We'll test a few things here. First, the case where the key-value metadata is text.
@@ -369,7 +369,7 @@ public void AssertStaticKeys()
[Fact]
public void Normalizer()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("generated_regression_dataset.csv");
var dataSource = new MultiFileSource(dataPath);
@@ -394,7 +394,7 @@ public void Normalizer()
[Fact]
public void NormalizerWithOnFit()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("generated_regression_dataset.csv");
var dataSource = new MultiFileSource(dataPath);
@@ -438,7 +438,7 @@ public void NormalizerWithOnFit()
[Fact]
public void ToKey()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("iris.data");
var reader = TextLoader.CreateReader(env,
c => (label: c.LoadText(4), values: c.LoadFloat(0, 3)),
@@ -476,7 +476,7 @@ public void ToKey()
[Fact]
public void ConcatWith()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("iris.data");
var reader = TextLoader.CreateReader(env,
c => (label: c.LoadText(4), values: c.LoadFloat(0, 3), value: c.LoadFloat(2)),
@@ -514,7 +514,7 @@ public void ConcatWith()
[Fact]
public void Tokenize()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
var reader = TextLoader.CreateReader(env, ctx => (
label: ctx.LoadBool(0),
@@ -543,7 +543,7 @@ public void Tokenize()
[Fact]
public void NormalizeTextAndRemoveStopWords()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
var reader = TextLoader.CreateReader(env, ctx => (
label: ctx.LoadBool(0),
@@ -572,7 +572,7 @@ public void NormalizeTextAndRemoveStopWords()
[Fact]
public void ConvertToWordBag()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
var reader = TextLoader.CreateReader(env, ctx => (
label: ctx.LoadBool(0),
@@ -601,7 +601,7 @@ public void ConvertToWordBag()
[Fact]
public void Ngrams()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
var reader = TextLoader.CreateReader(env, ctx => (
label: ctx.LoadBool(0),
@@ -631,7 +631,7 @@ public void Ngrams()
[Fact]
public void LpGcNormAndWhitening()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("generated_regression_dataset.csv");
var dataSource = new MultiFileSource(dataPath);
@@ -669,7 +669,7 @@ public void LpGcNormAndWhitening()
[Fact(Skip = "LDA transform cannot be trained on empty data, schema propagation fails")]
public void LdaTopicModel()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
var reader = TextLoader.CreateReader(env, ctx => (
label: ctx.LoadBool(0),
@@ -696,7 +696,7 @@ public void LdaTopicModel()
[Fact(Skip = "FeatureSeclection transform cannot be trained on empty data, schema propagation fails")]
public void FeatureSelection()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
var reader = TextLoader.CreateReader(env, ctx => (
label: ctx.LoadBool(0),
@@ -725,7 +725,7 @@ public void FeatureSelection()
[Fact]
public void TrainTestSplit()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -755,7 +755,7 @@ public void TrainTestSplit()
[Fact]
public void PrincipalComponentAnalysis()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("generated_regression_dataset.csv");
var dataSource = new MultiFileSource(dataPath);
@@ -778,10 +778,10 @@ public void PrincipalComponentAnalysis()
[Fact]
public void NAIndicatorStatic()
{
- var Env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
string dataPath = GetDataPath("breast-cancer.txt");
- var reader = TextLoader.CreateReader(Env, ctx => (
+ var reader = TextLoader.CreateReader(env, ctx => (
ScalarFloat: ctx.LoadFloat(1),
ScalarDouble: ctx.LoadDouble(1),
VectorFloat: ctx.LoadFloat(1, 4),
@@ -798,12 +798,12 @@ public void NAIndicatorStatic()
D: row.VectorDoulbe.IsMissingValue()
));
- IDataView newData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4);
+ IDataView newData = TakeFilter.Create(env, est.Fit(data).Transform(data).AsDynamic, 4);
Assert.NotNull(newData);
- bool[] ScalarFloat = newData.GetColumn(Env, "A").ToArray();
- bool[] ScalarDouble = newData.GetColumn(Env, "B").ToArray();
- bool[][] VectorFloat = newData.GetColumn(Env, "C").ToArray();
- bool[][] VectorDoulbe = newData.GetColumn(Env, "D").ToArray();
+ bool[] ScalarFloat = newData.GetColumn(env, "A").ToArray();
+ bool[] ScalarDouble = newData.GetColumn(env, "B").ToArray();
+ bool[][] VectorFloat = newData.GetColumn(env, "C").ToArray();
+ bool[][] VectorDoulbe = newData.GetColumn(env, "D").ToArray();
Assert.NotNull(ScalarFloat);
Assert.NotNull(ScalarDouble);
@@ -822,7 +822,7 @@ public void NAIndicatorStatic()
[Fact]
public void TextNormalizeStatic()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(0);
var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv");
var reader = TextLoader.CreateReader(env, ctx => (
label: ctx.LoadBool(0),
@@ -862,7 +862,7 @@ public void TextNormalizeStatic()
[Fact]
public void TestPcaStatic()
{
- var env = new ConsoleEnvironment(seed: 1);
+ var env = new MLContext(0);
var dataSource = GetDataPath("generated_regression_dataset.csv");
var reader = TextLoader.CreateReader(env,
c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
index c689cba935..b59c2118b5 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
@@ -32,7 +32,7 @@ public Training(ITestOutputHelper output) : base(output)
[Fact]
public void SdcaRegression()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0, conc: 1);
var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -45,7 +45,8 @@ public void SdcaRegression()
LinearRegressionPredictor pred = null;
var est = reader.MakeNewEstimator()
- .Append(r => (r.label, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, onFit: p => pred = p)));
+ .Append(r => (r.label, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2,
+ onFit: p => pred = p, advancedSettings: s => s.NumThreads = 1)));
var pipe = reader.Append(est);
@@ -74,7 +75,7 @@ public void SdcaRegression()
[Fact]
public void SdcaRegressionNameCollision()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new RegressionContext(env);
@@ -85,7 +86,7 @@ public void SdcaRegressionNameCollision()
separator: ';', hasHeader: true);
var est = reader.MakeNewEstimator()
- .Append(r => (r.label, r.Score, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2)));
+ .Append(r => (r.label, r.Score, score: ctx.Trainers.Sdca(r.label, r.features, maxIterations: 2, advancedSettings: s => s.NumThreads = 1)));
var pipe = reader.Append(est);
@@ -104,7 +105,7 @@ public void SdcaRegressionNameCollision()
[Fact]
public void SdcaBinaryClassification()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new BinaryClassificationContext(env);
@@ -118,7 +119,8 @@ public void SdcaBinaryClassification()
var est = reader.MakeNewEstimator()
.Append(r => (r.label, preds: ctx.Trainers.Sdca(r.label, r.features,
maxIterations: 2,
- onFit: (p, c) => { pred = p; cali = c; })));
+ onFit: (p, c) => { pred = p; cali = c; },
+ advancedSettings: s => s.NumThreads = 1)));
var pipe = reader.Append(est);
@@ -149,7 +151,7 @@ public void SdcaBinaryClassification()
[Fact]
public void SdcaBinaryClassificationNoCalibration()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new BinaryClassificationContext(env);
@@ -165,7 +167,8 @@ public void SdcaBinaryClassificationNoCalibration()
var est = reader.MakeNewEstimator()
.Append(r => (r.label, preds: ctx.Trainers.Sdca(r.label, r.features,
maxIterations: 2,
- loss: loss, onFit: p => pred = p)));
+ loss: loss, onFit: p => pred = p,
+ advancedSettings: s => s.NumThreads = 1)));
var pipe = reader.Append(est);
@@ -192,7 +195,7 @@ public void SdcaBinaryClassificationNoCalibration()
[Fact]
public void AveragePerceptronNoCalibration()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new BinaryClassificationContext(env);
@@ -228,7 +231,7 @@ public void AveragePerceptronNoCalibration()
[Fact]
public void FfmBinaryClassification()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new BinaryClassificationContext(env);
@@ -260,7 +263,7 @@ public void FfmBinaryClassification()
[Fact]
public void SdcaMulticlass()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -310,7 +313,7 @@ public void SdcaMulticlass()
[Fact]
public void CrossValidate()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -334,7 +337,7 @@ public void CrossValidate()
[Fact]
public void FastTreeBinaryClassification()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new BinaryClassificationContext(env);
@@ -373,7 +376,7 @@ public void FastTreeBinaryClassification()
[Fact]
public void FastTreeRegression()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -415,7 +418,7 @@ public void FastTreeRegression()
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only
public void LightGbmBinaryClassification()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new BinaryClassificationContext(env);
@@ -455,7 +458,7 @@ public void LightGbmBinaryClassification()
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only
public void LightGbmRegression()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -497,7 +500,7 @@ public void LightGbmRegression()
[Fact]
public void PoissonRegression()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -513,7 +516,8 @@ public void PoissonRegression()
.Append(r => (r.label, score: ctx.Trainers.PoissonRegression(r.label, r.features,
l1Weight: 2,
enoforceNoNegativity: true,
- onFit: (p) => { pred = p; })));
+ onFit: (p) => { pred = p; },
+ advancedSettings: s => s.NumThreads = 1)));
var pipe = reader.Append(est);
@@ -539,7 +543,7 @@ public void PoissonRegression()
[Fact]
public void LogisticRegressionBinaryClassification()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new BinaryClassificationContext(env);
@@ -552,7 +556,8 @@ public void LogisticRegressionBinaryClassification()
var est = reader.MakeNewEstimator()
.Append(r => (r.label, preds: ctx.Trainers.LogisticRegressionBinaryClassifier(r.label, r.features,
l1Weight: 10,
- onFit: (p) => { pred = p; })));
+ onFit: (p) => { pred = p; },
+ advancedSettings: s => s.NumThreads = 1)));
var pipe = reader.Append(est);
@@ -577,7 +582,7 @@ public void LogisticRegressionBinaryClassification()
[Fact]
public void MulticlassLogisticRegression()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -592,7 +597,8 @@ public void MulticlassLogisticRegression()
.Append(r => (label: r.label.ToKey(), r.features))
.Append(r => (r.label, preds: ctx.Trainers.MultiClassLogisticRegression(
r.label,
- r.features, onFit: p => pred = p)));
+ r.features, onFit: p => pred = p,
+ advancedSettings: s => s.NumThreads = 1)));
var pipe = reader.Append(est);
@@ -620,7 +626,7 @@ public void MulticlassLogisticRegression()
[Fact]
public void OnlineGradientDescent()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -663,11 +669,10 @@ public void OnlineGradientDescent()
[Fact]
public void KMeans()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0, conc: 1);
var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
var dataSource = new MultiFileSource(dataPath);
- var ctx = new ClusteringContext(env);
var reader = TextLoader.CreateReader(env,
c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));
@@ -675,7 +680,7 @@ public void KMeans()
var est = reader.MakeNewEstimator()
.Append(r => (label: r.label.ToKey(), r.features))
- .Append(r => (r.label, r.features, preds: ctx.Trainers.KMeans(r.features, clustersCount: 3, onFit: p => pred = p)));
+ .Append(r => (r.label, r.features, preds: env.Clustering.Trainers.KMeans(r.features, clustersCount: 3, onFit: p => pred = p, advancedSettings: s => s.NumThreads = 1)));
var pipe = reader.Append(est);
@@ -691,23 +696,23 @@ public void KMeans()
var data = model.Read(dataSource);
- var metrics = ctx.Evaluate(data, r => r.preds.score, r => r.label, r => r.features);
+ var metrics = env.Clustering.Evaluate(data, r => r.preds.score, r => r.label, r => r.features);
Assert.NotNull(metrics);
Assert.InRange(metrics.AvgMinScore, 0.5262, 0.5264);
Assert.InRange(metrics.Nmi, 0.73, 0.77);
Assert.InRange(metrics.Dbi, 0.662, 0.667);
- metrics = ctx.Evaluate(data, r => r.preds.score, label: r => r.label);
+ metrics = env.Clustering.Evaluate(data, r => r.preds.score, label: r => r.label);
Assert.NotNull(metrics);
Assert.InRange(metrics.AvgMinScore, 0.5262, 0.5264);
Assert.True(metrics.Dbi == 0.0);
- metrics = ctx.Evaluate(data, r => r.preds.score, features: r => r.features);
+ metrics = env.Clustering.Evaluate(data, r => r.preds.score, features: r => r.features);
Assert.True(double.IsNaN(metrics.Nmi));
- metrics = ctx.Evaluate(data, r => r.preds.score);
+ metrics = env.Clustering.Evaluate(data, r => r.preds.score);
Assert.NotNull(metrics);
Assert.InRange(metrics.AvgMinScore, 0.5262, 0.5264);
Assert.True(double.IsNaN(metrics.Nmi));
@@ -718,7 +723,7 @@ public void KMeans()
[Fact]
public void FastTreeRanking()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -759,7 +764,7 @@ public void FastTreeRanking()
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only
public void LightGBMRanking()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -800,7 +805,7 @@ public void LightGBMRanking()
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only
public void MultiClassLightGBM()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -838,7 +843,7 @@ public void MultiClassLightGBM()
[Fact]
public void MultiClassNaiveBayesTrainer()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
var dataSource = new MultiFileSource(dataPath);
@@ -883,7 +888,7 @@ public void MultiClassNaiveBayesTrainer()
[Fact]
public void HogwildSGDBinaryClassification()
{
- var env = new ConsoleEnvironment(seed: 0);
+ var env = new MLContext(seed: 0);
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new BinaryClassificationContext(env);
@@ -896,7 +901,8 @@ public void HogwildSGDBinaryClassification()
var est = reader.MakeNewEstimator()
.Append(r => (r.label, preds: ctx.Trainers.StochasticGradientDescentClassificationTrainer(r.label, r.features,
l2Weight: 0,
- onFit: (p) => { pred = p; })));
+ onFit: (p) => { pred = p; },
+ advancedSettings: s => s.NumThreads = 1)));
var pipe = reader.Append(est);
diff --git a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs
index 2f19a5c4f3..1cd0065703 100644
--- a/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs
+++ b/test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs
@@ -20,19 +20,16 @@ public void UniformRandomSweeperReturnsDistinctValuesWhenProposeSweep()
{
DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator();
- using (var writer = new StreamWriter(new MemoryStream()))
- using (var env = new ConsoleEnvironment(42, outWriter: writer, errWriter: writer))
- {
- var sweeper = new UniformRandomSweeper(env,
+ var env = new MLContext(42);
+ var sweeper = new UniformRandomSweeper(env,
new SweeperBase.ArgumentsBase(),
new[] { valueGenerator });
- var results = sweeper.ProposeSweeps(3);
- Assert.NotNull(results);
+ var results = sweeper.ProposeSweeps(3);
+ Assert.NotNull(results);
- int length = results.Length;
- Assert.Equal(2, length);
- }
+ int length = results.Length;
+ Assert.Equal(2, length);
}
[Fact]
@@ -40,19 +37,16 @@ public void RandomGridSweeperReturnsDistinctValuesWhenProposeSweep()
{
DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator();
- using (var writer = new StreamWriter(new MemoryStream()))
- using (var env = new ConsoleEnvironment(42, outWriter: writer, errWriter: writer))
- {
- var sweeper = new RandomGridSweeper(env,
- new RandomGridSweeper.Arguments(),
- new[] { valueGenerator });
+ var env = new MLContext(42);
+ var sweeper = new RandomGridSweeper(env,
+ new RandomGridSweeper.Arguments(),
+ new[] { valueGenerator });
- var results = sweeper.ProposeSweeps(3);
- Assert.NotNull(results);
+ var results = sweeper.ProposeSweeps(3);
+ Assert.NotNull(results);
- int length = results.Length;
- Assert.Equal(2, length);
- }
+ int length = results.Length;
+ Assert.Equal(2, length);
}
private static DiscreteValueGenerator CreateDiscreteValueGenerator()
diff --git a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs
index d191084ba8..0718a7edde 100644
--- a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs
+++ b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs
@@ -94,39 +94,37 @@ public void TestDiscreteValueSweep(double normalizedValue, string expected)
[Fact]
public void TestRandomSweeper()
{
- using (var env = new ConsoleEnvironment(42))
+ var env = new MLContext(42);
+ var args = new SweeperBase.ArgumentsBase()
{
- var args = new SweeperBase.ArgumentsBase()
- {
- SweptParameters = new[] {
+ SweptParameters = new[] {
ComponentFactoryUtils.CreateFromFunction(
environ => new LongValueGenerator(new LongParamArguments() { Name = "foo", Min = 10, Max = 20 })),
ComponentFactoryUtils.CreateFromFunction(
environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 100, Max = 200 }))
}
- };
+ };
- var sweeper = new UniformRandomSweeper(env, args);
- var initialList = sweeper.ProposeSweeps(5, new List());
- Assert.Equal(5, initialList.Length);
- foreach (var parameterSet in initialList)
+ var sweeper = new UniformRandomSweeper(env, args);
+ var initialList = sweeper.ProposeSweeps(5, new List());
+ Assert.Equal(5, initialList.Length);
+ foreach (var parameterSet in initialList)
+ {
+ foreach (var parameterValue in parameterSet)
{
- foreach (var parameterValue in parameterSet)
+ if (parameterValue.Name == "foo")
{
- if (parameterValue.Name == "foo")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.InRange(val, 10, 20);
- }
- else if (parameterValue.Name == "bar")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.InRange(val, 100, 200);
- }
- else
- {
- Assert.True(false, "Wrong parameter");
- }
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.InRange(val, 10, 20);
+ }
+ else if (parameterValue.Name == "bar")
+ {
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.InRange(val, 100, 200);
+ }
+ else
+ {
+ Assert.True(false, "Wrong parameter");
}
}
}
@@ -136,282 +134,273 @@ public void TestRandomSweeper()
public void TestSimpleSweeperAsync()
{
var random = new Random(42);
- using (var env = new ConsoleEnvironment(42))
+ var env = new MLContext(42);
+ const int sweeps = 100;
+ var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.ArgumentsBase
{
- int sweeps = 100;
- var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.ArgumentsBase
- {
- SweptParameters = new IComponentFactory[] {
+ SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5 })),
ComponentFactoryUtils.CreateFromFunction(
environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
}
- });
+ });
- var paramSets = new List();
- for (int i = 0; i < sweeps; i++)
- {
- var task = sweeper.Propose();
- Assert.True(task.IsCompleted);
- paramSets.Add(task.Result.ParameterSet);
- var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true);
- sweeper.Update(task.Result.Id, result);
- }
- Assert.Equal(sweeps, paramSets.Count);
- CheckAsyncSweeperResult(paramSets);
+ var paramSets = new List();
+ for (int i = 0; i < sweeps; i++)
+ {
+ var task = sweeper.Propose();
+ Assert.True(task.IsCompleted);
+ paramSets.Add(task.Result.ParameterSet);
+ var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true);
+ sweeper.Update(task.Result.Id, result);
+ }
+ Assert.Equal(sweeps, paramSets.Count);
+ CheckAsyncSweeperResult(paramSets);
- // Test consumption without ever calling Update.
- var gridArgs = new RandomGridSweeper.Arguments();
- gridArgs.SweptParameters = new IComponentFactory[] {
+ // Test consumption without ever calling Update.
+ var gridArgs = new RandomGridSweeper.Arguments();
+ gridArgs.SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})),
ComponentFactoryUtils.CreateFromFunction(
environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 100, LogBase = true }))
};
- var gridSweeper = new SimpleAsyncSweeper(env, gridArgs);
- paramSets.Clear();
- for (int i = 0; i < sweeps; i++)
- {
- var task = gridSweeper.Propose();
- Assert.True(task.IsCompleted);
- paramSets.Add(task.Result.ParameterSet);
- }
- Assert.Equal(sweeps, paramSets.Count);
- CheckAsyncSweeperResult(paramSets);
+ var gridSweeper = new SimpleAsyncSweeper(env, gridArgs);
+ paramSets.Clear();
+ for (int i = 0; i < sweeps; i++)
+ {
+ var task = gridSweeper.Propose();
+ Assert.True(task.IsCompleted);
+ paramSets.Add(task.Result.ParameterSet);
}
+ Assert.Equal(sweeps, paramSets.Count);
+ CheckAsyncSweeperResult(paramSets);
}
[Fact]
public void TestDeterministicSweeperAsyncCancellation()
{
var random = new Random(42);
- using (var env = new ConsoleEnvironment(42))
- {
- var args = new DeterministicSweeperAsync.Arguments();
- args.BatchSize = 5;
- args.Relaxation = 1;
-
- args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
- environ => new KdoSweeper(environ,
- new KdoSweeper.Arguments()
- {
- SweptParameters = new IComponentFactory[] {
+ var env = new MLContext(42);
+ var args = new DeterministicSweeperAsync.Arguments();
+ args.BatchSize = 5;
+ args.Relaxation = 1;
+
+ args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
+ environ => new KdoSweeper(environ,
+ new KdoSweeper.Arguments()
+ {
+ SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
t => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})),
ComponentFactoryUtils.CreateFromFunction(
t => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
- }
- }));
+ }
+ }));
- var sweeper = new DeterministicSweeperAsync(env, args);
+ var sweeper = new DeterministicSweeperAsync(env, args);
- int sweeps = 20;
- var tasks = new List>();
- int numCompleted = 0;
- for (int i = 0; i < sweeps; i++)
- {
- var task = sweeper.Propose();
- if (i < args.BatchSize - args.Relaxation)
- {
- Assert.True(task.IsCompleted);
- sweeper.Update(task.Result.Id, new RunResult(task.Result.ParameterSet, random.NextDouble(), true));
- numCompleted++;
- }
- else
- tasks.Add(task);
- }
- // Cancel after the first barrier and check if the number of registered actions
- // is indeed 2 * batchSize.
- sweeper.Cancel();
- Task.WaitAll(tasks.ToArray());
- foreach (var task in tasks)
+ int sweeps = 20;
+ var tasks = new List>();
+ int numCompleted = 0;
+ for (int i = 0; i < sweeps; i++)
+ {
+ var task = sweeper.Propose();
+ if (i < args.BatchSize - args.Relaxation)
{
- if (task.Result != null)
- numCompleted++;
+ Assert.True(task.IsCompleted);
+ sweeper.Update(task.Result.Id, new RunResult(task.Result.ParameterSet, random.NextDouble(), true));
+ numCompleted++;
}
- Assert.Equal(args.BatchSize + args.BatchSize, numCompleted);
+ else
+ tasks.Add(task);
}
+ // Cancel after the first barrier and check if the number of registered actions
+ // is indeed 2 * batchSize.
+ sweeper.Cancel();
+ Task.WaitAll(tasks.ToArray());
+ foreach (var task in tasks)
+ {
+ if (task.Result != null)
+ numCompleted++;
+ }
+ Assert.Equal(args.BatchSize + args.BatchSize, numCompleted);
}
[Fact]
public void TestDeterministicSweeperAsync()
{
var random = new Random(42);
- using (var env = new ConsoleEnvironment(42))
- {
- var args = new DeterministicSweeperAsync.Arguments();
- args.BatchSize = 5;
- args.Relaxation = args.BatchSize - 1;
-
- args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
- environ => new SmacSweeper(environ,
- new SmacSweeper.Arguments()
- {
- SweptParameters = new IComponentFactory[] {
+ var env = new MLContext(42);
+ var args = new DeterministicSweeperAsync.Arguments();
+ args.BatchSize = 5;
+ args.Relaxation = args.BatchSize - 1;
+
+ args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
+ environ => new SmacSweeper(environ,
+ new SmacSweeper.Arguments()
+ {
+ SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
t => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})),
ComponentFactoryUtils.CreateFromFunction(
t => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
- }
- }));
-
- var sweeper = new DeterministicSweeperAsync(env, args);
+ }
+ }));
- // Test single-threaded consumption.
- int sweeps = 10;
- var paramSets = new List();
- for (int i = 0; i < sweeps; i++)
- {
- var task = sweeper.Propose();
- Assert.True(task.IsCompleted);
- paramSets.Add(task.Result.ParameterSet);
- var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true);
- sweeper.Update(task.Result.Id, result);
- }
- Assert.Equal(sweeps, paramSets.Count);
- CheckAsyncSweeperResult(paramSets);
-
- // Create two batches and test if the 2nd batch is executed after the synchronization barrier is reached.
- object mlock = new object();
- var tasks = new Task[sweeps];
- args.Relaxation = args.Relaxation - 1;
- sweeper = new DeterministicSweeperAsync(env, args);
- paramSets.Clear();
- var results = new List>();
- for (int i = 0; i < args.BatchSize; i++)
- {
- var task = sweeper.Propose();
- Assert.True(task.IsCompleted);
- tasks[i] = task;
- if (task.Result == null)
- continue;
- results.Add(new KeyValuePair(task.Result.Id, new RunResult(task.Result.ParameterSet, 0.42, true)));
- }
- // Register consumers for the 2nd batch. Those consumers will await until at least one run
- // in the previous batch has been posted to the sweeper.
- for (int i = args.BatchSize; i < 2 * args.BatchSize; i++)
- {
- var task = sweeper.Propose();
- Assert.False(task.IsCompleted);
- tasks[i] = task;
- }
- // Call update to unblock the 2nd batch.
- foreach (var run in results)
- sweeper.Update(run.Key, run.Value);
+ var sweeper = new DeterministicSweeperAsync(env, args);
- Task.WaitAll(tasks);
- tasks.All(t => t.IsCompleted);
+ // Test single-threaded consumption.
+ int sweeps = 10;
+ var paramSets = new List();
+ for (int i = 0; i < sweeps; i++)
+ {
+ var task = sweeper.Propose();
+ Assert.True(task.IsCompleted);
+ paramSets.Add(task.Result.ParameterSet);
+ var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true);
+ sweeper.Update(task.Result.Id, result);
+ }
+ Assert.Equal(sweeps, paramSets.Count);
+ CheckAsyncSweeperResult(paramSets);
+
+ // Create two batches and test if the 2nd batch is executed after the synchronization barrier is reached.
+ object mlock = new object();
+ var tasks = new Task[sweeps];
+ args.Relaxation = args.Relaxation - 1;
+ sweeper = new DeterministicSweeperAsync(env, args);
+ paramSets.Clear();
+ var results = new List>();
+ for (int i = 0; i < args.BatchSize; i++)
+ {
+ var task = sweeper.Propose();
+ Assert.True(task.IsCompleted);
+ tasks[i] = task;
+ if (task.Result == null)
+ continue;
+ results.Add(new KeyValuePair(task.Result.Id, new RunResult(task.Result.ParameterSet, 0.42, true)));
+ }
+ // Register consumers for the 2nd batch. Those consumers will await until at least one run
+ // in the previous batch has been posted to the sweeper.
+ for (int i = args.BatchSize; i < 2 * args.BatchSize; i++)
+ {
+ var task = sweeper.Propose();
+ Assert.False(task.IsCompleted);
+ tasks[i] = task;
}
+ // Call update to unblock the 2nd batch.
+ foreach (var run in results)
+ sweeper.Update(run.Key, run.Value);
+
+ Task.WaitAll(tasks);
+ tasks.All(t => t.IsCompleted);
}
[Fact]
public void TestDeterministicSweeperAsyncParallel()
{
var random = new Random(42);
- using (var env = new ConsoleEnvironment(42))
- {
- int batchSize = 5;
- int sweeps = 20;
- var paramSets = new List();
- var args = new DeterministicSweeperAsync.Arguments();
- args.BatchSize = batchSize;
- args.Relaxation = batchSize - 2;
-
- args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
- environ => new SmacSweeper(environ,
- new SmacSweeper.Arguments()
- {
- SweptParameters = new IComponentFactory[] {
+ var env = new MLContext(42);
+ const int batchSize = 5;
+ const int sweeps = 20;
+ var paramSets = new List();
+ var args = new DeterministicSweeperAsync.Arguments();
+ args.BatchSize = batchSize;
+ args.Relaxation = batchSize - 2;
+
+ args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
+ environ => new SmacSweeper(environ,
+ new SmacSweeper.Arguments()
+ {
+ SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
t => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})),
ComponentFactoryUtils.CreateFromFunction(
t => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
- }
- }));
+ }
+ }));
- var sweeper = new DeterministicSweeperAsync(env, args);
+ var sweeper = new DeterministicSweeperAsync(env, args);
- var mlock = new object();
- var options = new ParallelOptions();
- options.MaxDegreeOfParallelism = 4;
+ var mlock = new object();
+ var options = new ParallelOptions();
+ options.MaxDegreeOfParallelism = 4;
- // Sleep randomly to simulate doing work.
- int[] sleeps = new int[sweeps];
- for (int i = 0; i < sleeps.Length; i++)
- sleeps[i] = random.Next(10, 100);
- var r = Parallel.For(0, sweeps, options, (int i) =>
- {
- var task = sweeper.Propose();
- task.Wait();
- Assert.Equal(TaskStatus.RanToCompletion, task.Status);
- var paramWithId = task.Result;
- if (paramWithId == null)
- return;
- Thread.Sleep(sleeps[i]);
- var result = new RunResult(paramWithId.ParameterSet, 0.42, true);
- sweeper.Update(paramWithId.Id, result);
- lock (mlock)
- paramSets.Add(paramWithId.ParameterSet);
- });
- Assert.True(paramSets.Count <= sweeps);
- CheckAsyncSweeperResult(paramSets);
- }
+ // Sleep randomly to simulate doing work.
+ int[] sleeps = new int[sweeps];
+ for (int i = 0; i < sleeps.Length; i++)
+ sleeps[i] = random.Next(10, 100);
+ var r = Parallel.For(0, sweeps, options, (int i) =>
+ {
+ var task = sweeper.Propose();
+ task.Wait();
+ Assert.Equal(TaskStatus.RanToCompletion, task.Status);
+ var paramWithId = task.Result;
+ if (paramWithId == null)
+ return;
+ Thread.Sleep(sleeps[i]);
+ var result = new RunResult(paramWithId.ParameterSet, 0.42, true);
+ sweeper.Update(paramWithId.Id, result);
+ lock (mlock)
+ paramSets.Add(paramWithId.ParameterSet);
+ });
+ Assert.True(paramSets.Count <= sweeps);
+ CheckAsyncSweeperResult(paramSets);
}
[Fact]
public async Task TestNelderMeadSweeperAsync()
{
var random = new Random(42);
- using (var env = new ConsoleEnvironment(42))
- {
- int batchSize = 5;
- int sweeps = 40;
- var paramSets = new List();
- var args = new DeterministicSweeperAsync.Arguments();
- args.BatchSize = batchSize;
- args.Relaxation = 0;
-
- args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
- environ => {
- var param = new IComponentFactory[] {
+ var env = new MLContext(42);
+ const int batchSize = 5;
+ const int sweeps = 40;
+ var paramSets = new List();
+ var args = new DeterministicSweeperAsync.Arguments();
+ args.BatchSize = batchSize;
+ args.Relaxation = 0;
+
+ args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
+ environ =>
+ {
+ var param = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
innerEnviron => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})),
ComponentFactoryUtils.CreateFromFunction(
innerEnviron => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
- };
+ };
- var nelderMeadSweeperArgs = new NelderMeadSweeper.Arguments()
- {
- SweptParameters = param,
- FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction(
- (firstBatchSweeperEnviron, firstBatchSweeperArgs) =>
- new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param }))
- };
+ var nelderMeadSweeperArgs = new NelderMeadSweeper.Arguments()
+ {
+ SweptParameters = param,
+ FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction(
+ (firstBatchSweeperEnviron, firstBatchSweeperArgs) =>
+ new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param }))
+ };
- return new NelderMeadSweeper(environ, nelderMeadSweeperArgs);
- }
- );
+ return new NelderMeadSweeper(environ, nelderMeadSweeperArgs);
+ }
+ );
- var sweeper = new DeterministicSweeperAsync(env, args);
- var mlock = new object();
- double[] metrics = new double[sweeps];
- for (int i = 0; i < metrics.Length; i++)
- metrics[i] = random.NextDouble();
+ var sweeper = new DeterministicSweeperAsync(env, args);
+ var mlock = new object();
+ double[] metrics = new double[sweeps];
+ for (int i = 0; i < metrics.Length; i++)
+ metrics[i] = random.NextDouble();
- for (int i = 0; i < sweeps; i++)
- {
- var paramWithId = await sweeper.Propose();
- if (paramWithId == null)
- return;
- var result = new RunResult(paramWithId.ParameterSet, metrics[i], true);
- sweeper.Update(paramWithId.Id, result);
- lock (mlock)
- paramSets.Add(paramWithId.ParameterSet);
- }
- Assert.True(paramSets.Count <= sweeps);
- CheckAsyncSweeperResult(paramSets);
+ for (int i = 0; i < sweeps; i++)
+ {
+ var paramWithId = await sweeper.Propose();
+ if (paramWithId == null)
+ return;
+ var result = new RunResult(paramWithId.ParameterSet, metrics[i], true);
+ sweeper.Update(paramWithId.Id, result);
+ lock (mlock)
+ paramSets.Add(paramWithId.ParameterSet);
}
+ Assert.True(paramSets.Count <= sweeps);
+ CheckAsyncSweeperResult(paramSets);
}
private void CheckAsyncSweeperResult(List paramSets)
@@ -442,275 +431,266 @@ private void CheckAsyncSweeperResult(List paramSets)
[Fact]
public void TestRandomGridSweeper()
{
- using (var env = new ConsoleEnvironment(42))
+ var env = new MLContext(42);
+ var args = new RandomGridSweeper.Arguments()
{
- var args = new RandomGridSweeper.Arguments()
- {
- SweptParameters = new[] {
+ SweptParameters = new[] {
ComponentFactoryUtils.CreateFromFunction(
environ => new LongValueGenerator(new LongParamArguments() { Name = "foo", Min = 10, Max = 20, NumSteps = 3 })),
ComponentFactoryUtils.CreateFromFunction(
environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 100, Max = 10000, LogBase = true, StepSize = 10 }))
}
- };
- var sweeper = new RandomGridSweeper(env, args);
- var initialList = sweeper.ProposeSweeps(5, new List());
- Assert.Equal(5, initialList.Length);
- var gridPoint = new bool[3][] {
+ };
+ var sweeper = new RandomGridSweeper(env, args);
+ var initialList = sweeper.ProposeSweeps(5, new List());
+ Assert.Equal(5, initialList.Length);
+ var gridPoint = new bool[3][] {
new bool[3],
new bool[3],
new bool[3]
};
- int i = 0;
- int j = 0;
- foreach (var parameterSet in initialList)
+ int i = 0;
+ int j = 0;
+ foreach (var parameterSet in initialList)
+ {
+ foreach (var parameterValue in parameterSet)
{
- foreach (var parameterValue in parameterSet)
+ if (parameterValue.Name == "foo")
{
- if (parameterValue.Name == "foo")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.True(val == 10 || val == 15 || val == 20);
- i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
- }
- else if (parameterValue.Name == "bar")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.True(val == 100 || val == 1000 || val == 10000);
- j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
- }
- else
- {
- Assert.True(false, "Wrong parameter");
- }
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.True(val == 10 || val == 15 || val == 20);
+ i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
+ }
+ else if (parameterValue.Name == "bar")
+ {
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.True(val == 100 || val == 1000 || val == 10000);
+ j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
+ }
+ else
+ {
+ Assert.True(false, "Wrong parameter");
}
- Assert.False(gridPoint[i][j]);
- gridPoint[i][j] = true;
}
+ Assert.False(gridPoint[i][j]);
+ gridPoint[i][j] = true;
+ }
- var nextList = sweeper.ProposeSweeps(5, initialList.Select(p => new RunResult(p)));
- Assert.Equal(4, nextList.Length);
- foreach (var parameterSet in nextList)
+ var nextList = sweeper.ProposeSweeps(5, initialList.Select(p => new RunResult(p)));
+ Assert.Equal(4, nextList.Length);
+ foreach (var parameterSet in nextList)
+ {
+ foreach (var parameterValue in parameterSet)
{
- foreach (var parameterValue in parameterSet)
+ if (parameterValue.Name == "foo")
{
- if (parameterValue.Name == "foo")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.True(val == 10 || val == 15 || val == 20);
- i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
- }
- else if (parameterValue.Name == "bar")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.True(val == 100 || val == 1000 || val == 10000);
- j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
- }
- else
- {
- Assert.True(false, "Wrong parameter");
- }
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.True(val == 10 || val == 15 || val == 20);
+ i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
+ }
+ else if (parameterValue.Name == "bar")
+ {
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.True(val == 100 || val == 1000 || val == 10000);
+ j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
+ }
+ else
+ {
+ Assert.True(false, "Wrong parameter");
}
- Assert.False(gridPoint[i][j]);
- gridPoint[i][j] = true;
}
+ Assert.False(gridPoint[i][j]);
+ gridPoint[i][j] = true;
+ }
- gridPoint = new bool[3][] {
+ gridPoint = new bool[3][] {
new bool[3],
new bool[3],
new bool[3]
};
- var lastList = sweeper.ProposeSweeps(10, null);
- Assert.Equal(9, lastList.Length);
- foreach (var parameterSet in lastList)
+ var lastList = sweeper.ProposeSweeps(10, null);
+ Assert.Equal(9, lastList.Length);
+ foreach (var parameterSet in lastList)
+ {
+ foreach (var parameterValue in parameterSet)
{
- foreach (var parameterValue in parameterSet)
+ if (parameterValue.Name == "foo")
{
- if (parameterValue.Name == "foo")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.True(val == 10 || val == 15 || val == 20);
- i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
- }
- else if (parameterValue.Name == "bar")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.True(val == 100 || val == 1000 || val == 10000);
- j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
- }
- else
- {
- Assert.True(false, "Wrong parameter");
- }
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.True(val == 10 || val == 15 || val == 20);
+ i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
+ }
+ else if (parameterValue.Name == "bar")
+ {
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.True(val == 100 || val == 1000 || val == 10000);
+ j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
+ }
+ else
+ {
+ Assert.True(false, "Wrong parameter");
}
- Assert.False(gridPoint[i][j]);
- gridPoint[i][j] = true;
}
- Assert.True(gridPoint.All(bArray => bArray.All(b => b)));
+ Assert.False(gridPoint[i][j]);
+ gridPoint[i][j] = true;
}
+ Assert.True(gridPoint.All(bArray => bArray.All(b => b)));
}
[Fact]
public void TestNelderMeadSweeper()
{
var random = new Random(42);
- using (var env = new ConsoleEnvironment(42))
- {
- var param = new IComponentFactory[] {
+ var env = new MLContext(42);
+ var param = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})),
ComponentFactoryUtils.CreateFromFunction(
environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
};
- var args = new NelderMeadSweeper.Arguments()
- {
- SweptParameters = param,
- FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction(
- (environ, firstBatchArgs) => {
- return new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param });
- }
- )
- };
- var sweeper = new NelderMeadSweeper(env, args);
- var sweeps = sweeper.ProposeSweeps(5, new List());
- Assert.Equal(3, sweeps.Length);
-
- var results = new List();
- for (int i = 1; i < 10; i++)
+ var args = new NelderMeadSweeper.Arguments()
+ {
+ SweptParameters = param,
+ FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction(
+ (environ, firstBatchArgs) =>
+ {
+ return new RandomGridSweeper(environ, new RandomGridSweeper.Arguments() { SweptParameters = param });
+ }
+ )
+ };
+ var sweeper = new NelderMeadSweeper(env, args);
+ var sweeps = sweeper.ProposeSweeps(5, new List());
+ Assert.Equal(3, sweeps.Length);
+
+ var results = new List();
+ for (int i = 1; i < 10; i++)
+ {
+ foreach (var parameterSet in sweeps)
{
- foreach (var parameterSet in sweeps)
+ foreach (var parameterValue in parameterSet)
{
- foreach (var parameterValue in parameterSet)
+ if (parameterValue.Name == "foo")
+ {
+ var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
+ Assert.InRange(val, 1, 5);
+ }
+ else if (parameterValue.Name == "bar")
{
- if (parameterValue.Name == "foo")
- {
- var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
- Assert.InRange(val, 1, 5);
- }
- else if (parameterValue.Name == "bar")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.InRange(val, 1, 1000);
- }
- else
- {
- Assert.True(false, "Wrong parameter");
- }
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.InRange(val, 1, 1000);
+ }
+ else
+ {
+ Assert.True(false, "Wrong parameter");
}
- results.Add(new RunResult(parameterSet, random.NextDouble(), true));
}
-
- sweeps = sweeper.ProposeSweeps(5, results);
+ results.Add(new RunResult(parameterSet, random.NextDouble(), true));
}
- Assert.True(sweeps.Length <= 5);
+
+ sweeps = sweeper.ProposeSweeps(5, results);
}
+ Assert.True(sweeps.Length <= 5);
}
[Fact]
public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper()
{
var random = new Random(42);
- using (var env = new ConsoleEnvironment(42))
- {
- var param = new IComponentFactory[] {
+ var env = new MLContext(42);
+ var param = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})),
ComponentFactoryUtils.CreateFromFunction(
environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
};
- var args = new NelderMeadSweeper.Arguments();
- args.SweptParameters = param;
- var sweeper = new NelderMeadSweeper(env, args);
- var sweeps = sweeper.ProposeSweeps(5, new List());
- Assert.Equal(3, sweeps.Length);
+ var args = new NelderMeadSweeper.Arguments();
+ args.SweptParameters = param;
+ var sweeper = new NelderMeadSweeper(env, args);
+ var sweeps = sweeper.ProposeSweeps(5, new List());
+ Assert.Equal(3, sweeps.Length);
- var results = new List();
- for (int i = 1; i < 10; i++)
+ var results = new List();
+ for (int i = 1; i < 10; i++)
+ {
+ foreach (var parameterSet in sweeps)
{
- foreach (var parameterSet in sweeps)
+ foreach (var parameterValue in parameterSet)
{
- foreach (var parameterValue in parameterSet)
+ if (parameterValue.Name == "foo")
+ {
+ var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
+ Assert.InRange(val, 1, 5);
+ }
+ else if (parameterValue.Name == "bar")
{
- if (parameterValue.Name == "foo")
- {
- var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
- Assert.InRange(val, 1, 5);
- }
- else if (parameterValue.Name == "bar")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.InRange(val, 1, 1000);
- }
- else
- {
- Assert.True(false, "Wrong parameter");
- }
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.InRange(val, 1, 1000);
+ }
+ else
+ {
+ Assert.True(false, "Wrong parameter");
}
- results.Add(new RunResult(parameterSet, random.NextDouble(), true));
}
-
- sweeps = sweeper.ProposeSweeps(5, results);
+ results.Add(new RunResult(parameterSet, random.NextDouble(), true));
}
- Assert.True(sweeps == null || sweeps.Length <= 5);
+
+ sweeps = sweeper.ProposeSweeps(5, results);
}
+ Assert.True(sweeps == null || sweeps.Length <= 5);
}
[Fact]
public void TestSmacSweeper()
{
- RunMTAThread(() =>
+ var random = new Random(42);
+ var env = new MLContext(42);
+ const int maxInitSweeps = 5;
+ var args = new SmacSweeper.Arguments()
{
- var random = new Random(42);
- using (var env = new ConsoleEnvironment(42))
- {
- int maxInitSweeps = 5;
- var args = new SmacSweeper.Arguments()
- {
- NumberInitialPopulation = 20,
- SweptParameters = new IComponentFactory[] {
+ NumberInitialPopulation = 20,
+ SweptParameters = new IComponentFactory[] {
ComponentFactoryUtils.CreateFromFunction(
environ => new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5})),
ComponentFactoryUtils.CreateFromFunction(
environ => new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 100, LogBase = true }))
}
- };
+ };
- var sweeper = new SmacSweeper(env, args);
- var results = new List();
- var sweeps = sweeper.ProposeSweeps(maxInitSweeps, results);
- Assert.Equal(Math.Min(args.NumberInitialPopulation, maxInitSweeps), sweeps.Length);
+ var sweeper = new SmacSweeper(env, args);
+ var results = new List();
+ var sweeps = sweeper.ProposeSweeps(maxInitSweeps, results);
+ Assert.Equal(Math.Min(args.NumberInitialPopulation, maxInitSweeps), sweeps.Length);
- for (int i = 1; i < 10; i++)
+ for (int i = 1; i < 10; i++)
+ {
+ foreach (var parameterSet in sweeps)
+ {
+ foreach (var parameterValue in parameterSet)
{
- foreach (var parameterSet in sweeps)
+ if (parameterValue.Name == "foo")
{
- foreach (var parameterValue in parameterSet)
- {
- if (parameterValue.Name == "foo")
- {
- var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
- Assert.InRange(val, 1, 5);
- }
- else if (parameterValue.Name == "bar")
- {
- var val = long.Parse(parameterValue.ValueText);
- Assert.InRange(val, 1, 1000);
- }
- else
- {
- Assert.True(false, "Wrong parameter");
- }
- }
- results.Add(new RunResult(parameterSet, random.NextDouble(), true));
+ var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
+ Assert.InRange(val, 1, 5);
+ }
+ else if (parameterValue.Name == "bar")
+ {
+ var val = long.Parse(parameterValue.ValueText);
+ Assert.InRange(val, 1, 1000);
+ }
+ else
+ {
+ Assert.True(false, "Wrong parameter");
}
-
- sweeps = sweeper.ProposeSweeps(5, results);
}
- Assert.Equal(5, sweeps.Length);
+ results.Add(new RunResult(parameterSet, random.NextDouble(), true));
}
- });
+
+ sweeps = sweeper.ProposeSweeps(5, results);
+ }
+ // Because only unique configurations are considered, the number asked for may exceed the number actually returned.
+ Assert.True(sweeps.Length <= 5);
}
}
}
diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
index 273028ddaa..0a1e7ada2f 100644
--- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
+++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
@@ -74,7 +74,8 @@ protected BaseTestBaseline(ITestOutputHelper output) : base(output)
// The writer to write to test log files.
protected StreamWriter LogWriter;
- protected ConsoleEnvironment Env;
+ private protected ConsoleEnvironment _env;
+ protected IHostEnvironment Env => _env;
protected MLContext ML;
private bool _normal;
private bool _passed;
@@ -96,7 +97,7 @@ protected override void Initialize()
string logPath = Path.Combine(logDir, FullTestName + LogSuffix);
LogWriter = OpenWriter(logPath);
_passed = true;
- Env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter)
+ _env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter)
.AddStandardComponents();
ML = new MLContext(42);
}
@@ -106,9 +107,8 @@ protected override void Initialize()
// It is called as a first step in test clean up.
protected override void Cleanup()
{
- if (Env != null)
- Env.Dispose();
- Env = null;
+ _env?.Dispose();
+ _env = null;
Contracts.Assert(IsActive);
Log("Test {0}: {1}: {2}", TestName,
@@ -535,7 +535,7 @@ private bool MatchNumberWithTolerance(MatchCollection firstCollection, MatchColl
double f1 = double.Parse(firstCollection[i].ToString());
double f2 = double.Parse(secondCollection[i].ToString());
- if(!CompareNumbersWithTolerance(f1, f2, i, digitsOfPrecision))
+ if (!CompareNumbersWithTolerance(f1, f2, i, digitsOfPrecision))
{
return false;
}
@@ -562,23 +562,23 @@ public bool CompareNumbersWithTolerance(double expected, double actual, int? ite
// would fail the inRange == true check, but would suceed the following, and we doconsider those two numbers
// (1.82844949 - 1.8284502) = -0.00000071
- double delta2 = 0;
- if (!inRange)
- {
- delta2 = Math.Round(expected - actual, digitsOfPrecision);
- inRange = delta2 >= -allowedVariance && delta2 <= allowedVariance;
- }
+ double delta2 = 0;
+ if (!inRange)
+ {
+ delta2 = Math.Round(expected - actual, digitsOfPrecision);
+ inRange = delta2 >= -allowedVariance && delta2 <= allowedVariance;
+ }
- if (!inRange)
- {
- var message = iterationOnCollection != null ? "" : $"Output and baseline mismatch at line {iterationOnCollection}." + Environment.NewLine;
+ if (!inRange)
+ {
+ var message = iterationOnCollection != null ? "" : $"Output and baseline mismatch at line {iterationOnCollection}." + Environment.NewLine;
- Fail(_allowMismatch, message +
- $"Values to compare are {expected} and {actual}" + Environment.NewLine +
- $"\t AllowedVariance: {allowedVariance}" + Environment.NewLine +
- $"\t delta: {delta}" + Environment.NewLine +
- $"\t delta2: {delta2}" + Environment.NewLine);
- }
+ Fail(_allowMismatch, message +
+ $"Values to compare are {expected} and {actual}" + Environment.NewLine +
+ $"\t AllowedVariance: {allowedVariance}" + Environment.NewLine +
+ $"\t delta: {delta}" + Environment.NewLine +
+ $"\t delta2: {delta2}" + Environment.NewLine);
+ }
return inRange;
}
@@ -817,11 +817,8 @@ protected static StreamReader OpenReader(string path)
///
protected static int MainForTest(string args)
{
- using (var env = new ConsoleEnvironment())
- {
- int result = Maml.MainCore(env, args, false);
- return result;
- }
+ var env = new MLContext();
+ return Maml.MainCore(env, args, false);
}
}
diff --git a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs
index 8eb4edb1b5..cf0303bc05 100644
--- a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs
+++ b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs
@@ -158,7 +158,7 @@ protected void Run(RunContext ctx, int digitsOfPrecision = DigitsOfPrecision)
{
// Not capturing into a specific log.
Log("*** Start raw predictor output");
- res = MainForTest(Env, LogWriter, string.Join(" ", ctx.Command, runcmd), ctx.BaselineProgress);
+ res = MainForTest(_env, LogWriter, string.Join(" ", ctx.Command, runcmd), ctx.BaselineProgress);
Log("*** End raw predictor output, return={0}", res);
return;
}
@@ -189,7 +189,7 @@ protected void Run(RunContext ctx, int digitsOfPrecision = DigitsOfPrecision)
Log(" Saving ini file: {0}", str);
}
- MainForTest(Env, LogWriter, str);
+ MainForTest(_env, LogWriter, str);
files.ForEach(file => CheckEqualityNormalized(dir, file, digitsOfPrecision: digitsOfPrecision));
}
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
index 306a8a9920..5e5b9d4cd6 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
@@ -123,7 +123,7 @@ public void SavePipeLabelParsers()
string name = TestName + "4-out.txt";
string pathOut = DeleteOutputPath("SavePipe", name);
using (var writer = OpenWriter(pathOut))
- using (Env.RedirectChannelOutput(writer, writer))
+ using (_env.RedirectChannelOutput(writer, writer))
{
TestCore(pathData, true,
new[] {
@@ -133,7 +133,7 @@ public void SavePipeLabelParsers()
"xf=SelectColumns{keepcol=RawLabel keepcol=FileLabelNum keepcol=FileLabelKey hidden=-}"
}, suffix: "4");
writer.WriteLine(ProgressLogLine);
- Env.PrintProgress();
+ _env.PrintProgress();
}
CheckEqualityNormalized("SavePipe", name);
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
index ff9c83160b..03ca7d8600 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
@@ -172,19 +172,16 @@ private void CheckSameSchemaShape(SchemaShape promised, SchemaShape delivered)
///
protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPipe,
Action actLoader = null, string suffix = "", string suffixBase = null, bool checkBaseline = true,
- bool forceDense = false, bool logCurs = false, ConsoleEnvironment env = null, bool roundTripText = true,
+ bool forceDense = false, bool logCurs = false, bool roundTripText = true,
bool checkTranspose = false, bool checkId = true, bool baselineSchema = true)
{
Contracts.AssertValue(Env);
- if (env == null)
- env = Env;
MultiFileSource files;
IDataLoader compositeLoader;
- var pipe1 = compositeLoader = CreatePipeDataLoader(env, pathData, argsPipe, out files);
+ var pipe1 = compositeLoader = CreatePipeDataLoader(_env, pathData, argsPipe, out files);
- if (actLoader != null)
- actLoader(compositeLoader);
+ actLoader?.Invoke(compositeLoader);
// Re-apply pipe to the loader and check equality.
var comp = compositeLoader as CompositeDataLoader;
@@ -194,7 +191,7 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi
srcLoader = comp.View;
while (srcLoader is IDataTransform)
srcLoader = ((IDataTransform)srcLoader).Source;
- var reappliedPipe = ApplyTransformUtils.ApplyAllTransformsToData(env, comp.View, srcLoader);
+ var reappliedPipe = ApplyTransformUtils.ApplyAllTransformsToData(_env, comp.View, srcLoader);
if (!CheckMetadataTypes(reappliedPipe.Schema))
Failed();
@@ -210,12 +207,12 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi
string pathLog = DeleteOutputPath("SavePipe", name);
using (var writer = OpenWriter(pathLog))
- using (env.RedirectChannelOutput(writer, writer))
+ using (_env.RedirectChannelOutput(writer, writer))
{
long count = 0;
// Set the concurrency to 1 for this; restore later.
- int conc = env.ConcurrencyFactor;
- env.ConcurrencyFactor = 1;
+ int conc = _env.ConcurrencyFactor;
+ _env.ConcurrencyFactor = 1;
using (var curs = pipe1.GetRowCursor(c => true, null))
{
while (curs.MoveNext())
@@ -224,14 +221,14 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi
}
}
writer.WriteLine("Cursored through {0} rows", count);
- env.ConcurrencyFactor = conc;
+ _env.ConcurrencyFactor = conc;
}
CheckEqualityNormalized("SavePipe", name);
}
var pathModel = SavePipe(pipe1, suffix);
- var pipe2 = LoadPipe(pathModel, env, files);
+ var pipe2 = LoadPipe(pathModel, _env, files);
if (!CheckMetadataTypes(pipe2.Schema))
Failed();
@@ -243,21 +240,21 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi
if (pipe1.Schema.ColumnCount > 0)
{
// The text saver fails if there are no columns, so we cannot check in that case.
- if (!SaveLoadText(pipe1, env, keepHidden, suffix, suffixBase, checkBaseline, forceDense, roundTripText))
+ if (!SaveLoadText(pipe1, _env, keepHidden, suffix, suffixBase, checkBaseline, forceDense, roundTripText))
Failed();
// The transpose saver likewise fails for the same reason.
- if (checkTranspose && !SaveLoadTransposed(pipe1, env, suffix))
+ if (checkTranspose && !SaveLoadTransposed(pipe1, _env, suffix))
Failed();
}
- if (!SaveLoad(pipe1, env, suffix))
+ if (!SaveLoad(pipe1, _env, suffix))
Failed();
// Check that the pipe doesn't shuffle when it cannot :).
if (srcLoader != null)
{
// First we need to cache the data so it can be shuffled.
- var cachedData = new CacheDataView(env, srcLoader, null);
- var newPipe = ApplyTransformUtils.ApplyAllTransformsToData(env, comp.View, cachedData);
+ var cachedData = new CacheDataView(_env, srcLoader, null);
+ var newPipe = ApplyTransformUtils.ApplyAllTransformsToData(_env, comp.View, cachedData);
if (!newPipe.CanShuffle)
{
using (var c1 = newPipe.GetRowCursor(col => true, new SysRandom(123)))
diff --git a/test/Microsoft.ML.TestFramework/ModelHelper.cs b/test/Microsoft.ML.TestFramework/ModelHelper.cs
index 94341e4a85..3fe189183b 100644
--- a/test/Microsoft.ML.TestFramework/ModelHelper.cs
+++ b/test/Microsoft.ML.TestFramework/ModelHelper.cs
@@ -13,7 +13,7 @@ namespace Microsoft.ML.TestFramework
{
public static class ModelHelper
{
- private static ConsoleEnvironment s_environment = new ConsoleEnvironment(seed: 1);
+ private static IHostEnvironment s_environment = new MLContext(seed: 1);
private static ITransformModel s_housePriceModel;
public static void WriteKcHousePriceModel(string dataPath, string outputModelPath)
@@ -35,7 +35,6 @@ public static void WriteKcHousePriceModel(string dataPath, Stream stream)
{
s_housePriceModel = CreateKcHousePricePredictorModel(dataPath);
}
-
s_housePriceModel.Save(s_environment, stream);
}
diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs
index 75f8d5d1e0..4e5c04395f 100644
--- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs
+++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs
@@ -292,10 +292,10 @@ protected bool TestCore(RunContextBase ctx, string cmdName, string args, int dig
Contracts.AssertValueOrNull(args);
OutputPath outputPath = ctx.StdoutPath();
using (var newWriter = OpenWriter(outputPath.Path))
- using (Env.RedirectChannelOutput(newWriter, newWriter))
+ using (_env.RedirectChannelOutput(newWriter, newWriter))
{
- Env.ResetProgressChannel();
- int res = MainForTest(Env, newWriter, string.Format("{0} {1}", cmdName, args), ctx.BaselineProgress);
+ _env.ResetProgressChannel();
+ int res = MainForTest(_env, newWriter, string.Format("{0} {1}", cmdName, args), ctx.BaselineProgress);
if (res != 0)
Log("*** Predictor returned {0}", res);
}
@@ -322,7 +322,7 @@ protected bool TestCore(RunContextBase ctx, string cmdName, string args, int dig
///
/// The arguments for MAML.
/// Whether to print the progress summary. If true, progress summary will appear in the end of baseline output file.
- protected static int MainForTest(ConsoleEnvironment env, TextWriter writer, string args, bool printProgress = false)
+ private protected static int MainForTest(ConsoleEnvironment env, TextWriter writer, string args, bool printProgress = false)
{
Contracts.AssertValue(env);
Contracts.AssertValue(writer);
@@ -364,7 +364,7 @@ private bool TestCoreCore(RunContextBase ctx, string cmdName, string dataPath, P
return TestCoreCore(ctx, cmdName, dataPath, situation, inModelPath, outModelPath, loaderArgs, extraArgs, DigitsOfPrecision, toCompare);
}
- private bool TestCoreCore(RunContextBase ctx, string cmdName, string dataPath, PathArgument.Usage situation,
+ private bool TestCoreCore(RunContextBase ctx, string cmdName, string dataPath, PathArgument.Usage situation,
OutputPath inModelPath, OutputPath outModelPath, string loaderArgs, string extraArgs, int digitsOfPrecision, params PathArgument[] toCompare)
{
Contracts.AssertNonEmpty(cmdName);
@@ -503,24 +503,22 @@ private string DataArg(string dataPath)
protected void TestPipeFromModel(string dataPath, OutputPath model)
{
- using (var env = new ConsoleEnvironment(42))
- {
- var files = new MultiFileSource(dataPath);
+ var env = new MLContext(seed: 42);
+ var files = new MultiFileSource(dataPath);
- bool tmp;
- IDataView pipe;
- using (var file = Env.OpenInputFile(model.Path))
- using (var strm = file.OpenReadStream())
- using (var rep = RepositoryReader.Open(strm, env))
- {
- ModelLoadContext.LoadModel(env,
- out pipe, rep, ModelFileUtils.DirDataLoaderModel, files);
- }
-
- using (var c = pipe.GetRowCursor(col => true))
- tmp = CheckSameValues(c, pipe, true, true, true);
- Check(tmp, "Single value same failed");
+ bool tmp;
+ IDataView pipe;
+ using (var file = Env.OpenInputFile(model.Path))
+ using (var strm = file.OpenReadStream())
+ using (var rep = RepositoryReader.Open(strm, env))
+ {
+ ModelLoadContext.LoadModel(env,
+ out pipe, rep, ModelFileUtils.DirDataLoaderModel, files);
}
+
+ using (var c = pipe.GetRowCursor(col => true))
+ tmp = CheckSameValues(c, pipe, true, true, true);
+ Check(tmp, "Single value same failed");
}
}
@@ -1969,7 +1967,7 @@ public void CommandTrainingBinaryFieldAwareFactorizationMachineWithInitializatio
string data = GetDataPath("breast-cancer.txt");
OutputPath model = ModelPath();
- TestCore("traintest", data, loaderArgs, extraArgs + " test=" + data, digitsOfPrecision:5);
+ TestCore("traintest", data, loaderArgs, extraArgs + " test=" + data, digitsOfPrecision: 5);
_step++;
TestInOutCore("traintest", data, model, extraArgs + " " + loaderArgs + " " + "cont+" + " " + "test=" + data);
@@ -1988,17 +1986,17 @@ public void CommandTrainingBinaryFactorizationMachineWithValidation()
string args = $"{loaderArgs} data={trainData} valid={validData} test={validData} {extraArgs} out={model}";
OutputPath outputPath = StdoutPath();
using (var newWriter = OpenWriter(outputPath.Path))
- using (Env.RedirectChannelOutput(newWriter, newWriter))
+ using (_env.RedirectChannelOutput(newWriter, newWriter))
{
- Env.ResetProgressChannel();
- int res = MainForTest(Env, newWriter, string.Format("{0} {1}", "traintest", args), true);
+ _env.ResetProgressChannel();
+ int res = MainForTest(_env, newWriter, string.Format("{0} {1}", "traintest", args), true);
Assert.True(res == 0);
}
// see https://github.com/dotnet/machinelearning/issues/404
// in Linux, the clang sqrt() results vary highly from the ones in mac and Windows.
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
- Assert.True(outputPath.CheckEqualityNormalized(digitsOfPrecision:4));
+ Assert.True(outputPath.CheckEqualityNormalized(digitsOfPrecision: 4));
else
Assert.True(outputPath.CheckEqualityNormalized());
@@ -2017,10 +2015,10 @@ public void CommandTrainingBinaryFieldAwareFactorizationMachineWithValidation()
string args = $"{loaderArgs} data={trainData} valid={validData} test={validData} {extraArgs} out={model}";
OutputPath outputPath = StdoutPath();
using (var newWriter = OpenWriter(outputPath.Path))
- using (Env.RedirectChannelOutput(newWriter, newWriter))
+ using (_env.RedirectChannelOutput(newWriter, newWriter))
{
- Env.ResetProgressChannel();
- int res = MainForTest(Env, newWriter, string.Format("{0} {1}", "traintest", args), true);
+ _env.ResetProgressChannel();
+ int res = MainForTest(_env, newWriter, string.Format("{0} {1}", "traintest", args), true);
Assert.Equal(0, res);
}
@@ -2044,15 +2042,15 @@ public void CommandTrainingBinaryFactorizationMachineWithValidationAndInitializa
OutputPath outputPath = StdoutPath();
string args = $"data={data} test={data} valid={data} in={model.Path} cont+" + " " + loaderArgs + " " + extraArgs;
using (var newWriter = OpenWriter(outputPath.Path))
- using (Env.RedirectChannelOutput(newWriter, newWriter))
+ using (_env.RedirectChannelOutput(newWriter, newWriter))
{
- Env.ResetProgressChannel();
- int res = MainForTest(Env, newWriter, string.Format("{0} {1}", "traintest", args), true);
+ _env.ResetProgressChannel();
+ int res = MainForTest(_env, newWriter, string.Format("{0} {1}", "traintest", args), true);
Assert.True(res == 0);
}
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
- Assert.True(outputPath.CheckEqualityNormalized(digitsOfPrecision:4));
+ Assert.True(outputPath.CheckEqualityNormalized(digitsOfPrecision: 4));
else
Assert.True(outputPath.CheckEqualityNormalized());
@@ -2074,10 +2072,10 @@ public void CommandTrainingBinaryFieldAwareFactorizationMachineWithValidationAnd
OutputPath outputPath = StdoutPath();
string args = $"data={data} test={data} valid={data} in={model.Path} cont+" + " " + loaderArgs + " " + extraArgs;
using (var newWriter = OpenWriter(outputPath.Path))
- using (Env.RedirectChannelOutput(newWriter, newWriter))
+ using (_env.RedirectChannelOutput(newWriter, newWriter))
{
- Env.ResetProgressChannel();
- int res = MainForTest(Env, newWriter, string.Format("{0} {1}", "traintest", args), true);
+ _env.ResetProgressChannel();
+ int res = MainForTest(_env, newWriter, string.Format("{0} {1}", "traintest", args), true);
Assert.True(res == 0);
}
diff --git a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs
index a674618dd8..75cd3ca516 100644
--- a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs
+++ b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs
@@ -48,28 +48,26 @@ private void GenericSparseDataView(T[] v1, T[] v2)
new SparseExample() { X = new VBuffer (5, 3, v1, new int[] { 0, 2, 4 }) },
new SparseExample() { X = new VBuffer (5, 3, v2, new int[] { 0, 1, 3 }) }
};
- using (var host = new ConsoleEnvironment())
+ var env = new MLContext();
+ var data = env.CreateStreamingDataView(inputs);
+ var value = new VBuffer();
+ int n = 0;
+ using (var cur = data.GetRowCursor(i => true))
{
- var data = host.CreateStreamingDataView(inputs);
- var value = new VBuffer();
- int n = 0;
- using (var cur = data.GetRowCursor(i => true))
+ var getter = cur.GetGetter>(0);
+ while (cur.MoveNext())
{
- var getter = cur.GetGetter>(0);
- while (cur.MoveNext())
- {
- getter(ref value);
- Assert.True(value.GetValues().Length == 3);
- ++n;
- }
- }
- Assert.True(n == 2);
- var iter = data.AsEnumerable>(host, false).GetEnumerator();
- n = 0;
- while (iter.MoveNext())
+ getter(ref value);
+ Assert.True(value.GetValues().Length == 3);
++n;
- Assert.True(n == 2);
+ }
}
+ Assert.True(n == 2);
+ var iter = data.AsEnumerable>(env, false).GetEnumerator();
+ n = 0;
+ while (iter.MoveNext())
+ ++n;
+ Assert.True(n == 2);
}
[Fact]
@@ -90,28 +88,26 @@ private void GenericDenseDataView(T[] v1, T[] v2)
new DenseExample() { X = v1 },
new DenseExample() { X = v2 }
};
- using (var host = new ConsoleEnvironment())
+ var env = new MLContext();
+ var data = env.CreateStreamingDataView(inputs);
+ var value = new VBuffer();
+ int n = 0;
+ using (var cur = data.GetRowCursor(i => true))
{
- var data = host.CreateStreamingDataView(inputs);
- var value = new VBuffer();
- int n = 0;
- using (var cur = data.GetRowCursor(i => true))
+ var getter = cur.GetGetter>(0);
+ while (cur.MoveNext())
{
- var getter = cur.GetGetter>(0);
- while (cur.MoveNext())
- {
- getter(ref value);
- Assert.True(value.GetValues().Length == 3);
- ++n;
- }
- }
- Assert.True(n == 2);
- var iter = data.AsEnumerable>(host, false).GetEnumerator();
- n = 0;
- while (iter.MoveNext())
+ getter(ref value);
+ Assert.True(value.GetValues().Length == 3);
++n;
- Assert.True(n == 2);
+ }
}
+ Assert.True(n == 2);
+ var iter = data.AsEnumerable>(env, false).GetEnumerator();
+ n = 0;
+ while (iter.MoveNext())
+ ++n;
+ Assert.True(n == 2);
}
}
}
diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs
index 2e3b71ab01..8862864885 100644
--- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs
+++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs
@@ -59,15 +59,13 @@ public void CheckConstructor()
public void CanSuccessfullyApplyATransform()
{
var collection = CollectionDataSource.Create(new List() { new Input { Number1 = 1, String1 = "1" } });
- using (var environment = new ConsoleEnvironment())
- {
+ var environment = new MLContext();
Experiment experiment = environment.CreateExperiment();
Legacy.ILearningPipelineDataStep output = (Legacy.ILearningPipelineDataStep)collection.ApplyStep(null, experiment);
Assert.NotNull(output.Data);
Assert.NotNull(output.Data.VarName);
Assert.Null(output.Model);
- }
}
[Fact]
@@ -79,9 +77,8 @@ public void CanSuccessfullyEnumerated()
new Input { Number1 = 3, String1 = "3" }
});
- using (var environment = new ConsoleEnvironment())
- {
- Experiment experiment = environment.CreateExperiment();
+ var environment = new MLContext();
+ Experiment experiment = environment.CreateExperiment();
Legacy.ILearningPipelineDataStep output = collection.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep;
experiment.Compile();
@@ -128,7 +125,6 @@ public void CanSuccessfullyEnumerated()
Assert.False(cursor.MoveNext());
}
- }
}
[Fact]
@@ -294,7 +290,7 @@ public class ConversionSimpleClass
public float fFloat;
public double fDouble;
public bool fBool;
- public string fString="";
+ public string fString = "";
}
public bool CompareObjectValues(object x, object y, Type type)
@@ -418,17 +414,15 @@ public void RoundTripConversionWithBasicTypes()
new ConversionSimpleClass()
};
- using (var env = new ConsoleEnvironment())
+ var env = new MLContext();
+ var dataView = ComponentCreation.CreateDataView(env, data);
+ var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
+ var originalEnumerator = data.GetEnumerator();
+ while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
{
- var dataView = ComponentCreation.CreateDataView(env, data);
- var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
- var originalEnumerator = data.GetEnumerator();
- while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
- {
- Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
- }
- Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
+ Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
}
+ Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
public class ConversionNotSupportedMinValueClass
@@ -442,27 +436,25 @@ public class ConversionNotSupportedMinValueClass
[Fact]
public void ConversionExceptionsBehavior()
{
- using (var env = new ConsoleEnvironment())
+ var env = new MLContext();
+ var data = new ConversionNotSupportedMinValueClass[1];
+ foreach (var field in typeof(ConversionNotSupportedMinValueClass).GetFields())
{
- var data = new ConversionNotSupportedMinValueClass[1];
- foreach (var field in typeof(ConversionNotSupportedMinValueClass).GetFields())
+ data[0] = new ConversionNotSupportedMinValueClass();
+ FieldInfo fi;
+ if ((fi = field.FieldType.GetField("MinValue")) != null)
+ {
+ field.SetValue(data[0], fi.GetValue(null));
+ }
+ var dataView = ComponentCreation.CreateDataView(env, data);
+ var enumerator = dataView.AsEnumerable(env, false).GetEnumerator();
+ try
+ {
+ enumerator.MoveNext();
+ Assert.True(false);
+ }
+ catch
{
- data[0] = new ConversionNotSupportedMinValueClass();
- FieldInfo fi;
- if ((fi = field.FieldType.GetField("MinValue")) != null)
- {
- field.SetValue(data[0], fi.GetValue(null));
- }
- var dataView = ComponentCreation.CreateDataView(env, data);
- var enumerator = dataView.AsEnumerable(env, false).GetEnumerator();
- try
- {
- enumerator.MoveNext();
- Assert.True(false);
- }
- catch
- {
- }
}
}
}
@@ -496,15 +488,13 @@ public void ClassWithConstFieldsConversion()
new ClassWithConstField(){ fInt=-1, fString ="" },
};
- using (var env = new ConsoleEnvironment())
- {
- var dataView = ComponentCreation.CreateDataView(env, data);
- var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
- var originalEnumerator = data.GetEnumerator();
- while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
- Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
- Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
- }
+ var env = new MLContext();
+ var dataView = ComponentCreation.CreateDataView(env, data);
+ var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
+ var originalEnumerator = data.GetEnumerator();
+ while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
+ Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
+ Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
@@ -524,15 +514,13 @@ public void ClassWithMixOfFieldsAndPropertiesConversion()
new ClassWithMixOfFieldsAndProperties(){ IntProp=-1, fString ="" },
};
- using (var env = new ConsoleEnvironment())
- {
- var dataView = ComponentCreation.CreateDataView(env, data);
- var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
- var originalEnumerator = data.GetEnumerator();
- while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
- Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
- Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
- }
+ var env = new MLContext();
+ var dataView = ComponentCreation.CreateDataView(env, data);
+ var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
+ var originalEnumerator = data.GetEnumerator();
+ while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
+ Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
+ Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
public abstract class BaseClassWithInheritedProperties
@@ -580,28 +568,25 @@ public void ClassWithPrivateFieldsAndPropertiesConversion()
new ClassWithPrivateFieldsAndProperties(){ StringProp ="baba" }
};
- using (var env = new ConsoleEnvironment())
+ var env = new MLContext();
+ var dataView = ComponentCreation.CreateDataView(env, data);
+ var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
+ var originalEnumerator = data.GetEnumerator();
+ while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
{
- var dataView = ComponentCreation.CreateDataView(env, data);
- var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
- var originalEnumerator = data.GetEnumerator();
- while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
- {
- Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
- Assert.True(enumeratorSimple.Current.UnusedPropertyWithPrivateSetter == 100);
- }
- Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
+ Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
+ Assert.True(enumeratorSimple.Current.UnusedPropertyWithPrivateSetter == 100);
}
+ Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
public class ClassWithInheritedProperties : BaseClassWithInheritedProperties
{
- private int _fInt;
private long _fLong;
private byte _fByte2;
- public int IntProp { get { return _fInt; } set { _fInt = value; } }
- public override long LongProp { get { return _fLong; } set { _fLong = value; } }
- public override byte ByteProp { get { return _fByte2; } set { _fByte2 = value; } }
+ public int IntProp { get; set; }
+ public override long LongProp { get => _fLong; set => _fLong = value; }
+ public override byte ByteProp { get => _fByte2; set => _fByte2 = value; }
}
[Fact]
@@ -613,15 +598,13 @@ public void ClassWithInheritedPropertiesConversion()
new ClassWithInheritedProperties(){ IntProp=-1, StringProp ="", LongProp=2, ByteProp=4 },
};
- using (var env = new ConsoleEnvironment())
- {
- var dataView = ComponentCreation.CreateDataView(env, data);
- var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
- var originalEnumerator = data.GetEnumerator();
- while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
- Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
- Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
- }
+ var env = new MLContext();
+ var dataView = ComponentCreation.CreateDataView(env, data);
+ var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
+ var originalEnumerator = data.GetEnumerator();
+ while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
+ Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
+ Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
public class ClassWithArrays
@@ -666,44 +649,30 @@ public void RoundTripConversionWithArrays()
};
- using (var env = new ConsoleEnvironment())
+ var env = new MLContext();
+ var dataView = ComponentCreation.CreateDataView(env, data);
+ var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
+ var originalEnumerator = data.GetEnumerator();
+ while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
{
- var dataView = ComponentCreation.CreateDataView(env, data);
- var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
- var originalEnumerator = data.GetEnumerator();
- while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
- {
- Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
- }
- Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
+ Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
}
+ Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
public class ClassWithArrayProperties
{
- private string[] _fString;
- private int[] _fInt;
- private uint[] _fuInt;
- private short[] _fShort;
- private ushort[] _fuShort;
- private sbyte[] _fsByte;
- private byte[] _fByte;
- private long[] _fLong;
- private ulong[] _fuLong;
- private float[] _fFloat;
- private double[] _fDouble;
- private bool[] _fBool;
- public string[] StringProp { get { return _fString; } set { _fString = value; } }
- public int[] IntProp { get { return _fInt; } set { _fInt = value; } }
- public uint[] UIntProp { get { return _fuInt; } set { _fuInt = value; } }
- public short[] ShortProp { get { return _fShort; } set { _fShort = value; } }
- public ushort[] UShortProp { get { return _fuShort; } set { _fuShort = value; } }
- public sbyte[] SByteProp { get { return _fsByte; } set { _fsByte = value; } }
- public byte[] ByteProp { get { return _fByte; } set { _fByte = value; } }
- public long[] LongProp { get { return _fLong; } set { _fLong = value; } }
- public ulong[] ULongProp { get { return _fuLong; } set { _fuLong = value; } }
- public float[] FloatProp { get { return _fFloat; } set { _fFloat = value; } }
- public double[] DobuleProp { get { return _fDouble; } set { _fDouble = value; } }
- public bool[] BoolProp { get { return _fBool; } set { _fBool = value; } }
+ public string[] StringProp { get; set; }
+ public int[] IntProp { get; set; }
+ public uint[] UIntProp { get; set; }
+ public short[] ShortProp { get; set; }
+ public ushort[] UShortProp { get; set; }
+ public sbyte[] SByteProp { get; set; }
+ public byte[] ByteProp { get; set; }
+ public long[] LongProp { get; set; }
+ public ulong[] ULongProp { get; set; }
+ public float[] FloatProp { get; set; }
+ public double[] DobuleProp { get; set; }
+ public bool[] BoolProp { get; set; }
}
[Fact]
@@ -731,27 +700,23 @@ public void RoundTripConversionWithArrayPropertiess()
new ClassWithArrayProperties()
};
- using (var env = new ConsoleEnvironment())
- {
- var dataView = ComponentCreation.CreateDataView(env, data);
- var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
- var originalEnumerator = data.GetEnumerator();
- while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
- {
- Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
- }
- Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
- }
+ var env = new MLContext();
+ var dataView = ComponentCreation.CreateDataView(env, data);
+ var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
+ var originalEnumerator = data.GetEnumerator();
+ while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
+ Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current));
+ Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext());
}
- class ClassWithGetter
+ private sealed class ClassWithGetter
{
private DateTime _dateTime = DateTime.Now;
- public float Day { get { return _dateTime.Day; } }
- public int Hour { get { return _dateTime.Hour; } }
+ public float Day => _dateTime.Day;
+ public int Hour => _dateTime.Hour;
}
- class ClassWithSetter
+ private sealed class ClassWithSetter
{
public float Day { private get; set; }
public int Hour { private get; set; }
@@ -772,16 +737,14 @@ public void PrivateGetSetProperties()
new ClassWithGetter()
};
- using (var env = new ConsoleEnvironment())
+ var env = new MLContext();
+ var dataView = ComponentCreation.CreateDataView(env, data);
+ var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
+ var originalEnumerator = data.GetEnumerator();
+ while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
{
- var dataView = ComponentCreation.CreateDataView(env, data);
- var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator();
- var originalEnumerator = data.GetEnumerator();
- while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext())
- {
- Assert.True(enumeratorSimple.Current.GetDay == originalEnumerator.Current.Day &&
- enumeratorSimple.Current.GetHour == originalEnumerator.Current.Hour);
- }
+ Assert.True(enumeratorSimple.Current.GetDay == originalEnumerator.Current.Day &&
+ enumeratorSimple.Current.GetHour == originalEnumerator.Current.Hour);
}
}
}
diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs
index 3c4522bf1d..cde998e2b4 100644
--- a/test/Microsoft.ML.Tests/ImagesTests.cs
+++ b/test/Microsoft.ML.Tests/ImagesTests.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.ImageAnalytics;
@@ -25,72 +26,69 @@ public ImageTests(ITestOutputHelper output) : base(output)
[Fact]
public void TestEstimatorChain()
{
- using (var env = new ConsoleEnvironment())
+ var env = new MLContext();
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var invalidData = TextLoader.Create(env, new TextLoader.Arguments()
+ }, new MultiFileSource(dataFile));
+ var invalidData = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.R4, 0),
}
- }, new MultiFileSource(dataFile));
+ }, new MultiFileSource(dataFile));
- var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal"))
- .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100))
- .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels"))
- .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray")));
+ var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal"))
+ .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100))
+ .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels"))
+ .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray")));
- TestEstimatorCore(pipe, data, null, invalidData);
- }
+ TestEstimatorCore(pipe, data, null, invalidData);
Done();
}
[Fact]
public void TestEstimatorSaveLoad()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
+ }, new MultiFileSource(dataFile));
- var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal"))
- .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100))
- .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels"))
- .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray")));
+ var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal"))
+ .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100))
+ .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels"))
+ .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray")));
- pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema));
- var model = pipe.Fit(data);
-
- using (var file = env.CreateTempFile())
- {
- using (var fs = file.CreateWriteStream())
- model.SaveTo(env, fs);
- var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream());
+ pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema));
+ var model = pipe.Fit(data);
- var newCols = ((ImageLoaderTransform)model2.First()).Columns;
- var oldCols = ((ImageLoaderTransform)model.First()).Columns;
- Assert.True(newCols
- .Zip(oldCols, (x, y) => x == y)
- .All(x => x));
- }
+ var tempPath = Path.GetTempFileName();
+ using (var file = new SimpleFileHandle(env, tempPath, true, true))
+ {
+ using (var fs = file.CreateWriteStream())
+ model.SaveTo(env, fs);
+ var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream());
+
+ var newCols = ((ImageLoaderTransform)model2.First()).Columns;
+ var oldCols = ((ImageLoaderTransform)model.First()).Columns;
+ Assert.True(newCols
+ .Zip(oldCols, (x, y) => x == y)
+ .All(x => x));
}
Done();
}
@@ -98,50 +96,48 @@ public void TestEstimatorSaveLoad()
[Fact]
public void TestSaveImages()
{
- using (var env = new ConsoleEnvironment())
+ var env = new MLContext();
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
+ },
+ ImageFolder = imageFolder
+ }, data);
- IDataView cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ IDataView cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =100, ImageWidth = 100, Resizing = ImageResizerTransform.ResizingKind.IsoPad}
}
- }, images);
+ }, images);
- cropped.Schema.TryGetColumnIndex("ImagePath", out int pathColumn);
- cropped.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
- using (var cursor = cropped.GetRowCursor((x) => true))
- {
- var pathGetter = cursor.GetGetter>(pathColumn);
- ReadOnlyMemory path = default;
- var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
- Bitmap bitmap = default;
- while (cursor.MoveNext())
- {
- pathGetter(ref path);
- bitmapCropGetter(ref bitmap);
- Assert.NotNull(bitmap);
- var fileToSave = GetOutputPath(Path.GetFileNameWithoutExtension(path.ToString()) + ".cropped.jpg");
- bitmap.Save(fileToSave, System.Drawing.Imaging.ImageFormat.Jpeg);
- }
+ cropped.Schema.TryGetColumnIndex("ImagePath", out int pathColumn);
+ cropped.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
+ using (var cursor = cropped.GetRowCursor((x) => true))
+ {
+ var pathGetter = cursor.GetGetter>(pathColumn);
+ ReadOnlyMemory path = default;
+ var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
+ Bitmap bitmap = default;
+ while (cursor.MoveNext())
+ {
+ pathGetter(ref path);
+ bitmapCropGetter(ref bitmap);
+ Assert.NotNull(bitmap);
+ var fileToSave = GetOutputPath(Path.GetFileNameWithoutExtension(path.ToString()) + ".cropped.jpg");
+ bitmap.Save(fileToSave, System.Drawing.Imaging.ImageFormat.Jpeg);
}
}
Done();
@@ -150,68 +146,66 @@ public void TestSaveImages()
[Fact]
public void TestGreyscaleTransformImages()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ var imageHeight = 150;
+ var imageWidth = 100;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var imageHeight = 150;
- var imageWidth = 100;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
- var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Name= "ImageCropped", Source = "ImageReal", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
- }, images);
+ }, images);
- IDataView grey = ImageGrayscaleTransform.Create(env, new ImageGrayscaleTransform.Arguments()
- {
- Column = new ImageGrayscaleTransform.Column[1]{
+ IDataView grey = ImageGrayscaleTransform.Create(env, new ImageGrayscaleTransform.Arguments()
+ {
+ Column = new ImageGrayscaleTransform.Column[1]{
new ImageGrayscaleTransform.Column() { Name= "ImageGrey", Source = "ImageCropped"}
}
- }, cropped);
+ }, cropped);
- var fname = nameof(TestGreyscaleTransformImages) + "_model.zip";
+ var fname = nameof(TestGreyscaleTransformImages) + "_model.zip";
- var fh = env.CreateOutputFile(fname);
- using (var ch = env.Start("save"))
- TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(grey));
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(grey));
- grey = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
- DeleteOutputPath(fname);
+ grey = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
- grey.Schema.TryGetColumnIndex("ImageGrey", out int greyColumn);
- using (var cursor = grey.GetRowCursor((x) => true))
- {
- var bitmapGetter = cursor.GetGetter(greyColumn);
- Bitmap bitmap = default;
- while (cursor.MoveNext())
- {
- bitmapGetter(ref bitmap);
- Assert.NotNull(bitmap);
- for (int x = 0; x < imageWidth; x++)
- for (int y = 0; y < imageHeight; y++)
- {
- var pixel = bitmap.GetPixel(x, y);
- // greyscale image has same values for R,G and B
- Assert.True(pixel.R == pixel.G && pixel.G == pixel.B);
- }
- }
+ grey.Schema.TryGetColumnIndex("ImageGrey", out int greyColumn);
+ using (var cursor = grey.GetRowCursor((x) => true))
+ {
+ var bitmapGetter = cursor.GetGetter(greyColumn);
+ Bitmap bitmap = default;
+ while (cursor.MoveNext())
+ {
+ bitmapGetter(ref bitmap);
+ Assert.NotNull(bitmap);
+ for (int x = 0; x < imageWidth; x++)
+ for (int y = 0; y < imageHeight; y++)
+ {
+ var pixel = bitmap.GetPixel(x, y);
+ // greyscale image has same values for R,G and B
+ Assert.True(pixel.R == pixel.G && pixel.G == pixel.B);
+ }
}
}
Done();
@@ -220,88 +214,86 @@ public void TestGreyscaleTransformImages()
[Fact]
public void TestBackAndForthConversionWithAlphaInterleave()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ const int imageHeight = 100;
+ const int imageWidth = 130;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var imageHeight = 100;
- var imageWidth = 130;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
- var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
- }, images);
+ }, images);
- var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
- {
- InterleaveArgb = true,
- Offset = 127.5f,
- Scale = 2f / 255,
- Column = new ImagePixelExtractorTransform.Column[1]{
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ InterleaveArgb = true,
+ Offset = 127.5f,
+ Scale = 2f / 255,
+ Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true}
}
- }, cropped);
+ }, cropped);
- IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
- {
- InterleaveArgb = true,
- Offset = -1f,
- Scale = 255f / 2,
- Column = new VectorToImageTransform.Column[1]{
+ IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
+ {
+ InterleaveArgb = true,
+ Offset = -1f,
+ Scale = 255f / 2,
+ Column = new VectorToImageTransform.Column[1]{
new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true}
}
- }, pixels);
+ }, pixels);
- var fname = nameof(TestBackAndForthConversionWithAlphaInterleave) + "_model.zip";
+ var fname = nameof(TestBackAndForthConversionWithAlphaInterleave) + "_model.zip";
- var fh = env.CreateOutputFile(fname);
- using (var ch = env.Start("save"))
- TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
- backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
- DeleteOutputPath(fname);
+ backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
- backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
- backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
- using (var cursor = backToBitmaps.GetRowCursor((x) => true))
- {
- var bitmapGetter = cursor.GetGetter(bitmapColumn);
- Bitmap restoredBitmap = default;
-
- var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
- Bitmap croppedBitmap = default;
- while (cursor.MoveNext())
- {
- bitmapGetter(ref restoredBitmap);
- Assert.NotNull(restoredBitmap);
- bitmapCropGetter(ref croppedBitmap);
- Assert.NotNull(croppedBitmap);
- for (int x = 0; x < imageWidth; x++)
- for (int y = 0; y < imageHeight; y++)
- {
- var c = croppedBitmap.GetPixel(x, y);
- var r = restoredBitmap.GetPixel(x, y);
- Assert.True(c == r);
- }
- }
+ backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
+ backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
+ using (var cursor = backToBitmaps.GetRowCursor((x) => true))
+ {
+ var bitmapGetter = cursor.GetGetter(bitmapColumn);
+ Bitmap restoredBitmap = default;
+
+ var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
+ Bitmap croppedBitmap = default;
+ while (cursor.MoveNext())
+ {
+ bitmapGetter(ref restoredBitmap);
+ Assert.NotNull(restoredBitmap);
+ bitmapCropGetter(ref croppedBitmap);
+ Assert.NotNull(croppedBitmap);
+ for (int x = 0; x < imageWidth; x++)
+ for (int y = 0; y < imageHeight; y++)
+ {
+ var c = croppedBitmap.GetPixel(x, y);
+ var r = restoredBitmap.GetPixel(x, y);
+ Assert.True(c == r);
+ }
}
}
Done();
@@ -310,88 +302,86 @@ public void TestBackAndForthConversionWithAlphaInterleave()
[Fact]
public void TestBackAndForthConversionWithoutAlphaInterleave()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ const int imageHeight = 100;
+ const int imageWidth = 130;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var imageHeight = 100;
- var imageWidth = 130;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
- var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
- }, images);
+ }, images);
- var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
- {
- InterleaveArgb = true,
- Offset = 127.5f,
- Scale = 2f / 255,
- Column = new ImagePixelExtractorTransform.Column[1]{
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ InterleaveArgb = true,
+ Offset = 127.5f,
+ Scale = 2f / 255,
+ Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=false}
}
- }, cropped);
+ }, cropped);
- IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
- {
- InterleaveArgb = true,
- Offset = -1f,
- Scale = 255f / 2,
- Column = new VectorToImageTransform.Column[1]{
+ IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
+ {
+ InterleaveArgb = true,
+ Offset = -1f,
+ Scale = 255f / 2,
+ Column = new VectorToImageTransform.Column[1]{
new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false}
}
- }, pixels);
+ }, pixels);
- var fname = nameof(TestBackAndForthConversionWithoutAlphaInterleave) + "_model.zip";
+ var fname = nameof(TestBackAndForthConversionWithoutAlphaInterleave) + "_model.zip";
- var fh = env.CreateOutputFile(fname);
- using (var ch = env.Start("save"))
- TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
- backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
- DeleteOutputPath(fname);
+ backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
- backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
- backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
- using (var cursor = backToBitmaps.GetRowCursor((x) => true))
- {
- var bitmapGetter = cursor.GetGetter(bitmapColumn);
- Bitmap restoredBitmap = default;
-
- var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
- Bitmap croppedBitmap = default;
- while (cursor.MoveNext())
- {
- bitmapGetter(ref restoredBitmap);
- Assert.NotNull(restoredBitmap);
- bitmapCropGetter(ref croppedBitmap);
- Assert.NotNull(croppedBitmap);
- for (int x = 0; x < imageWidth; x++)
- for (int y = 0; y < imageHeight; y++)
- {
- var c = croppedBitmap.GetPixel(x, y);
- var r = restoredBitmap.GetPixel(x, y);
- Assert.True(c.R == r.R && c.G == r.G && c.B == r.B);
- }
- }
+ backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
+ backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
+ using (var cursor = backToBitmaps.GetRowCursor((x) => true))
+ {
+ var bitmapGetter = cursor.GetGetter(bitmapColumn);
+ Bitmap restoredBitmap = default;
+
+ var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
+ Bitmap croppedBitmap = default;
+ while (cursor.MoveNext())
+ {
+ bitmapGetter(ref restoredBitmap);
+ Assert.NotNull(restoredBitmap);
+ bitmapCropGetter(ref croppedBitmap);
+ Assert.NotNull(croppedBitmap);
+ for (int x = 0; x < imageWidth; x++)
+ for (int y = 0; y < imageHeight; y++)
+ {
+ var c = croppedBitmap.GetPixel(x, y);
+ var r = restoredBitmap.GetPixel(x, y);
+ Assert.True(c.R == r.R && c.G == r.G && c.B == r.B);
+ }
}
}
Done();
@@ -400,88 +390,86 @@ public void TestBackAndForthConversionWithoutAlphaInterleave()
[Fact]
public void TestBackAndForthConversionWithAlphaNoInterleave()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ const int imageHeight = 100;
+ const int imageWidth = 130;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var imageHeight = 100;
- var imageWidth = 130;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
- var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
- }, images);
+ }, images);
- var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
- {
- InterleaveArgb = false,
- Offset = 127.5f,
- Scale = 2f / 255,
- Column = new ImagePixelExtractorTransform.Column[1]{
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ InterleaveArgb = false,
+ Offset = 127.5f,
+ Scale = 2f / 255,
+ Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true}
}
- }, cropped);
+ }, cropped);
- IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
- {
- InterleaveArgb = false,
- Offset = -1f,
- Scale = 255f / 2,
- Column = new VectorToImageTransform.Column[1]{
+ IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
+ {
+ InterleaveArgb = false,
+ Offset = -1f,
+ Scale = 255f / 2,
+ Column = new VectorToImageTransform.Column[1]{
new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true}
}
- }, pixels);
+ }, pixels);
- var fname = nameof(TestBackAndForthConversionWithAlphaNoInterleave) + "_model.zip";
+ var fname = nameof(TestBackAndForthConversionWithAlphaNoInterleave) + "_model.zip";
- var fh = env.CreateOutputFile(fname);
- using (var ch = env.Start("save"))
- TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
- backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
- DeleteOutputPath(fname);
+ backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
- backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
- backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
- using (var cursor = backToBitmaps.GetRowCursor((x) => true))
- {
- var bitmapGetter = cursor.GetGetter(bitmapColumn);
- Bitmap restoredBitmap = default;
-
- var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
- Bitmap croppedBitmap = default;
- while (cursor.MoveNext())
- {
- bitmapGetter(ref restoredBitmap);
- Assert.NotNull(restoredBitmap);
- bitmapCropGetter(ref croppedBitmap);
- Assert.NotNull(croppedBitmap);
- for (int x = 0; x < imageWidth; x++)
- for (int y = 0; y < imageHeight; y++)
- {
- var c = croppedBitmap.GetPixel(x, y);
- var r = restoredBitmap.GetPixel(x, y);
- Assert.True(c == r);
- }
- }
+ backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
+ backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
+ using (var cursor = backToBitmaps.GetRowCursor((x) => true))
+ {
+ var bitmapGetter = cursor.GetGetter(bitmapColumn);
+ Bitmap restoredBitmap = default;
+
+ var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
+ Bitmap croppedBitmap = default;
+ while (cursor.MoveNext())
+ {
+ bitmapGetter(ref restoredBitmap);
+ Assert.NotNull(restoredBitmap);
+ bitmapCropGetter(ref croppedBitmap);
+ Assert.NotNull(croppedBitmap);
+ for (int x = 0; x < imageWidth; x++)
+ for (int y = 0; y < imageHeight; y++)
+ {
+ var c = croppedBitmap.GetPixel(x, y);
+ var r = restoredBitmap.GetPixel(x, y);
+ Assert.True(c == r);
+ }
}
}
Done();
@@ -490,88 +478,86 @@ public void TestBackAndForthConversionWithAlphaNoInterleave()
[Fact]
public void TestBackAndForthConversionWithoutAlphaNoInterleave()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ const int imageHeight = 100;
+ const int imageWidth = 130;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var imageHeight = 100;
- var imageWidth = 130;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
- var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
- }, images);
+ }, images);
- var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
- {
- InterleaveArgb = false,
- Offset = 127.5f,
- Scale = 2f / 255,
- Column = new ImagePixelExtractorTransform.Column[1]{
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ InterleaveArgb = false,
+ Offset = 127.5f,
+ Scale = 2f / 255,
+ Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=false}
}
- }, cropped);
+ }, cropped);
- IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
- {
- InterleaveArgb = false,
- Offset = -1f,
- Scale = 255f / 2,
- Column = new VectorToImageTransform.Column[1]{
+ IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
+ {
+ InterleaveArgb = false,
+ Offset = -1f,
+ Scale = 255f / 2,
+ Column = new VectorToImageTransform.Column[1]{
new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false}
}
- }, pixels);
-
- var fname = nameof(TestBackAndForthConversionWithoutAlphaNoInterleave) + "_model.zip";
+ }, pixels);
- var fh = env.CreateOutputFile(fname);
- using (var ch = env.Start("save"))
- TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ var fname = nameof(TestBackAndForthConversionWithoutAlphaNoInterleave) + "_model.zip";
- backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
- DeleteOutputPath(fname);
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
- backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
- backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
- using (var cursor = backToBitmaps.GetRowCursor((x) => true))
- {
- var bitmapGetter = cursor.GetGetter(bitmapColumn);
- Bitmap restoredBitmap = default;
- var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
- Bitmap croppedBitmap = default;
- while (cursor.MoveNext())
- {
- bitmapGetter(ref restoredBitmap);
- Assert.NotNull(restoredBitmap);
- bitmapCropGetter(ref croppedBitmap);
- Assert.NotNull(croppedBitmap);
- for (int x = 0; x < imageWidth; x++)
- for (int y = 0; y < imageHeight; y++)
- {
- var c = croppedBitmap.GetPixel(x, y);
- var r = restoredBitmap.GetPixel(x, y);
- Assert.True(c.R == r.R && c.G == r.G && c.B == r.B);
- }
- }
+ backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
+ backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
+ using (var cursor = backToBitmaps.GetRowCursor((x) => true))
+ {
+ var bitmapGetter = cursor.GetGetter(bitmapColumn);
+ Bitmap restoredBitmap = default;
+
+ var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
+ Bitmap croppedBitmap = default;
+ while (cursor.MoveNext())
+ {
+ bitmapGetter(ref restoredBitmap);
+ Assert.NotNull(restoredBitmap);
+ bitmapCropGetter(ref croppedBitmap);
+ Assert.NotNull(croppedBitmap);
+ for (int x = 0; x < imageWidth; x++)
+ for (int y = 0; y < imageHeight; y++)
+ {
+ var c = croppedBitmap.GetPixel(x, y);
+ var r = restoredBitmap.GetPixel(x, y);
+ Assert.True(c.R == r.R && c.G == r.G && c.B == r.B);
+ }
}
}
Done();
@@ -580,84 +566,82 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleave()
[Fact]
public void TestBackAndForthConversionWithAlphaInterleaveNoOffset()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ const int imageHeight = 100;
+ const int imageWidth = 130;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var imageHeight = 100;
- var imageWidth = 130;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
- var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
- }, images);
+ }, images);
- var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
- {
- InterleaveArgb = true,
- Column = new ImagePixelExtractorTransform.Column[1]{
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ InterleaveArgb = true,
+ Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true}
}
- }, cropped);
+ }, cropped);
- IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
- {
- InterleaveArgb = true,
- Column = new VectorToImageTransform.Column[1]{
+ IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
+ {
+ InterleaveArgb = true,
+ Column = new VectorToImageTransform.Column[1]{
new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true}
}
- }, pixels);
+ }, pixels);
- var fname = nameof(TestBackAndForthConversionWithAlphaInterleaveNoOffset) + "_model.zip";
+ var fname = nameof(TestBackAndForthConversionWithAlphaInterleaveNoOffset) + "_model.zip";
- var fh = env.CreateOutputFile(fname);
- using (var ch = env.Start("save"))
- TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
- backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
- DeleteOutputPath(fname);
+ backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
- backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
- backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
- using (var cursor = backToBitmaps.GetRowCursor((x) => true))
- {
- var bitmapGetter = cursor.GetGetter(bitmapColumn);
- Bitmap restoredBitmap = default;
-
- var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
- Bitmap croppedBitmap = default;
- while (cursor.MoveNext())
- {
- bitmapGetter(ref restoredBitmap);
- Assert.NotNull(restoredBitmap);
- bitmapCropGetter(ref croppedBitmap);
- Assert.NotNull(croppedBitmap);
- for (int x = 0; x < imageWidth; x++)
- for (int y = 0; y < imageHeight; y++)
- {
- var c = croppedBitmap.GetPixel(x, y);
- var r = restoredBitmap.GetPixel(x, y);
- Assert.True(c == r);
- }
- }
+ backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
+ backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
+ using (var cursor = backToBitmaps.GetRowCursor((x) => true))
+ {
+ var bitmapGetter = cursor.GetGetter(bitmapColumn);
+ Bitmap restoredBitmap = default;
+
+ var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
+ Bitmap croppedBitmap = default;
+ while (cursor.MoveNext())
+ {
+ bitmapGetter(ref restoredBitmap);
+ Assert.NotNull(restoredBitmap);
+ bitmapCropGetter(ref croppedBitmap);
+ Assert.NotNull(croppedBitmap);
+ for (int x = 0; x < imageWidth; x++)
+ for (int y = 0; y < imageHeight; y++)
+ {
+ var c = croppedBitmap.GetPixel(x, y);
+ var r = restoredBitmap.GetPixel(x, y);
+ Assert.True(c == r);
+ }
}
}
Done();
@@ -666,84 +650,82 @@ public void TestBackAndForthConversionWithAlphaInterleaveNoOffset()
[Fact]
public void TestBackAndForthConversionWithoutAlphaInterleaveNoOffset()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ const int imageHeight = 100;
+ const int imageWidth = 130;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var imageHeight = 100;
- var imageWidth = 130;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
- var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
- }, images);
+ }, images);
- var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
- {
- InterleaveArgb = true,
- Column = new ImagePixelExtractorTransform.Column[1]{
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ InterleaveArgb = true,
+ Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=false}
}
- }, cropped);
+ }, cropped);
- IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
- {
- InterleaveArgb = true,
- Column = new VectorToImageTransform.Column[1]{
+ IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
+ {
+ InterleaveArgb = true,
+ Column = new VectorToImageTransform.Column[1]{
new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false}
}
- }, pixels);
+ }, pixels);
- var fname = nameof(TestBackAndForthConversionWithoutAlphaInterleaveNoOffset) + "_model.zip";
+ var fname = nameof(TestBackAndForthConversionWithoutAlphaInterleaveNoOffset) + "_model.zip";
- var fh = env.CreateOutputFile(fname);
- using (var ch = env.Start("save"))
- TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
- backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
- DeleteOutputPath(fname);
+ backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
- backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
- backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
- using (var cursor = backToBitmaps.GetRowCursor((x) => true))
- {
- var bitmapGetter = cursor.GetGetter(bitmapColumn);
- Bitmap restoredBitmap = default;
-
- var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
- Bitmap croppedBitmap = default;
- while (cursor.MoveNext())
- {
- bitmapGetter(ref restoredBitmap);
- Assert.NotNull(restoredBitmap);
- bitmapCropGetter(ref croppedBitmap);
- Assert.NotNull(croppedBitmap);
- for (int x = 0; x < imageWidth; x++)
- for (int y = 0; y < imageHeight; y++)
- {
- var c = croppedBitmap.GetPixel(x, y);
- var r = restoredBitmap.GetPixel(x, y);
- Assert.True(c.R == r.R && c.G == r.G && c.B == r.B);
- }
- }
+ backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
+ backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
+ using (var cursor = backToBitmaps.GetRowCursor((x) => true))
+ {
+ var bitmapGetter = cursor.GetGetter(bitmapColumn);
+ Bitmap restoredBitmap = default;
+
+ var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
+ Bitmap croppedBitmap = default;
+ while (cursor.MoveNext())
+ {
+ bitmapGetter(ref restoredBitmap);
+ Assert.NotNull(restoredBitmap);
+ bitmapCropGetter(ref croppedBitmap);
+ Assert.NotNull(croppedBitmap);
+ for (int x = 0; x < imageWidth; x++)
+ for (int y = 0; y < imageHeight; y++)
+ {
+ var c = croppedBitmap.GetPixel(x, y);
+ var r = restoredBitmap.GetPixel(x, y);
+ Assert.True(c.R == r.R && c.G == r.G && c.B == r.B);
+ }
}
}
Done();
@@ -752,84 +734,82 @@ public void TestBackAndForthConversionWithoutAlphaInterleaveNoOffset()
[Fact]
public void TestBackAndForthConversionWithAlphaNoInterleaveNoOffset()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ const int imageHeight = 100;
+ var imageWidth = 130;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var imageHeight = 100;
- var imageWidth = 130;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
- var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
- }, images);
+ }, images);
- var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
- {
- InterleaveArgb = false,
- Column = new ImagePixelExtractorTransform.Column[1]{
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ InterleaveArgb = false,
+ Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=true}
}
- }, cropped);
+ }, cropped);
- IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
- {
- InterleaveArgb = false,
- Column = new VectorToImageTransform.Column[1]{
+ IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
+ {
+ InterleaveArgb = false,
+ Column = new VectorToImageTransform.Column[1]{
new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true}
}
- }, pixels);
+ }, pixels);
- var fname = nameof(TestBackAndForthConversionWithAlphaNoInterleaveNoOffset) + "_model.zip";
+ var fname = nameof(TestBackAndForthConversionWithAlphaNoInterleaveNoOffset) + "_model.zip";
- var fh = env.CreateOutputFile(fname);
- using (var ch = env.Start("save"))
- TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
- backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
- DeleteOutputPath(fname);
+ backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
- backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
- backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
- using (var cursor = backToBitmaps.GetRowCursor((x) => true))
- {
- var bitmapGetter = cursor.GetGetter(bitmapColumn);
- Bitmap restoredBitmap = default;
-
- var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
- Bitmap croppedBitmap = default;
- while (cursor.MoveNext())
- {
- bitmapGetter(ref restoredBitmap);
- Assert.NotNull(restoredBitmap);
- bitmapCropGetter(ref croppedBitmap);
- Assert.NotNull(croppedBitmap);
- for (int x = 0; x < imageWidth; x++)
- for (int y = 0; y < imageHeight; y++)
- {
- var c = croppedBitmap.GetPixel(x, y);
- var r = restoredBitmap.GetPixel(x, y);
- Assert.True(c == r);
- }
- }
+ backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
+ backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
+ using (var cursor = backToBitmaps.GetRowCursor((x) => true))
+ {
+ var bitmapGetter = cursor.GetGetter(bitmapColumn);
+ Bitmap restoredBitmap = default;
+
+ var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
+ Bitmap croppedBitmap = default;
+ while (cursor.MoveNext())
+ {
+ bitmapGetter(ref restoredBitmap);
+ Assert.NotNull(restoredBitmap);
+ bitmapCropGetter(ref croppedBitmap);
+ Assert.NotNull(croppedBitmap);
+ for (int x = 0; x < imageWidth; x++)
+ for (int y = 0; y < imageHeight; y++)
+ {
+ var c = croppedBitmap.GetPixel(x, y);
+ var r = restoredBitmap.GetPixel(x, y);
+ Assert.True(c == r);
+ }
}
}
Done();
@@ -838,87 +818,85 @@ public void TestBackAndForthConversionWithAlphaNoInterleaveNoOffset()
[Fact]
public void TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset()
{
- using (var env = new ConsoleEnvironment())
+ IHostEnvironment env = new MLContext();
+ const int imageHeight = 100;
+ const int imageWidth = 130;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
{
- var imageHeight = 100;
- var imageWidth = 130;
- var dataFile = GetDataPath("images/images.tsv");
- var imageFolder = Path.GetDirectoryName(dataFile);
- var data = TextLoader.Create(env, new TextLoader.Arguments()
+ Column = new[]
{
- Column = new[]
- {
new TextLoader.Column("ImagePath", DataKind.TX, 0),
new TextLoader.Column("Name", DataKind.TX, 1),
}
- }, new MultiFileSource(dataFile));
- var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ }, new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
{
- Column = new ImageLoaderTransform.Column[1]
- {
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
- },
- ImageFolder = imageFolder
- }, data);
- var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
- {
- Column = new ImageResizerTransform.Column[1]{
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
}
- }, images);
+ }, images);
- var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
- {
- InterleaveArgb = false,
- Column = new ImagePixelExtractorTransform.Column[1]{
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ InterleaveArgb = false,
+ Column = new ImagePixelExtractorTransform.Column[1]{
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "ImagePixels", UseAlpha=false}
}
- }, cropped);
+ }, cropped);
- IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
- {
- InterleaveArgb = false,
- Column = new VectorToImageTransform.Column[1]{
+ IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments()
+ {
+ InterleaveArgb = false,
+ Column = new VectorToImageTransform.Column[1]{
new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false}
}
- }, pixels);
-
- var fname = nameof(TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset) + "_model.zip";
+ }, pixels);
- var fh = env.CreateOutputFile(fname);
- using (var ch = env.Start("save"))
- TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ var fname = nameof(TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset) + "_model.zip";
- backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
- DeleteOutputPath(fname);
+ var fh = env.CreateOutputFile(fname);
+ using (var ch = env.Start("save"))
+ TrainUtils.SaveModel(env, ch, fh, null, new RoleMappedData(backToBitmaps));
+ backToBitmaps = ModelFileUtils.LoadPipeline(env, fh.OpenReadStream(), new MultiFileSource(dataFile));
+ DeleteOutputPath(fname);
- backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
- backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
- using (var cursor = backToBitmaps.GetRowCursor((x) => true))
- {
- var bitmapGetter = cursor.GetGetter(bitmapColumn);
- Bitmap restoredBitmap = default;
- var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
- Bitmap croppedBitmap = default;
- while (cursor.MoveNext())
- {
- bitmapGetter(ref restoredBitmap);
- Assert.NotNull(restoredBitmap);
- bitmapCropGetter(ref croppedBitmap);
- Assert.NotNull(croppedBitmap);
- for (int x = 0; x < imageWidth; x++)
- for (int y = 0; y < imageHeight; y++)
- {
- var c = croppedBitmap.GetPixel(x, y);
- var r = restoredBitmap.GetPixel(x, y);
- Assert.True(c.R == r.R && c.G == r.G && c.B == r.B);
- }
- }
+ backToBitmaps.Schema.TryGetColumnIndex("ImageRestored", out int bitmapColumn);
+ backToBitmaps.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn);
+ using (var cursor = backToBitmaps.GetRowCursor((x) => true))
+ {
+ var bitmapGetter = cursor.GetGetter(bitmapColumn);
+ Bitmap restoredBitmap = default;
+
+ var bitmapCropGetter = cursor.GetGetter(cropBitmapColumn);
+ Bitmap croppedBitmap = default;
+ while (cursor.MoveNext())
+ {
+ bitmapGetter(ref restoredBitmap);
+ Assert.NotNull(restoredBitmap);
+ bitmapCropGetter(ref croppedBitmap);
+ Assert.NotNull(croppedBitmap);
+ for (int x = 0; x < imageWidth; x++)
+ for (int y = 0; y < imageHeight; y++)
+ {
+ var c = croppedBitmap.GetPixel(x, y);
+ var r = restoredBitmap.GetPixel(x, y);
+ Assert.True(c.R == r.R && c.G == r.G && c.B == r.B);
+ }
}
+ Done();
}
- Done();
}
}
}
diff --git a/test/Microsoft.ML.Tests/OnnxTests.cs b/test/Microsoft.ML.Tests/OnnxTests.cs
index 1cac50611c..8e513a5a93 100644
--- a/test/Microsoft.ML.Tests/OnnxTests.cs
+++ b/test/Microsoft.ML.Tests/OnnxTests.cs
@@ -82,70 +82,68 @@ public class BreastCancerClusterPrediction
[Fact]
public void InitializerCreationTest()
{
- using (var env = new ConsoleEnvironment())
- {
- // Create the actual implementation
- var ctxImpl = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", Runtime.Model.Onnx.OnnxVersion.Stable);
-
- // Use implementation as in the actual conversion code
- var ctx = ctxImpl as OnnxContext;
- ctx.AddInitializer(9.4f, "float");
- ctx.AddInitializer(17L, "int64");
- ctx.AddInitializer("36", "string");
- ctx.AddInitializer(new List { 9.4f, 1.7f, 3.6f }, new List { 1, 3 }, "floats");
- ctx.AddInitializer(new List { 94L, 17L, 36L }, new List { 1, 3 }, "int64s");
- ctx.AddInitializer(new List { "94" , "17", "36" }, new List { 1, 3 }, "strings");
-
- var model = ctxImpl.MakeModel();
-
- var floatScalar = model.Graph.Initializer[0];
- Assert.True(floatScalar.Name == "float");
- Assert.True(floatScalar.Dims.Count == 0);
- Assert.True(floatScalar.FloatData.Count == 1);
- Assert.True(floatScalar.FloatData[0] == 9.4f);
-
- var int64Scalar = model.Graph.Initializer[1];
- Assert.True(int64Scalar.Name == "int64");
- Assert.True(int64Scalar.Dims.Count == 0);
- Assert.True(int64Scalar.Int64Data.Count == 1);
- Assert.True(int64Scalar.Int64Data[0] == 17L);
-
- var stringScalar = model.Graph.Initializer[2];
- Assert.True(stringScalar.Name == "string");
- Assert.True(stringScalar.Dims.Count == 0);
- Assert.True(stringScalar.StringData.Count == 1);
- Assert.True(stringScalar.StringData[0].ToStringUtf8() == "36");
-
- var floatsTensor = model.Graph.Initializer[3];
- Assert.True(floatsTensor.Name == "floats");
- Assert.True(floatsTensor.Dims.Count == 2);
- Assert.True(floatsTensor.Dims[0] == 1);
- Assert.True(floatsTensor.Dims[1] == 3);
- Assert.True(floatsTensor.FloatData.Count == 3);
- Assert.True(floatsTensor.FloatData[0] == 9.4f);
- Assert.True(floatsTensor.FloatData[1] == 1.7f);
- Assert.True(floatsTensor.FloatData[2] == 3.6f);
-
- var int64sTensor = model.Graph.Initializer[4];
- Assert.True(int64sTensor.Name == "int64s");
- Assert.True(int64sTensor.Dims.Count == 2);
- Assert.True(int64sTensor.Dims[0] == 1);
- Assert.True(int64sTensor.Dims[1] == 3);
- Assert.True(int64sTensor.Int64Data.Count == 3);
- Assert.True(int64sTensor.Int64Data[0] == 94L);
- Assert.True(int64sTensor.Int64Data[1] == 17L);
- Assert.True(int64sTensor.Int64Data[2] == 36L);
-
- var stringsTensor = model.Graph.Initializer[5];
- Assert.True(stringsTensor.Name == "strings");
- Assert.True(stringsTensor.Dims.Count == 2);
- Assert.True(stringsTensor.Dims[0] == 1);
- Assert.True(stringsTensor.Dims[1] == 3);
- Assert.True(stringsTensor.StringData.Count == 3);
- Assert.True(stringsTensor.StringData[0].ToStringUtf8() == "94");
- Assert.True(stringsTensor.StringData[1].ToStringUtf8() == "17");
- Assert.True(stringsTensor.StringData[2].ToStringUtf8() == "36");
- }
+ var env = new MLContext();
+ // Create the actual implementation
+ var ctxImpl = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", Runtime.Model.Onnx.OnnxVersion.Stable);
+
+ // Use implementation as in the actual conversion code
+ var ctx = ctxImpl as OnnxContext;
+ ctx.AddInitializer(9.4f, "float");
+ ctx.AddInitializer(17L, "int64");
+ ctx.AddInitializer("36", "string");
+ ctx.AddInitializer(new List { 9.4f, 1.7f, 3.6f }, new List { 1, 3 }, "floats");
+ ctx.AddInitializer(new List { 94L, 17L, 36L }, new List { 1, 3 }, "int64s");
+ ctx.AddInitializer(new List { "94", "17", "36" }, new List { 1, 3 }, "strings");
+
+ var model = ctxImpl.MakeModel();
+
+ var floatScalar = model.Graph.Initializer[0];
+ Assert.True(floatScalar.Name == "float");
+ Assert.True(floatScalar.Dims.Count == 0);
+ Assert.True(floatScalar.FloatData.Count == 1);
+ Assert.True(floatScalar.FloatData[0] == 9.4f);
+
+ var int64Scalar = model.Graph.Initializer[1];
+ Assert.True(int64Scalar.Name == "int64");
+ Assert.True(int64Scalar.Dims.Count == 0);
+ Assert.True(int64Scalar.Int64Data.Count == 1);
+ Assert.True(int64Scalar.Int64Data[0] == 17L);
+
+ var stringScalar = model.Graph.Initializer[2];
+ Assert.True(stringScalar.Name == "string");
+ Assert.True(stringScalar.Dims.Count == 0);
+ Assert.True(stringScalar.StringData.Count == 1);
+ Assert.True(stringScalar.StringData[0].ToStringUtf8() == "36");
+
+ var floatsTensor = model.Graph.Initializer[3];
+ Assert.True(floatsTensor.Name == "floats");
+ Assert.True(floatsTensor.Dims.Count == 2);
+ Assert.True(floatsTensor.Dims[0] == 1);
+ Assert.True(floatsTensor.Dims[1] == 3);
+ Assert.True(floatsTensor.FloatData.Count == 3);
+ Assert.True(floatsTensor.FloatData[0] == 9.4f);
+ Assert.True(floatsTensor.FloatData[1] == 1.7f);
+ Assert.True(floatsTensor.FloatData[2] == 3.6f);
+
+ var int64sTensor = model.Graph.Initializer[4];
+ Assert.True(int64sTensor.Name == "int64s");
+ Assert.True(int64sTensor.Dims.Count == 2);
+ Assert.True(int64sTensor.Dims[0] == 1);
+ Assert.True(int64sTensor.Dims[1] == 3);
+ Assert.True(int64sTensor.Int64Data.Count == 3);
+ Assert.True(int64sTensor.Int64Data[0] == 94L);
+ Assert.True(int64sTensor.Int64Data[1] == 17L);
+ Assert.True(int64sTensor.Int64Data[2] == 36L);
+
+ var stringsTensor = model.Graph.Initializer[5];
+ Assert.True(stringsTensor.Name == "strings");
+ Assert.True(stringsTensor.Dims.Count == 2);
+ Assert.True(stringsTensor.Dims[0] == 1);
+ Assert.True(stringsTensor.Dims[1] == 3);
+ Assert.True(stringsTensor.StringData.Count == 3);
+ Assert.True(stringsTensor.StringData[0].ToStringUtf8() == "94");
+ Assert.True(stringsTensor.StringData[1].ToStringUtf8() == "17");
+ Assert.True(stringsTensor.StringData[2].ToStringUtf8() == "36");
}
[Fact]
@@ -259,8 +257,12 @@ public void KeyToVectorWithBagTest()
});
var vectorizer = new CategoricalOneHotVectorizer();
- var categoricalColumn = new OneHotEncodingTransformerColumn() {
- OutputKind = OneHotEncodingTransformerOutputKind.Bag, Name = "F2", Source = "F2" };
+ var categoricalColumn = new OneHotEncodingTransformerColumn()
+ {
+ OutputKind = OneHotEncodingTransformerOutputKind.Bag,
+ Name = "F2",
+ Source = "F2"
+ };
vectorizer.Column = new OneHotEncodingTransformerColumn[1] { categoricalColumn };
pipeline.Add(vectorizer);
pipeline.Add(new ColumnConcatenator("Features", "F1", "F2"));
@@ -306,7 +308,7 @@ public void WordEmbeddingsTest()
{
Separator = new[] { '\t' },
HasHeader = false,
- Column = new []
+ Column = new[]
{
new TextLoaderColumn()
{
@@ -317,7 +319,7 @@ public void WordEmbeddingsTest()
}
}
});
-
+
var modelPath = GetDataPath(@"shortsentiment.emd");
var embed = new WordEmbeddings() { CustomLookupTable = modelPath };
embed.AddColumn("Cat", "Cat");
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs
index 2f5ba899fc..a9a13385a8 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs
@@ -153,7 +153,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m
[Fact]
public void TrainRegressionModel()
=> TrainRegression(GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename), GetDataPath(TestDatasets.generatedRegressionDataset.testFilename),
- DeleteOutputPath("cook_model.zip"));
+ DeleteOutputPath("cook_model_static.zip"));
private ITransformer TrainOnIris(string irisDataPath)
{
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs
index 1bfa09377b..32486abb27 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs
@@ -38,7 +38,7 @@ public void New_TrainWithInitialPredictor()
var secondTrainer = ml.BinaryClassification.Trainers.AveragedPerceptron("Label","Features");
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");
- var finalModel = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: firstModel.Model));
+ var finalModel = ((ITrainer)secondTrainer).Train(new TrainContext(trainRoles, initialPredictor: firstModel.Model));
}
}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs
index 9507d4bce3..b3086a68a1 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs
@@ -62,80 +62,78 @@ private class TwoIChannelsOnlyOneWithAttribute
[Fact]
public void CursorChannelExposedInMapTransform()
{
- using (var env = new ConsoleEnvironment(0))
- {
- // Correct use of CursorChannel attribute.
- var data1 = Utils.CreateArray(10, new OneIChannelWithAttribute());
- var idv1 = env.CreateDataView(data1);
- Assert.Null(data1[0].Channel);
-
- var filter1 = LambdaTransform.CreateFilter(env, idv1,
- (input, state) =>
- {
- Assert.NotNull(input.Channel);
- return false;
- }, null);
- filter1.GetRowCursor(col => true).MoveNext();
-
- // Error case: non-IChannel field marked with attribute.
- var data2 = Utils.CreateArray(10, new OneStringWithAttribute());
- var idv2 = env.CreateDataView(data2);
- Assert.Null(data2[0].Channel);
-
- var filter2 = LambdaTransform.CreateFilter(env, idv2,
- (input, state) =>
- {
- Assert.Null(input.Channel);
- return false;
- }, null);
- try
+ var env = new MLContext(seed: 0);
+ // Correct use of CursorChannel attribute.
+ var data1 = Utils.CreateArray(10, new OneIChannelWithAttribute());
+ var idv1 = env.CreateDataView(data1);
+ Assert.Null(data1[0].Channel);
+
+ var filter1 = LambdaTransform.CreateFilter(env, idv1,
+ (input, state) =>
{
- filter2.GetRowCursor(col => true).MoveNext();
- Assert.True(false, "Throw an error if attribute is applied to a field that is not an IChannel.");
- }
- catch (InvalidOperationException ex)
+ Assert.NotNull(input.Channel);
+ return false;
+ }, null);
+ filter1.GetRowCursor(col => true).MoveNext();
+
+ // Error case: non-IChannel field marked with attribute.
+ var data2 = Utils.CreateArray(10, new OneStringWithAttribute());
+ var idv2 = env.CreateDataView(data2);
+ Assert.Null(data2[0].Channel);
+
+ var filter2 = LambdaTransform.CreateFilter(env, idv2,
+ (input, state) =>
{
- Assert.True(ex.IsMarked());
- }
+ Assert.Null(input.Channel);
+ return false;
+ }, null);
+ try
+ {
+ filter2.GetRowCursor(col => true).MoveNext();
+ Assert.True(false, "Throw an error if attribute is applied to a field that is not an IChannel.");
+ }
+ catch (InvalidOperationException ex)
+ {
+ Assert.True(ex.IsMarked());
+ }
- // Error case: multiple fields marked with attributes.
- var data3 = Utils.CreateArray(10, new TwoIChannelsWithAttributes());
- var idv3 = env.CreateDataView(data3);
- Assert.Null(data3[0].ChannelOne);
- Assert.Null(data3[2].ChannelTwo);
+ // Error case: multiple fields marked with attributes.
+ var data3 = Utils.CreateArray(10, new TwoIChannelsWithAttributes());
+ var idv3 = env.CreateDataView(data3);
+ Assert.Null(data3[0].ChannelOne);
+ Assert.Null(data3[2].ChannelTwo);
- var filter3 = LambdaTransform.CreateFilter(env, idv3,
- (input, state) =>
- {
- Assert.Null(input.ChannelOne);
- Assert.Null(input.ChannelTwo);
- return false;
- }, null);
- try
+ var filter3 = LambdaTransform.CreateFilter(env, idv3,
+ (input, state) =>
{
- filter3.GetRowCursor(col => true).MoveNext();
- Assert.True(false, "Throw an error if attribute is applied to a field that is not an IChannel.");
- }
- catch (InvalidOperationException ex)
- {
- Assert.True(ex.IsMarked());
- }
+ Assert.Null(input.ChannelOne);
+ Assert.Null(input.ChannelTwo);
+ return false;
+ }, null);
+ try
+ {
+ filter3.GetRowCursor(col => true).MoveNext();
+ Assert.True(false, "Throw an error if attribute is applied to a field that is not an IChannel.");
+ }
+ catch (InvalidOperationException ex)
+ {
+ Assert.True(ex.IsMarked());
+ }
- // Correct case: non-marked IChannel field is not touched.
- var example4 = new TwoIChannelsOnlyOneWithAttribute();
- Assert.Null(example4.ChannelTwo);
- Assert.Null(example4.ChannelOne);
- var idv4 = env.CreateDataView(Utils.CreateArray(10, example4));
+ // Correct case: non-marked IChannel field is not touched.
+ var example4 = new TwoIChannelsOnlyOneWithAttribute();
+ Assert.Null(example4.ChannelTwo);
+ Assert.Null(example4.ChannelOne);
+ var idv4 = env.CreateDataView(Utils.CreateArray(10, example4));
- var filter4 = LambdaTransform.CreateFilter(env, idv4,
- (input, state) =>
- {
- Assert.Null(input.ChannelOne);
- Assert.NotNull(input.ChannelTwo);
- return false;
- }, null);
- filter1.GetRowCursor(col => true).MoveNext();
- }
+ var filter4 = LambdaTransform.CreateFilter(env, idv4,
+ (input, state) =>
+ {
+ Assert.Null(input.ChannelOne);
+ Assert.NotNull(input.ChannelTwo);
+ return false;
+ }, null);
+ filter1.GetRowCursor(col => true).MoveNext();
}
public class BreastCancerExample
@@ -149,44 +147,40 @@ public class BreastCancerExample
[Fact]
public void LambdaTransformCreate()
{
- using (var env = new ConsoleEnvironment(42))
- {
- var data = ReadBreastCancerExamples();
- var idv = env.CreateDataView(data);
+ var env = new MLContext(seed: 42);
+ var data = ReadBreastCancerExamples();
+ var idv = env.CreateDataView(data);
- var filter = LambdaTransform.CreateFilter(env, idv,
- (input, state) => input.Label == 0, null);
+ var filter = LambdaTransform.CreateFilter(env, idv,
+ (input, state) => input.Label == 0, null);
- Assert.Null(filter.GetRowCount());
+ Assert.Null(filter.GetRowCount());
- // test re-apply
- var applied = env.CreateDataView(data);
- applied = ApplyTransformUtils.ApplyAllTransformsToData(env, filter, applied);
+ // test re-apply
+ var applied = env.CreateDataView(data);
+ applied = ApplyTransformUtils.ApplyAllTransformsToData(env, filter, applied);
- var saver = new TextSaver(env, new TextSaver.Arguments());
- Assert.True(applied.Schema.TryGetColumnIndex("Label", out int label));
- using (var fs = File.Create(GetOutputPath(OutputRelativePath, "lambda-output.tsv")))
- saver.SaveData(fs, applied, label);
- }
+ var saver = new TextSaver(env, new TextSaver.Arguments());
+ Assert.True(applied.Schema.TryGetColumnIndex("Label", out int label));
+ using (var fs = File.Create(GetOutputPath(OutputRelativePath, "lambda-output.tsv")))
+ saver.SaveData(fs, applied, label);
}
[Fact]
public void TrainAveragedPerceptronWithCache()
{
- using (var env = new ConsoleEnvironment(0))
- {
- var dataFile = GetDataPath("breast-cancer.txt");
- var loader = TextLoader.Create(env, new TextLoader.Arguments(), new MultiFileSource(dataFile));
- var globalCounter = 0;
- var xf = LambdaTransform.CreateFilter