Skip to content

Renamed TlcEnvironment to Console. Also introduced LocalEnvironment #923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Api/TypedCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ public static ICursorable<TRow> AsCursorable<TRow>(this IDataView data, bool ign
where TRow : class, new()
{
// REVIEW: Take an env as a parameter.
var env = new TlcEnvironment();
var env = new ConsoleEnvironment();
return data.AsCursorable<TRow>(env, ignoreMissingColumns, schemaDefinition);
}

Expand Down Expand Up @@ -699,7 +699,7 @@ public static IEnumerable<TRow> AsEnumerable<TRow>(this IDataView data, bool reu
where TRow : class, new()
{
// REVIEW: Take an env as a parameter.
var env = new TlcEnvironment();
var env = new ConsoleEnvironment();
return data.AsEnumerable<TRow>(env, reuseRowObject, ignoreMissingColumns, schemaDefinition);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/ProgressReporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.ML.Runtime.Data
public static class ProgressReporting
{
/// <summary>
/// The progress channel for <see cref="TlcEnvironment"/>.
/// The progress channel for <see cref="ConsoleEnvironment"/>.
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConsoleEnvironment [](start = 48, length = 18)

Is it really only TlcEnvironment specific or can be used in any HostEnvironmentBase descendant? #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be used elsewhere, but I'm not sure, and I would rather not try now.


In reply to: 217869313 [](ancestors = 217869313)

/// This is coupled with a <see cref="ProgressTracker"/> that aggregates all events and returns them on demand.
/// </summary>
public sealed class ProgressChannel : IProgressChannel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ namespace Microsoft.ML.Runtime.Data
{
using Stopwatch = System.Diagnostics.Stopwatch;

public sealed class TlcEnvironment : HostEnvironmentBase<TlcEnvironment>
public sealed class ConsoleEnvironment : HostEnvironmentBase<ConsoleEnvironment>
{
public const string ComponentHistoryKey = "ComponentHistory";

private sealed class ConsoleWriter
{
private readonly object _lock;
private readonly TlcEnvironment _parent;
private readonly ConsoleEnvironment _parent;
private readonly TextWriter _out;
private readonly TextWriter _err;

Expand All @@ -34,7 +34,7 @@ private sealed class ConsoleWriter
private const int _maxDots = 50;
private int _dots;

public ConsoleWriter(TlcEnvironment parent, TextWriter outWriter, TextWriter errWriter)
public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter)
{
Contracts.AssertValue(parent);
Contracts.AssertValue(outWriter);
Expand Down Expand Up @@ -331,7 +331,7 @@ private bool PrintDot()
private sealed class Channel : ChannelBase
{
public readonly Stopwatch Watch;
public Channel(TlcEnvironment root, ChannelProviderBase parent, string shortName,
public Channel(ConsoleEnvironment root, ChannelProviderBase parent, string shortName,
Action<IMessageSource, ChannelMessage> dispatch)
: base(root, parent, shortName, dispatch)
{
Expand All @@ -358,20 +358,36 @@ protected override void DisposeCore()
private volatile ConsoleWriter _consoleWriter;
private readonly MessageSensitivity _sensitivityFlags;

public TlcEnvironment(int? seed = null, bool verbose = false,
/// <summary>
/// Create an ML.NET <see cref="IHostEnvironment"/> for local execution, with console feedback.
/// </summary>
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
/// <param name="verbose">Set to <c>true</c> for fully verbose logging.</param>
/// <param name="sensitivity">Allowed message sensitivity.</param>
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
/// <param name="outWriter">Text writer to print normal messages to.</param>
/// <param name="errWriter">Text writer to print error messages to.</param>
public ConsoleEnvironment(int? seed = null, bool verbose = false,
MessageSensitivity sensitivity = MessageSensitivity.All, int conc = 0,
TextWriter outWriter = null, TextWriter errWriter = null)
: this(RandomUtils.Create(seed), verbose, sensitivity, conc, outWriter, errWriter)
{
}

// REVIEW: do we really care about custom random? If we do, let's make this ctor public.
/// <summary>
/// This takes ownership of the random number generator.
/// Create an ML.NET environment for local execution, with console feedback.
/// </summary>
public TlcEnvironment(IRandom rand, bool verbose = false,
/// <param name="rand">An custom source of randomness to use in the environment.</param>
/// <param name="verbose">Set to <c>true</c> for fully verbose logging.</param>
/// <param name="sensitivity">Allowed message sensitivity.</param>
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
/// <param name="outWriter">Text writer to print normal messages to.</param>
/// <param name="errWriter">Text writer to print error messages to.</param>
private ConsoleEnvironment(IRandom rand, bool verbose = false,
MessageSensitivity sensitivity = MessageSensitivity.All, int conc = 0,
TextWriter outWriter = null, TextWriter errWriter = null)
: base(rand, verbose, conc, nameof(TlcEnvironment))
: base(rand, verbose, conc, nameof(ConsoleEnvironment))
{
Contracts.CheckValueOrNull(outWriter);
Contracts.CheckValueOrNull(errWriter);
Expand Down Expand Up @@ -401,7 +417,7 @@ protected override IFileHandle CreateTempFileCore(IHostEnvironment env, string s
return base.CreateTempFileCore(env, suffix, "TLC_" + prefix);
}

protected override IHost RegisterCore(HostEnvironmentBase<TlcEnvironment> source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
protected override IHost RegisterCore(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
{
Contracts.AssertValue(rand);
Contracts.AssertValueOrNull(parentFullName);
Expand All @@ -413,15 +429,15 @@ protected override IHost RegisterCore(HostEnvironmentBase<TlcEnvironment> source
protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
{
Contracts.AssertValue(parent);
Contracts.Assert(parent is TlcEnvironment);
Contracts.Assert(parent is ConsoleEnvironment);
Contracts.AssertNonEmpty(name);
return new Channel(this, parent, name, GetDispatchDelegate<ChannelMessage>());
}

protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase parent, string name)
{
Contracts.AssertValue(parent);
Contracts.Assert(parent is TlcEnvironment);
Contracts.Assert(parent is ConsoleEnvironment);
Contracts.AssertNonEmpty(name);
return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
}
Expand All @@ -439,11 +455,11 @@ internal IDisposable RedirectChannelOutput(TextWriter newOutWriter, TextWriter n

private sealed class OutputRedirector : IDisposable
{
private readonly TlcEnvironment _root;
private readonly ConsoleEnvironment _root;
private ConsoleWriter _oldConsoleWriter;
private readonly ConsoleWriter _newConsoleWriter;

public OutputRedirector(TlcEnvironment env, TextWriter newOutWriter, TextWriter newErrWriter)
public OutputRedirector(ConsoleEnvironment env, TextWriter newOutWriter, TextWriter newErrWriter)
{
Contracts.AssertValue(env);
Contracts.AssertValue(newOutWriter);
Expand All @@ -467,7 +483,7 @@ public void Dispose()

private sealed class Host : HostBase
{
public Host(HostEnvironmentBase<TlcEnvironment> source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
public Host(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
: base(source, shortName, parentFullName, rand, verbose, conc)
{
IsCancelled = source.IsCancelled;
Expand All @@ -489,7 +505,7 @@ protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase pare
return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
}

protected override IHost RegisterCore(HostEnvironmentBase<TlcEnvironment> source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
protected override IHost RegisterCore(HostEnvironmentBase<ConsoleEnvironment> source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
{
return new Host(source, shortName, parentFullName, rand, verbose, conc);
}
Expand Down
131 changes: 131 additions & 0 deletions src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;

namespace Microsoft.ML.Runtime.Data
{
using Stopwatch = System.Diagnostics.Stopwatch;
Copy link
Member

@eerhardt eerhardt Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other kind of Stopwatch is there? Can we just have a using System.Diagnostics; at the top? #ByDesign

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

System.Diagnostics is a pretty big namespace with some names that for whatever reason cleave pretty tightly to common terms in ML (e.g., we have 90% of our code bent around processing instance data, and sure enough we get Process and InstanceData... :P), after several instances of developer confusion, we decided to adopt this practice. My innate tendency would be to keep up our existing practice of, when you want only a stopwatch, just import only that, rather than to relax it just yet. (Or better yet C# could have a less awkward way to import a single class from a namespace.)

Of course we could relax it. That practice dates from the times before we had channels, where literally every component instantiated stopwatches to time themselves, and there were lots of namespace collisions, and so lots of cost. Now that we do have channels, few things produce their own stopwatches -- use of System.Diagnostics is quite limited. So the amount of pain that would be generated now by your suggestion would be much less.


In reply to: 218094176 [](ancestors = 218094176)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes to the above


In reply to: 218118934 [](ancestors = 218118934,218094176)


/// <summary>
/// An ML.NET environment for local execution.
/// </summary>
public sealed class LocalEnvironment : HostEnvironmentBase<LocalEnvironment>
Copy link
Member

@eerhardt eerhardt Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have some tests that cover this class. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made API scenario tests use it.


In reply to: 218093190 [](ancestors = 218093190)

{
private sealed class Channel : ChannelBase
{
public readonly Stopwatch Watch;
public Channel(LocalEnvironment root, ChannelProviderBase parent, string shortName,
Action<IMessageSource, ChannelMessage> dispatch)
: base(root, parent, shortName, dispatch)
{
Watch = Stopwatch.StartNew();
Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, MessageSensitivity.None, "Channel started"));
}

public override void Done()
{
Watch.Stop();
ChannelFinished();
base.Done();
}

private void ChannelFinished()
=> Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, MessageSensitivity.None, "Channel finished. Elapsed { 0:c }.", Watch.Elapsed));

protected override void DisposeCore()
Copy link
Member

@eerhardt eerhardt Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI - I logged #930 to implement the Dispose Pattern correctly on these types. #Resolved

{
if (IsActive)
{
ChannelFinished();
Watch.Stop();
}

Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, MessageSensitivity.None, "Channel disposed"));
base.DisposeCore();
}
}

/// <summary>
/// Create an ML.NET <see cref="IHostEnvironment"/> for local execution.
/// </summary>
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
public LocalEnvironment(int? seed = null, int conc = 0)
: base(RandomUtils.Create(seed), verbose: false, conc)
{
}

/// <summary>
/// Add a custom listener to the messages of ML.NET components.
/// </summary>
public void AddListener(Action<IMessageSource, ChannelMessage> listener)
=> AddListener<ChannelMessage>(listener);

/// <summary>
/// Remove a previously added a custom listener.
/// </summary>
public void RemoveListener(Action<IMessageSource, ChannelMessage> listener)
=> RemoveListener<ChannelMessage>(listener);

protected override IFileHandle CreateTempFileCore(IHostEnvironment env, string suffix = null, string prefix = null)
=> base.CreateTempFileCore(env, suffix, "Local_" + prefix);

protected override IHost RegisterCore(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
{
Contracts.AssertValue(rand);
Contracts.AssertValueOrNull(parentFullName);
Contracts.AssertNonEmpty(shortName);
Contracts.Assert(source == this || source is Host);
return new Host(source, shortName, parentFullName, rand, verbose, conc);
}

protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
{
Contracts.AssertValue(parent);
Contracts.Assert(parent is LocalEnvironment);
Contracts.AssertNonEmpty(name);
return new Channel(this, parent, name, GetDispatchDelegate<ChannelMessage>());
}

protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase parent, string name)
{
Contracts.AssertValue(parent);
Contracts.Assert(parent is LocalEnvironment);
Contracts.AssertNonEmpty(name);
return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
}

private sealed class Host : HostBase
{
public Host(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
: base(source, shortName, parentFullName, rand, verbose, conc)
{
IsCancelled = source.IsCancelled;
}

protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
{
Contracts.AssertValue(parent);
Contracts.Assert(parent is Host);
Contracts.AssertNonEmpty(name);
return new Channel(Root, parent, name, GetDispatchDelegate<ChannelMessage>());
}

protected override IPipe<TMessage> CreatePipe<TMessage>(ChannelProviderBase parent, string name)
{
Contracts.AssertValue(parent);
Contracts.Assert(parent is Host);
Contracts.AssertNonEmpty(name);
return new Pipe<TMessage>(parent, name, GetDispatchDelegate<TMessage>());
}

protected override IHost RegisterCore(HostEnvironmentBase<LocalEnvironment> source, string shortName, string parentFullName, IRandom rand, bool verbose, int? conc)
{
return new Host(source, shortName, parentFullName, rand, verbose, conc);
}
}
}

}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args)
ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();
ParallelTraining.InitEnvironment();
// REVIEW: CLR 4.6 has a bug that is only exposed in Scope, and if we trigger GC.Collect in scope environment
// with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from TlcEnvironment.
AllowGC = (env is HostEnvironmentBase<TlcEnvironment>);
// with memory consumption more than 5GB, GC get stuck in infinite loop. So for now let's call GC only if we call things from ConsoleEnvironment.
AllowGC = (env is HostEnvironmentBase<ConsoleEnvironment>);
Tests = new List<Test>();

InitializeThreads(numThreads);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Legacy/LearningPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public PredictionModel<TInput, TOutput> Train<TInput, TOutput>()
where TInput : class
where TOutput : class, new()
{
using (var environment = new TlcEnvironment(seed: _seed, conc: _conc))
using (var environment = new ConsoleEnvironment(seed: _seed, conc: _conc))
{
Experiment experiment = environment.CreateExperiment();
ILearningPipelineStep step = null;
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Legacy/LearningPipelineDebugProxy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ internal sealed class LearningPipelineDebugProxy
private const int MaxSlotNamesToDisplay = 100;

private readonly LearningPipeline _pipeline;
private readonly TlcEnvironment _environment;
private readonly ConsoleEnvironment _environment;
private IDataView _preview;
private Exception _pipelineExecutionException;
private PipelineItemDebugColumn[] _columns;
Expand All @@ -39,7 +39,7 @@ public LearningPipelineDebugProxy(LearningPipeline pipeline)
_pipeline = new LearningPipeline();

// use a ConcurrencyFactor of 1 so other threads don't need to run in the debugger
_environment = new TlcEnvironment(conc: 1);
_environment = new ConsoleEnvironment(conc: 1);

foreach (ILearningPipelineItem item in pipeline)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public sealed partial class BinaryClassificationEvaluator
/// </returns>
public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
using (var environment = new TlcEnvironment())
using (var environment = new ConsoleEnvironment())
{
environment.CheckValue(model, nameof(model));
environment.CheckValue(testData, nameof(testData));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Legacy/Models/ClassificationEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public sealed partial class ClassificationEvaluator
/// </returns>
public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
using (var environment = new TlcEnvironment())
using (var environment = new ConsoleEnvironment())
{
environment.CheckValue(model, nameof(model));
environment.CheckValue(testData, nameof(testData));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Legacy/Models/ClusterEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public sealed partial class ClusterEvaluator
/// </returns>
public ClusterMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
using (var environment = new TlcEnvironment())
using (var environment = new ConsoleEnvironment())
{
environment.CheckValue(model, nameof(model));
environment.CheckValue(testData, nameof(testData));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Legacy/Models/CrossValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
where TInput : class
where TOutput : class, new()
{
using (var environment = new TlcEnvironment())
using (var environment = new ConsoleEnvironment())
{
Experiment subGraph = environment.CreateExperiment();
ILearningPipelineStep step = null;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Legacy/Models/OneVersusAll.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public OvaPipelineItem(ITrainerInputWithLabel trainer, bool useProbabilities)

public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
{
using (var env = new TlcEnvironment())
using (var env = new ConsoleEnvironment())
{
var subgraph = env.CreateExperiment();
subgraph.Add(_trainer);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Legacy/Models/OnnxConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public sealed partial class OnnxConverter
/// <param name="model">Model that needs to be converted to ONNX format.</param>
public void Convert(PredictionModel model)
{
using (var environment = new TlcEnvironment())
using (var environment = new ConsoleEnvironment())
{
environment.CheckValue(model, nameof(model));

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Legacy/Models/RegressionEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public sealed partial class RegressionEvaluator
/// </returns>
public RegressionMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
{
using (var environment = new TlcEnvironment())
using (var environment = new ConsoleEnvironment())
{
environment.CheckValue(model, nameof(model));
environment.CheckValue(testData, nameof(testData));
Expand Down
Loading