From 0bd3788e2d2426b8c778d036efe4e04f6b0ea850 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Wed, 7 Nov 2018 14:29:48 -0800 Subject: [PATCH 01/12] Internalization of Utils. Move out Sweeper types. --- .../Data/ReadOnlyMemoryUtils.cs | 12 ------ .../Properties/AssemblyInfo.cs | 11 ++++- src/Microsoft.ML.Core/Utilities/BigArray.cs | 3 +- src/Microsoft.ML.Core/Utilities/BinFinder.cs | 9 ++-- src/Microsoft.ML.Core/Utilities/BitUtils.cs | 2 +- src/Microsoft.ML.Core/Utilities/CharUtils.cs | 3 +- .../Utilities/CmdIndenter.cs | 3 +- .../Utilities/DoubleParser.cs | 3 +- .../Utilities/FixedSizeQueue.cs | 3 +- src/Microsoft.ML.Core/Utilities/FloatUtils.cs | 3 +- src/Microsoft.ML.Core/Utilities/HashArray.cs | 3 +- src/Microsoft.ML.Core/Utilities/Hashing.cs | 3 +- src/Microsoft.ML.Core/Utilities/Heap.cs | 8 ++-- .../Utilities/HybridMemoryStream.cs | 3 +- .../Utilities/IndentingTextWriter.cs | 3 +- .../Utilities/ListExtensions.cs | 42 ------------------- src/Microsoft.ML.Core/Utilities/LruCache.cs | 3 +- src/Microsoft.ML.Core/Utilities/MathUtils.cs | 3 +- .../Utilities/MatrixTransposeOps.cs | 3 +- src/Microsoft.ML.Core/Utilities/MemUtils.cs | 27 ------------ src/Microsoft.ML.Core/Utilities/MinWaiter.cs | 3 +- src/Microsoft.ML.Core/Utilities/NormStr.cs | 3 +- src/Microsoft.ML.Core/Utilities/ObjectPool.cs | 8 ++-- .../Utilities/OrderedWaiter.cs | 3 +- src/Microsoft.ML.Core/Utilities/PathUtils.cs | 2 +- .../Utilities/PlatformUtils.cs | 3 +- src/Microsoft.ML.Core/Utilities/Stream.cs | 5 +-- .../Utilities/ThreadUtils.cs | 2 +- src/Microsoft.ML.Core/Utilities/Utils.cs | 3 +- .../Evaluators/AnomalyDetectionEvaluator.cs | 2 +- src/Microsoft.ML.Data/Model/ModelHeader.cs | 2 +- .../Model/ModelLoadContext.cs | 5 ++- .../Model/ModelSaveContext.cs | 6 ++- .../Properties/AssemblyInfo.cs | 31 ++++++++++++-- .../Transforms/TermTransformImpl.cs | 14 +++---- .../Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs | 6 +-- .../LightGbmTrainerBase.cs | 4 +- .../DatasetFeaturesInference.cs | 4 +- .../Microsoft.ML.PipelineInference.csproj | 2 +- .../MulticlassLogisticRegression.cs | 21 +--------- .../ISweeper.cs | 0 src/Microsoft.ML.Sweeper/Parameters.cs | 1 - src/Microsoft.ML.Sweeper/SweepCommand.cs | 3 -- .../SweepResultEvaluator.cs | 1 - src/Microsoft.ML.Sweeper/SynthConfigRunner.cs | 3 -- ...AdaptiveSingularSpectrumSequenceModeler.cs | 31 +++++++------- .../ExponentialAverageTransform.cs | 8 ++-- .../IidAnomalyDetectionBase.cs | 6 +-- .../MovingAverageTransform.cs | 12 +++--- .../PValueTransform.cs | 8 ++-- .../PercentileThresholdTransform.cs | 12 +++--- ...uenceModeler.cs => SequenceModelerBase.cs} | 25 +++++++---- ...SequentialAnomalyDetectionTransformBase.cs | 36 ++++++++-------- .../SequentialTransformBase.cs | 37 ++++++++-------- .../SequentialTransformerBase.cs | 38 +++++++++-------- .../SlidingWindowTransformBase.cs | 8 ++-- .../SsaAnomalyDetectionBase.cs | 14 +++---- .../TermLookupTransform.cs | 4 +- .../Microsoft.ML.Sweeper.Tests/TestSweeper.cs | 2 +- 59 files changed, 242 insertions(+), 286 deletions(-) delete mode 100644 src/Microsoft.ML.Core/Utilities/ListExtensions.cs delete mode 100644 src/Microsoft.ML.Core/Utilities/MemUtils.cs rename src/{Microsoft.ML.Core/Prediction => Microsoft.ML.Sweeper}/ISweeper.cs (100%) rename src/Microsoft.ML.TimeSeries/{ISequenceModeler.cs => SequenceModelerBase.cs} (75%) diff --git a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs index 4b207ab507..4f89e29201 100644 --- a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs +++ b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs @@ -208,18 +208,6 @@ public static ReadOnlyMemory TrimEndWhiteSpace(ReadOnlyMemory memory return memory.Slice(0, ichLim); } - public static NormStr AddToPool(ReadOnlyMemory memory, NormStr.Pool pool) - { - Contracts.CheckValue(pool, nameof(pool)); - return pool.Add(memory); - } - - public static NormStr FindInPool(ReadOnlyMemory memory, NormStr.Pool pool) - { - Contracts.CheckValue(pool, nameof(pool)); - return pool.Get(memory); - } - public static void AddLowerCaseToStringBuilder(ReadOnlySpan span, StringBuilder sb) { Contracts.CheckValue(sb, nameof(sb)); diff --git a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs index 838e12384c..5a30b74dee 100644 --- a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs @@ -6,11 +6,16 @@ using Microsoft.ML; [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipelineTesting" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformTest" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] -[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ResultProcessor" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Data" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Api" + PublicKey.Value)] @@ -21,12 +26,16 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Onnx" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransform" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Parquet" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PCA" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Recommender" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Runtime.ImageAnalytics" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Scoring" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StandardLearners" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)] [assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.Core/Utilities/BigArray.cs b/src/Microsoft.ML.Core/Utilities/BigArray.cs index b005c4cefe..7b78d583ae 100644 --- a/src/Microsoft.ML.Core/Utilities/BigArray.cs +++ b/src/Microsoft.ML.Core/Utilities/BigArray.cs @@ -19,7 +19,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// than the total capacity. /// /// The type of entries. - public sealed class BigArray : IEnumerable + [BestFriend] + internal sealed class BigArray : IEnumerable { // REVIEW: This class merges and replaces the original private BigArray implementation in CacheDataView. // There the block size was 25 bits. Need to understand the performance implication of this 32x change. diff --git a/src/Microsoft.ML.Core/Utilities/BinFinder.cs b/src/Microsoft.ML.Core/Utilities/BinFinder.cs index 5cd7fd2a61..cdfd0ad08b 100644 --- a/src/Microsoft.ML.Core/Utilities/BinFinder.cs +++ b/src/Microsoft.ML.Core/Utilities/BinFinder.cs @@ -9,7 +9,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public abstract class BinFinderBase + [BestFriend] + internal abstract class BinFinderBase { private Single[] _valuesSng; // distinct values private Double[] _valuesDbl; // distinct values @@ -278,7 +279,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities using EnergyType = System.Int64; // Uses the energy function: sum(1,N) dx^2 where dx is the difference in accum values. - public sealed class GreedyBinFinder : BinFinderBase + [BestFriend] + internal sealed class GreedyBinFinder : BinFinderBase { // Potential drop location for another peg, together with its energy improvement. // PlacePegs uses a heap of these. Note that this is a struct so size matters. @@ -529,7 +531,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities using EnergyType = System.Double; // Uses dynamic programming. - public sealed class DynamicBinFinder : BinFinderBase + [BestFriend] + internal sealed class DynamicBinFinder : BinFinderBase { private int[] _accum; // integral of counts diff --git a/src/Microsoft.ML.Core/Utilities/BitUtils.cs b/src/Microsoft.ML.Core/Utilities/BitUtils.cs index 9af79139ef..376b0ac8ef 100644 --- a/src/Microsoft.ML.Core/Utilities/BitUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/BitUtils.cs @@ -7,7 +7,7 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static partial class Utils + internal static partial class Utils { private const int CbitUint = 32; private const int CbitUlong = 64; diff --git a/src/Microsoft.ML.Core/Utilities/CharUtils.cs b/src/Microsoft.ML.Core/Utilities/CharUtils.cs index bf7ae4677e..d88197c8e7 100644 --- a/src/Microsoft.ML.Core/Utilities/CharUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/CharUtils.cs @@ -10,7 +10,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static class CharUtils + [BestFriend] + internal static class CharUtils { private const int CharsCount = 0x10000; private static volatile char[] _lowerInvariantChars; diff --git a/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs b/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs index d82fffe682..d25a731da9 100644 --- a/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs +++ b/src/Microsoft.ML.Core/Utilities/CmdIndenter.cs @@ -11,7 +11,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static class CmdIndenter + [BestFriend] + internal static class CmdIndenter { /// /// Get indented version of command line or same string if we unable to produce it. diff --git a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs index a1a82d5218..4f2ef9ad80 100644 --- a/src/Microsoft.ML.Core/Utilities/DoubleParser.cs +++ b/src/Microsoft.ML.Core/Utilities/DoubleParser.cs @@ -13,7 +13,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public class DoubleParser + [BestFriend] + internal static class DoubleParser { private const ulong TopBit = 0x8000000000000000UL; private const ulong TopTwoBits = 0xC000000000000000UL; diff --git a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs index 1c375bd625..bdcea0c0ae 100644 --- a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs +++ b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs @@ -12,7 +12,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// A fixed-length circular array. Items are added at the end. If the array is full, adding /// an item will result in discarding the least recently added item. /// - public sealed class FixedSizeQueue + [BestFriend] + internal sealed class FixedSizeQueue { private readonly T[] _array; private int _startIndex; diff --git a/src/Microsoft.ML.Core/Utilities/FloatUtils.cs b/src/Microsoft.ML.Core/Utilities/FloatUtils.cs index fdbc59aa89..06d403da9a 100644 --- a/src/Microsoft.ML.Core/Utilities/FloatUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/FloatUtils.cs @@ -8,7 +8,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static class FloatUtils + [BestFriend] + internal static class FloatUtils { // This is used to read and write the bits of a Double. // Thanks to Vance Morrison for educating me about this excellent aliasing mechanism. diff --git a/src/Microsoft.ML.Core/Utilities/HashArray.cs b/src/Microsoft.ML.Core/Utilities/HashArray.cs index 27f0ec9b5d..bfaafbdf8d 100644 --- a/src/Microsoft.ML.Core/Utilities/HashArray.cs +++ b/src/Microsoft.ML.Core/Utilities/HashArray.cs @@ -17,7 +17,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// Also implements memory efficient sorting. /// Note: Supports adding and looking up of items but does not support removal of items. /// - public sealed class HashArray + [BestFriend] + internal sealed class HashArray // REVIEW: May want to not consider not making TItem have to be IComparable but instead // could support user specified sort order. where TItem : IEquatable, IComparable diff --git a/src/Microsoft.ML.Core/Utilities/Hashing.cs b/src/Microsoft.ML.Core/Utilities/Hashing.cs index a15677451b..ae36fae95d 100644 --- a/src/Microsoft.ML.Core/Utilities/Hashing.cs +++ b/src/Microsoft.ML.Core/Utilities/Hashing.cs @@ -9,7 +9,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static class Hashing + [BestFriend] + internal static class Hashing { private const uint _defaultSeed = (5381 << 16) + 5381; diff --git a/src/Microsoft.ML.Core/Utilities/Heap.cs b/src/Microsoft.ML.Core/Utilities/Heap.cs index 241acd0209..7652163f10 100644 --- a/src/Microsoft.ML.Core/Utilities/Heap.cs +++ b/src/Microsoft.ML.Core/Utilities/Heap.cs @@ -12,7 +12,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// /// Implements a heap. /// - public sealed class Heap + [BestFriend] + internal sealed class Heap { private readonly List _rgv; // The heap elements. The 0th element is a dummy. private readonly Func _fnReverse; @@ -196,7 +197,7 @@ private void BubbleDown(int iv) /// /// For the heap to allow deletion, the heap node has to derive from this class. /// - public abstract partial class HeapNode + internal abstract class HeapNode { // Where this node lives in the heap. Zero means it isn't in the heap. private int _index; @@ -207,10 +208,7 @@ protected HeapNode() } public bool InHeap { get { return _index > 0; } } - } - public abstract partial class HeapNode - { /// /// Implements a heap. /// diff --git a/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs b/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs index 73b4c4a828..02f713dd6e 100644 --- a/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs +++ b/src/Microsoft.ML.Core/Utilities/HybridMemoryStream.cs @@ -15,7 +15,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// file system. This can be useful if we have intermediate operations that require streams. /// The temporary file will be destroyed if the object is properly disposed. /// - public sealed class HybridMemoryStream : Stream + [BestFriend] + internal sealed class HybridMemoryStream : Stream { private MemoryStream _memStream; private Stream _overflowStream; diff --git a/src/Microsoft.ML.Core/Utilities/IndentingTextWriter.cs b/src/Microsoft.ML.Core/Utilities/IndentingTextWriter.cs index e42bdf6519..d71bc93a3b 100644 --- a/src/Microsoft.ML.Core/Utilities/IndentingTextWriter.cs +++ b/src/Microsoft.ML.Core/Utilities/IndentingTextWriter.cs @@ -8,7 +8,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public sealed class IndentingTextWriter : TextWriter + [BestFriend] + internal sealed class IndentingTextWriter : TextWriter { public struct Scope : IDisposable { diff --git a/src/Microsoft.ML.Core/Utilities/ListExtensions.cs b/src/Microsoft.ML.Core/Utilities/ListExtensions.cs deleted file mode 100644 index 22c45a3bca..0000000000 --- a/src/Microsoft.ML.Core/Utilities/ListExtensions.cs +++ /dev/null @@ -1,42 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; - -namespace Microsoft.ML.Runtime.Internal.Utilities -{ - public static class ListExtensions - { - public static void Push(this List list, T item) - { - Contracts.AssertValue(list); - list.Add(item); - } - - public static T Pop(this List list) - { - Contracts.AssertValue(list); - Contracts.Assert(list.Count > 0); - int index = list.Count - 1; - T item = list[index]; - list.RemoveAt(index); - return item; - } - - public static void PopTo(this List list, int depth) - { - Contracts.AssertValue(list); - Contracts.Assert(0 <= depth && depth <= list.Count); - list.RemoveRange(depth, list.Count - depth); - } - - public static T Peek(this List list) - { - Contracts.AssertValue(list); - Contracts.Assert(list.Count > 0); - return list[list.Count - 1]; - } - } -} diff --git a/src/Microsoft.ML.Core/Utilities/LruCache.cs b/src/Microsoft.ML.Core/Utilities/LruCache.cs index 21897f4e27..e041efde02 100644 --- a/src/Microsoft.ML.Core/Utilities/LruCache.cs +++ b/src/Microsoft.ML.Core/Utilities/LruCache.cs @@ -10,7 +10,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// /// Implements a least recently used cache. /// - public sealed class LruCache + [BestFriend] + internal sealed class LruCache { private readonly int _size; private readonly Dictionary>> _cache; diff --git a/src/Microsoft.ML.Core/Utilities/MathUtils.cs b/src/Microsoft.ML.Core/Utilities/MathUtils.cs index fb68ee82d6..7550a949c5 100644 --- a/src/Microsoft.ML.Core/Utilities/MathUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/MathUtils.cs @@ -13,7 +13,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// /// Some useful math methods. /// - public static class MathUtils + [BestFriend] + internal static class MathUtils { public static Float ToFloat(this Double dbl) { diff --git a/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs b/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs index 0afdd9a6a5..cb945e56a2 100644 --- a/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs +++ b/src/Microsoft.ML.Core/Utilities/MatrixTransposeOps.cs @@ -9,7 +9,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static class MatrixTransposeOps + [BestFriend] + internal static class MatrixTransposeOps { private const int _block = 32; diff --git a/src/Microsoft.ML.Core/Utilities/MemUtils.cs b/src/Microsoft.ML.Core/Utilities/MemUtils.cs deleted file mode 100644 index 1dba9205e9..0000000000 --- a/src/Microsoft.ML.Core/Utilities/MemUtils.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -namespace Microsoft.ML.Runtime.Internal.CpuMath -{ - public static class MemUtils - { - // The signature of this method is intentionally identical to - // .Net 4.6's Buffer.MemoryCopy. - // REVIEW: Remove once we're on a version of .NET which includes - // Buffer.MemoryCopy. - public static unsafe void MemoryCopy(void* source, void* destination, long destinationSizeInBytes, long sourceBytesToCopy) - { - // MemCpy has undefined behavior when handed overlapping source and - // destination buffers. - // Do not pass it overlapping source and destination buffers. - Contracts.Check((byte*)destination + sourceBytesToCopy <= source || destination >= (byte*)source + sourceBytesToCopy); - Contracts.Check(destinationSizeInBytes >= sourceBytesToCopy); -#if CORECLR - System.Buffer.MemoryCopy(source, destination, destinationSizeInBytes, sourceBytesToCopy); -#else - Thunk.MemCpy(destination, source, sourceBytesToCopy); -#endif - } - } -} diff --git a/src/Microsoft.ML.Core/Utilities/MinWaiter.cs b/src/Microsoft.ML.Core/Utilities/MinWaiter.cs index 42ddf0c69c..fbaf8fb6d6 100644 --- a/src/Microsoft.ML.Core/Utilities/MinWaiter.cs +++ b/src/Microsoft.ML.Core/Utilities/MinWaiter.cs @@ -21,7 +21,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// registering itself for a new event (or, finally, retiring itself through /// ). /// - public sealed class MinWaiter + [BestFriend] + internal sealed class MinWaiter { /// /// This is an event-line pair. The intended usage is, when the line diff --git a/src/Microsoft.ML.Core/Utilities/NormStr.cs b/src/Microsoft.ML.Core/Utilities/NormStr.cs index fea018ac58..c79e2425d1 100644 --- a/src/Microsoft.ML.Core/Utilities/NormStr.cs +++ b/src/Microsoft.ML.Core/Utilities/NormStr.cs @@ -17,7 +17,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// /// Normalized string type. For string pooling. /// - public sealed class NormStr + [BestFriend] + internal sealed class NormStr { public readonly ReadOnlyMemory Value; public readonly int Id; diff --git a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs index 4a65286551..a06202af76 100644 --- a/src/Microsoft.ML.Core/Utilities/ObjectPool.cs +++ b/src/Microsoft.ML.Core/Utilities/ObjectPool.cs @@ -8,7 +8,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public sealed class ObjectPool : ObjectPoolBase where T : class, new() + [BestFriend] + internal sealed class ObjectPool : ObjectPoolBase where T : class, new() { protected override T Create() { @@ -16,7 +17,8 @@ protected override T Create() } } - public sealed class MadeObjectPool : ObjectPoolBase + [BestFriend] + internal sealed class MadeObjectPool : ObjectPoolBase { private readonly Func _maker; @@ -31,7 +33,7 @@ protected override T Create() } } - public abstract class ObjectPoolBase + internal abstract class ObjectPoolBase { private readonly ConcurrentBag _pool; private int _numCreated; diff --git a/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs b/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs index e3e118d8bf..ca0ef23445 100644 --- a/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs +++ b/src/Microsoft.ML.Core/Utilities/OrderedWaiter.cs @@ -17,7 +17,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// any order), the first thread to clear the wait will be 0, then 1 will /// be cleared once incremented, then 2 will be cleared once incremented. /// - public sealed class OrderedWaiter + [BestFriend] + internal sealed class OrderedWaiter { /// /// This is an event-line pair. The intended usage is, when the line diff --git a/src/Microsoft.ML.Core/Utilities/PathUtils.cs b/src/Microsoft.ML.Core/Utilities/PathUtils.cs index 6698c11f7f..98407e24a1 100644 --- a/src/Microsoft.ML.Core/Utilities/PathUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/PathUtils.cs @@ -8,7 +8,7 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static partial class Utils + internal static partial class Utils { /// /// Environment variable containing optional resources path. diff --git a/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs b/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs index 8f46f93017..2bd8acab3e 100644 --- a/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/PlatformUtils.cs @@ -11,7 +11,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// /// Contains extension methods that aid in building cross platform. /// - public static class PlatformUtils + [BestFriend] + internal static class PlatformUtils { public static ReadOnlyCollection AsReadOnly(this T[] items) { diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs index 926a425f5b..4fe21e7df7 100644 --- a/src/Microsoft.ML.Core/Utilities/Stream.cs +++ b/src/Microsoft.ML.Core/Utilities/Stream.cs @@ -8,11 +8,10 @@ using System.IO; using System.Text; using System.Threading; -using Microsoft.ML.Runtime.Internal.CpuMath; namespace Microsoft.ML.Runtime.Internal.Utilities { - public static partial class Utils + internal static partial class Utils { private const int _bulkReadThresholdInBytes = 4096; @@ -853,7 +852,7 @@ public static unsafe void ReadBytes(this BinaryReader reader, void* destination, int toRead = (int)Math.Min(bytesToRead - offset, blockSize); int read = reader.Read(work, 0, toRead); Contracts.CheckDecode(read == toRead); - MemUtils.MemoryCopy(src, (byte*)destination + offset, destinationSizeInBytes - offset, read); + Buffer.MemoryCopy(src, (byte*)destination + offset, destinationSizeInBytes - offset, read); offset += read; } Contracts.Assert(offset == bytesToRead); diff --git a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs index 859ae7b28d..4756337b39 100644 --- a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs @@ -10,7 +10,7 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static partial class Utils + internal static partial class Utils { public static Thread CreateBackgroundThread(ParameterizedThreadStart start) { diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs index 9a6ecb9b0b..a0b9019d14 100644 --- a/src/Microsoft.ML.Core/Utilities/Utils.cs +++ b/src/Microsoft.ML.Core/Utilities/Utils.cs @@ -16,7 +16,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static partial class Utils + [BestFriend] + internal static partial class Utils { // Maximum size of one-dimensional array. // See: https://msdn.microsoft.com/en-us/library/hh285054(v=vs.110).aspx diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index 128dafc81e..b795245766 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -293,7 +293,7 @@ public virtual void FinishFirstPass() { } - protected IEnumerable ReverseHeap(Heap heap) + private protected Info[] ReverseHeap(Heap heap) { var res = new Info[heap.Count]; while (heap.Count > 0) diff --git a/src/Microsoft.ML.Data/Model/ModelHeader.cs b/src/Microsoft.ML.Data/Model/ModelHeader.cs index 9acaaccfee..b74d0fdca6 100644 --- a/src/Microsoft.ML.Data/Model/ModelHeader.cs +++ b/src/Microsoft.ML.Data/Model/ModelHeader.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.Runtime.Model { [StructLayout(LayoutKind.Explicit, Size = ModelHeader.Size)] - public struct ModelHeader + internal struct ModelHeader { /// /// This spells 'ML MODEL' with zero replacing space (assuming little endian). diff --git a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs index 7efd9f4b47..f3507277cf 100644 --- a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs +++ b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs @@ -50,7 +50,8 @@ public sealed partial class ModelLoadContext : IDisposable /// /// The main stream's model header. /// - public ModelHeader Header; + [BestFriend] + internal ModelHeader Header; /// /// The min file position of the main stream. @@ -96,7 +97,7 @@ public ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir) /// /// Create a ModelLoadContext supporting loading from a single-stream, for implementors of ICanSaveInBinaryFormat. /// - public ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null) + internal ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null) { Contracts.AssertValueOrNull(ectx); _ectx = ectx; diff --git a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs index c5a9199758..4617be093d 100644 --- a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs +++ b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs @@ -37,12 +37,14 @@ public sealed partial class ModelSaveContext : IDisposable /// /// The strings that will be saved in the main stream's string table. /// - public readonly NormStr.Pool Strings; + [BestFriend] + internal readonly NormStr.Pool Strings; /// /// The main stream's model header. /// - public ModelHeader Header; + [BestFriend] + internal ModelHeader Header; /// /// The min file position of the main stream. diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs index 97db2a8a07..5056571670 100644 --- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs @@ -2,8 +2,33 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Reflection; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; +using Microsoft.ML; -[assembly: InternalsVisibleTo("Microsoft.ML.TestFramework, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)] + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Data" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Api" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Ensemble" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.FastTree" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.KMeansClustering" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Onnx" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransform" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Parquet" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PCA" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Recommender" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Runtime.ImageAnalytics" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Scoring" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StandardLearners" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)] + +[assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs index 8ddc1f89ce..1b70ff0f2d 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs @@ -111,7 +111,7 @@ public override bool TryAdd(ref ReadOnlyMemory val) if (val.IsEmpty) return false; int count = _pool.Count; - return ReadOnlyMemoryUtils.AddToPool(val, _pool).Id == count; + return _pool.Add(val).Id == count; } public override TermMap Finish() @@ -568,9 +568,9 @@ private static TermMap LoadCodecCore(ModelLoadContext ctx, IExceptionContext return new HashArrayImpl(codec.Type.AsPrimitive, values); } - public abstract void WriteTextTerms(TextWriter writer); + internal abstract void WriteTextTerms(TextWriter writer); - public sealed class TextImpl : TermMap> + internal sealed class TextImpl : TermMap> { private readonly NormStr.Pool _pool; @@ -634,7 +634,7 @@ internal override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFa private void KeyMapper(in ReadOnlyMemory src, ref uint dst) { - var nstr = ReadOnlyMemoryUtils.FindInPool(src, _pool); + var nstr = _pool.Get(src); if (nstr == null) dst = 0; else @@ -663,7 +663,7 @@ public override void GetTerms(ref VBuffer> dst) dst = new VBuffer>(_pool.Count, values, dst.Indices); } - public override void WriteTextTerms(TextWriter writer) + internal override void WriteTextTerms(TextWriter writer) { writer.WriteLine("# Number of terms = {0}", Count); foreach (var nstr in _pool) @@ -671,7 +671,7 @@ public override void WriteTextTerms(TextWriter writer) } } - public sealed class HashArrayImpl : TermMap + internal sealed class HashArrayImpl : TermMap where T : IEquatable, IComparable { // One of the two must exist. If we need one we can initialize it @@ -743,7 +743,7 @@ public override void GetTerms(ref VBuffer dst) dst = new VBuffer(Count, values, dst.Indices); } - public override void WriteTextTerms(TextWriter writer) + internal override void WriteTextTerms(TextWriter writer) { writer.WriteLine("# Number of terms of type '{0}' = {1}", ItemType, Count); StringBuilder sb = null; diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs index 48fd32fc31..5179c60e56 100644 --- a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs +++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs @@ -2,12 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Transforms.TensorFlow; using System; -using System.Linq; namespace Microsoft.ML.DnnAnalyzer { @@ -15,7 +11,7 @@ public static class DnnAnalyzer { public static void Main(string[] args) { - if (Utils.Size(args) != 1) + if (args == null || args.Length != 1) { Console.Error.WriteLine("Usage: dotnet DnnAnalyzer.dll "); return; diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index c6a68a8a51..3a0c496868 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -540,8 +540,8 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor, fv = catMetaData.OnehotBias[colIdx] + 1; if (newColIdx != lastIdx) { - featureIndices.Push(newColIdx); - values.Push(fv); + featureIndices.Add(newColIdx); + values.Add(fv); nhot = 1; } else diff --git a/src/Microsoft.ML.PipelineInference/DatasetFeaturesInference.cs b/src/Microsoft.ML.PipelineInference/DatasetFeaturesInference.cs index 097a61e621..f51c0871f8 100644 --- a/src/Microsoft.ML.PipelineInference/DatasetFeaturesInference.cs +++ b/src/Microsoft.ML.PipelineInference/DatasetFeaturesInference.cs @@ -428,8 +428,8 @@ private void ApplyCore(ReadOnlyMemory[][] data, Column column) NumericColumnFeatures.Add(new ColumnStatistics { Column = column, Stats = stats }); else { - NonNumericColumnLengthFeature.Push(new ColumnStatistics { Column = column, Stats = stats }); - NonNumericColumnSpacesFeature.Push(new ColumnStatistics { Column = column, Stats = spacesStats }); + NonNumericColumnLengthFeature.Add(new ColumnStatistics { Column = column, Stats = stats }); + NonNumericColumnSpacesFeature.Add(new ColumnStatistics { Column = column, Stats = spacesStats }); } } diff --git a/src/Microsoft.ML.PipelineInference/Microsoft.ML.PipelineInference.csproj b/src/Microsoft.ML.PipelineInference/Microsoft.ML.PipelineInference.csproj index fdc8d2802d..9f79aebbe1 100644 --- a/src/Microsoft.ML.PipelineInference/Microsoft.ML.PipelineInference.csproj +++ b/src/Microsoft.ML.PipelineInference/Microsoft.ML.PipelineInference.csproj @@ -15,9 +15,9 @@ + - diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 4a6a9e4413..5b237aa47b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -552,17 +552,7 @@ private MulticlassLogisticRegressionPredictor(IHostEnvironment env, ModelLoadCon if (ctx.TryLoadBinaryStream(LabelNamesSubModelFilename, r => labelNames = LoadLabelNames(ctx, r))) _labelNames = labelNames; - string statsDir = Path.Combine(ctx.Directory ?? "", ModelStatsSubModelFilename); - using (var statsEntry = ctx.Repository.OpenEntryOrNull(statsDir, ModelLoadContext.ModelStreamName)) - { - if (statsEntry == null) - _stats = null; - else - { - using (var statsCtx = new ModelLoadContext(ctx.Repository, statsEntry, statsDir)) - _stats = LinearModelStatistics.Create(Host, statsCtx); - } - } + ctx.LoadModelOrNull< LinearModelStatistics, SignatureLoadModel>(Host, out _stats, ModelStatsSubModelFilename); } public static MulticlassLogisticRegressionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) @@ -698,14 +688,7 @@ protected override void SaveCore(ModelSaveContext ctx) Contracts.AssertValueOrNull(_stats); if (_stats != null) - { - using (var statsCtx = new ModelSaveContext(ctx.Repository, - Path.Combine(ctx.Directory ?? "", ModelStatsSubModelFilename), ModelLoadContext.ModelStreamName)) - { - _stats.Save(statsCtx); - statsCtx.Done(); - } - } + ctx.SaveModel(_stats, ModelStatsSubModelFilename); } // REVIEW: Destroy. diff --git a/src/Microsoft.ML.Core/Prediction/ISweeper.cs b/src/Microsoft.ML.Sweeper/ISweeper.cs similarity index 100% rename from src/Microsoft.ML.Core/Prediction/ISweeper.cs rename to src/Microsoft.ML.Sweeper/ISweeper.cs diff --git a/src/Microsoft.ML.Sweeper/Parameters.cs b/src/Microsoft.ML.Sweeper/Parameters.cs index 1618881867..7e87d34c35 100644 --- a/src/Microsoft.ML.Sweeper/Parameters.cs +++ b/src/Microsoft.ML.Sweeper/Parameters.cs @@ -9,7 +9,6 @@ using System.Globalization; using System.Linq; using System.Text.RegularExpressions; -using Microsoft.ML; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Internal.Utilities; diff --git a/src/Microsoft.ML.Sweeper/SweepCommand.cs b/src/Microsoft.ML.Sweeper/SweepCommand.cs index c738db0ef5..51dc28559d 100644 --- a/src/Microsoft.ML.Sweeper/SweepCommand.cs +++ b/src/Microsoft.ML.Sweeper/SweepCommand.cs @@ -5,13 +5,10 @@ using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Sweeper; -using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Command; [assembly: LoadableClass(SweepCommand.Summary, typeof(SweepCommand), typeof(SweepCommand.Arguments), typeof(SignatureCommand), diff --git a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs index fded15ddaf..1bfd37ec70 100644 --- a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs +++ b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs @@ -4,7 +4,6 @@ using System; using System.Text; -using Microsoft.ML; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.CommandLine; diff --git a/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs b/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs index bee7b8a60b..27da71cc64 100644 --- a/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs +++ b/src/Microsoft.ML.Sweeper/SynthConfigRunner.cs @@ -7,12 +7,9 @@ using System.IO; using System.Threading.Tasks; -using Microsoft.ML; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Sweeper; -using Microsoft.ML.Runtime.Internal.Internallearn; [assembly: LoadableClass(typeof(SynthConfigRunner), typeof(SynthConfigRunner.Arguments), typeof(SignatureConfigRunner), "", "Synth")] diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index a2b85313d6..a48b59a0e9 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -14,7 +14,7 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TimeSeriesProcessing; -[assembly: LoadableClass(typeof(ISequenceModeler), typeof(AdaptiveSingularSpectrumSequenceModeler), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(AdaptiveSingularSpectrumSequenceModeler), typeof(AdaptiveSingularSpectrumSequenceModeler), null, typeof(SignatureLoadModel), "SSA Sequence Modeler", AdaptiveSingularSpectrumSequenceModeler.LoaderSignature)] @@ -24,7 +24,7 @@ namespace Microsoft.ML.Runtime.TimeSeriesProcessing /// This class implements basic Singular Spectrum Analysis (SSA) model for modeling univariate time-series. /// For the details of the model, refer to http://arxiv.org/pdf/1206.6910.pdf. /// - public sealed class AdaptiveSingularSpectrumSequenceModeler : ISequenceModeler + public sealed class AdaptiveSingularSpectrumSequenceModeler : SequenceModelerBase { public const string LoaderSignature = "SSAModel"; @@ -239,7 +239,6 @@ private static VersionInfo GetVersionInfo() /// The length of series that is kept in buffer for modeling (parameter N). /// The length of the window on the series for building the trajectory matrix (parameter L). /// The discount factor in [0,1] used for online updates (default = 1). - /// The buffer used to keep the series in the memory. If null, an internal buffer is created (default = null). /// The rank selection method (default = Exact). /// The desired rank of the subspace used for SSA projection (parameter r). This parameter should be in the range in [1, windowSize]. /// If set to null, the rank is automatically determined based on prediction error minimization. (default = null) @@ -249,8 +248,9 @@ private static VersionInfo GetVersionInfo() /// The flag determining whether the meta information for the model needs to be maintained. /// The maximum growth on the exponential trend public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, int trainSize, int seriesLength, int windowSize, Single discountFactor = 1, - FixedSizeQueue buffer = null, RankSelectionMethod rankSelectionMethod = RankSelectionMethod.Exact, int? rank = null, int? maxRank = null, + RankSelectionMethod rankSelectionMethod = RankSelectionMethod.Exact, int? rank = null, int? maxRank = null, bool shouldComputeForecastIntervals = true, bool shouldstablize = true, bool shouldMaintainInfo = false, GrowthRatio? maxGrowth = null) + : base() { Contracts.CheckValue(env, nameof(env)); _host = env.Register(LoaderSignature); @@ -285,10 +285,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, int trainSi _trainSize = trainSize; _discountFactor = discountFactor; - if (buffer == null) - _buffer = new FixedSizeQueue(seriesLength); - else - _buffer = buffer; + _buffer = new FixedSizeQueue(seriesLength); _alpha = new Single[windowSize - 1]; _state = new Single[windowSize - 1]; @@ -312,7 +309,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, int trainSi } /// - /// The copy constructor + /// The copy constructor. /// /// An object whose contents are copied to the current object. private AdaptiveSingularSpectrumSequenceModeler(AdaptiveSingularSpectrumSequenceModeler model) @@ -465,7 +462,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo _xSmooth = new CpuAlignedVector(_windowSize, SseUtils.CbAlign); } - public void Save(ModelSaveContext ctx) + public override void Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -741,7 +738,7 @@ private static int DetermineSignalRank(Single[] series, TrajectoryMatrix tMat, S return minIndex + 1; } - public void InitState() + internal override void InitState() { for (int i = 0; i < _windowSize - 2; ++i) _state[i] = 0; @@ -1114,7 +1111,7 @@ private bool Stabilize() /// /// The next observation on the series. /// Determines whether the model parameters also need to be updated upon consuming the new observation (default = false). - public void Consume(ref Single input, bool updateModel = false) + internal override void Consume(ref Single input, bool updateModel = false) { if (Single.IsNaN(input)) return; @@ -1177,7 +1174,7 @@ public void Consume(ref Single input, bool updateModel = false) /// Train the model parameters based on a training series. /// /// The training time-series. - public void Train(FixedSizeQueue data) + internal override void Train(FixedSizeQueue data) { _host.CheckParam(data != null, nameof(data), "The input series for training cannot be null."); _host.CheckParam(data.Count >= _trainSize, nameof(data), "The input series for training does not have enough points for training."); @@ -1218,7 +1215,7 @@ public void Train(FixedSizeQueue data) /// Train the model parameters based on a training series. /// /// The training time-series. - public void Train(RoleMappedData data) + internal override void Train(RoleMappedData data) { _host.CheckParam(data != null, nameof(data), "The input series for training cannot be null."); if (data.Schema.Feature.Type != NumberType.Float) @@ -1428,7 +1425,7 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) /// /// The forecast result. /// The forecast horizon. - public void Forecast(ref ForecastResultBase result, int horizon = 1) + internal override void Forecast(ref ForecastResultBase result, int horizon = 1) { _host.CheckParam(horizon >= 1, nameof(horizon), "The horizon parameter should be greater than 0."); if (result == null) @@ -1500,12 +1497,12 @@ public void Forecast(ref ForecastResultBase result, int horizon = 1) /// Predicts the next value on the series. /// /// The prediction result. - public void PredictNext(ref Single output) + internal override void PredictNext(ref Single output) { output = _nextPrediction; } - public ISequenceModeler Clone() + internal override SequenceModelerBase Clone() { return new AdaptiveSingularSpectrumSequenceModeler(this); } diff --git a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs index a6a7c20986..8a1807a797 100644 --- a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs +++ b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs @@ -109,12 +109,12 @@ public State() _firstIteration = true; } - protected override void SetNaOutput(ref Single output) + private protected override void SetNaOutput(ref Single output) { output = Single.NaN; } - protected override void TransformCore(ref Single input, FixedSizeQueue windowedBuffer, long iteration, ref Single output) + private protected override void TransformCore(ref Single input, FixedSizeQueue windowedBuffer, long iteration, ref Single output) { if (_firstIteration) { @@ -127,13 +127,13 @@ protected override void TransformCore(ref Single input, FixedSizeQueue w _previousAverage = output; } - protected override void InitializeStateCore() + private protected override void InitializeStateCore() { _firstIteration = true; _decay = ((ExponentialAverageTransform)ParentTransform)._decay; } - protected override void LearnStateFromDataCore(FixedSizeQueue data) + private protected override void LearnStateFromDataCore(FixedSizeQueue data) { // This method is empty because there is no need for parameter learning from the initial windowed buffer for this transform. } diff --git a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs index 900407adec..f19694e38d 100644 --- a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs @@ -54,17 +54,17 @@ public override void Save(ModelSaveContext ctx) public sealed class State : AnomalyDetectionStateBase { - protected override void LearnStateFromDataCore(FixedSizeQueue data) + private protected override void LearnStateFromDataCore(FixedSizeQueue data) { // This method is empty because there is no need for initial tuning for this transform. } - protected override void InitializeAnomalyDetector() + private protected override void InitializeAnomalyDetector() { // This method is empty because there is no need for any extra initialization for this transform. } - protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) + private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) { // This transform treats the input sequenence as the raw anomaly score. return (double)input; diff --git a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs index da71c8d5e9..da4094a125 100644 --- a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs +++ b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs @@ -143,7 +143,7 @@ private static Single ComputeMovingAverageUniformInitialisation(FixedSizeQueue others, Single input, Single[] weights, int lag) + internal static Single ComputeMovingAverageNonUniform(FixedSizeQueue others, Single input, Single[] weights, int lag) { Single sumWeights = 0; Single sumValues = 0; @@ -178,7 +178,7 @@ public static Single ComputeMovingAverageNonUniform(FixedSizeQueue other /// NaN value: only NaN values in the sliding window or +/- Infinite /// Inifinite value: one infinite value in the sliding window (sign is no relevant) /// - public static Single ComputeMovingAverageUniform(FixedSizeQueue others, Single input, int lag, + internal static Single ComputeMovingAverageUniform(FixedSizeQueue others, Single input, int lag, Single lastDropped, ref Single currentSum, ref bool initUniformMovingAverage, ref int nbNanValues) @@ -262,7 +262,7 @@ public sealed class State : StateBase // take part of the computation. private int _nbNanValues; - protected override void SetNaOutput(ref Single output) + private protected override void SetNaOutput(ref Single output) { output = Single.NaN; } @@ -274,7 +274,7 @@ protected override void SetNaOutput(ref Single output) /// /// /// - protected override void TransformCore(ref Single input, FixedSizeQueue windowedBuffer, long iteration, ref Single output) + private protected override void TransformCore(ref Single input, FixedSizeQueue windowedBuffer, long iteration, ref Single output) { if (_weights == null) output = ComputeMovingAverageUniform(windowedBuffer, input, _lag, _lastDroppedValue, ref _currentSum, ref _initUniformMovingAverage, ref _nbNanValues); @@ -283,14 +283,14 @@ protected override void TransformCore(ref Single input, FixedSizeQueue w _lastDroppedValue = windowedBuffer[0]; } - protected override void InitializeStateCore() + private protected override void InitializeStateCore() { _weights = ((MovingAverageTransform)ParentTransform)._weights; _lag = ((MovingAverageTransform)ParentTransform)._lag; _initUniformMovingAverage = true; } - protected override void LearnStateFromDataCore(FixedSizeQueue data) + private protected override void LearnStateFromDataCore(FixedSizeQueue data) { // This method is empty because there is no need for parameter learning from the initial windowed buffer for this transform. } diff --git a/src/Microsoft.ML.TimeSeries/PValueTransform.cs b/src/Microsoft.ML.TimeSeries/PValueTransform.cs index b6490de3c7..36e8bebf4b 100644 --- a/src/Microsoft.ML.TimeSeries/PValueTransform.cs +++ b/src/Microsoft.ML.TimeSeries/PValueTransform.cs @@ -113,12 +113,12 @@ public sealed class State : StateBase private PValueTransform _parent; - protected override void SetNaOutput(ref Single dst) + private protected override void SetNaOutput(ref Single dst) { dst = Single.NaN; } - protected override void TransformCore(ref Single input, FixedSizeQueue windowedBuffer, long iteration, ref Single dst) + private protected override void TransformCore(ref Single input, FixedSizeQueue windowedBuffer, long iteration, ref Single dst) { int count; int equalCount; @@ -131,13 +131,13 @@ protected override void TransformCore(ref Single input, FixedSizeQueue w // Based on the equation in http://arxiv.org/pdf/1204.3251.pdf } - protected override void InitializeStateCore() + private protected override void InitializeStateCore() { _parent = (PValueTransform)ParentTransform; _randomGen = RandomUtils.Create(_parent._seed); } - protected override void LearnStateFromDataCore(FixedSizeQueue data) + private protected override void LearnStateFromDataCore(FixedSizeQueue data) { // This method is empty because there is no need for parameter learning from the initial windowed buffer for this transform. } diff --git a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs index 9d771ef21a..571a1b1bc4 100644 --- a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs +++ b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs @@ -101,7 +101,7 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.Write(_percentile); } - public static void CountGreaterOrEqualValues(FixedSizeQueue others, Single theValue, out int greaterVals, out int equalVals, out int totalVals) + internal static void CountGreaterOrEqualValues(FixedSizeQueue others, Single theValue, out int greaterVals, out int equalVals, out int totalVals) { // The current linear algorithm for counting greater and equal elements takes O(n), // but it can be improved to O(log n) if a separate Binary Search Tree data structure is used. @@ -130,12 +130,12 @@ public sealed class State : StateBase /// private PercentileThresholdTransform _parent; - protected override void SetNaOutput(ref bool dst) + private protected override void SetNaOutput(ref bool dst) { dst = false; } - protected override void TransformCore(ref Single input, FixedSizeQueue windowedBuffer, long iteration, ref bool dst) + private protected override void TransformCore(ref Single input, FixedSizeQueue windowedBuffer, long iteration, ref bool dst) { int greaterCount; int equalCount; @@ -145,15 +145,15 @@ protected override void TransformCore(ref Single input, FixedSizeQueue w dst = greaterCount < (int)(_parent._percentile * totalCount / 100); } - protected override void InitializeStateCore() + private protected override void InitializeStateCore() { _parent = (PercentileThresholdTransform)ParentTransform; } - protected override void LearnStateFromDataCore(FixedSizeQueue data) + private protected override void LearnStateFromDataCore(FixedSizeQueue data) { // This method is empty because there is no need for parameter learning from the initial windowed buffer for this transform. } } } -} +} \ No newline at end of file diff --git a/src/Microsoft.ML.TimeSeries/ISequenceModeler.cs b/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs similarity index 75% rename from src/Microsoft.ML.TimeSeries/ISequenceModeler.cs rename to src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs index 5ee0c01711..5b7b86d93f 100644 --- a/src/Microsoft.ML.TimeSeries/ISequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs @@ -22,50 +22,59 @@ public abstract class ForecastResultBase /// /// The type of the elements in the input sequence /// The type of the elements in the output sequence - public interface ISequenceModeler : ICanSaveModel + public abstract class SequenceModelerBase : ICanSaveModel { + private protected SequenceModelerBase() + { + } + /// /// Initializes the state of the modeler /// - void InitState(); + internal abstract void InitState(); /// /// Consumes one element from the input sequence. /// /// An element in the sequence /// determines whether the sequence model should be updated according to the input - void Consume(ref TInput input, bool updateModel = false); + internal abstract void Consume(ref TInput input, bool updateModel = false); /// /// Trains the sequence model on a given sequence. /// /// The input sequence used for training - void Train(FixedSizeQueue data); + internal abstract void Train(FixedSizeQueue data); /// /// Trains the sequence model on a given sequence. The method accepts an object of RoleMappedData, /// and assumes the input column is the 'Feature' column of type TInput. /// /// The input sequence used for training - void Train(RoleMappedData data); + internal abstract void Train(RoleMappedData data); /// /// Forecasts the next 'horizon' elements in the output sequence. /// /// The forecast result for the given horizon along with optional information depending on the algorithm /// The forecast horizon - void Forecast(ref ForecastResultBase result, int horizon = 1); + internal abstract void Forecast(ref ForecastResultBase result, int horizon = 1); /// /// Predicts the next element in the output sequence. /// /// The output ref parameter the will contain the prediction result - void PredictNext(ref TOutput output); + internal abstract void PredictNext(ref TOutput output); /// /// Creates a clone of the model. /// /// A clone of the object - ISequenceModeler Clone(); + internal abstract SequenceModelerBase Clone(); + + /// + /// Implementation of . + /// + public abstract void Save(ModelSaveContext ctx); } } diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs index e80f41a985..b0c1199439 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs @@ -159,7 +159,7 @@ private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment } } - protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, + private protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, AnomalySide anomalySide, MartingaleType martingale, AlertingScore alertingScore, Double powerMartingaleEpsilon, Double alertThreshold) : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, inputColumnName, outputColumnName, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env))) @@ -183,13 +183,13 @@ protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWin _outputLength = GetOutputLength(ThresholdScore, Host); } - protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env) + private protected SequentialAnomalyDetectionTransformBase(ArgumentsBase args, string name, IHostEnvironment env) : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, name, env, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold) { } - protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name) + private protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name) : base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx) { // *** Binary format *** @@ -319,8 +319,10 @@ public abstract class AnomalyDetectionStateBase : StateBase private int _martingaleAlertCounter; - protected Double LatestMartingaleScore { - get { return Math.Exp(_logMartingaleValue); } + protected Double LatestMartingaleScore => Math.Exp(_logMartingaleValue); + + private protected AnomalyDetectionStateBase() : base() + { } private Double ComputeKernelPValue(Double rawScore) @@ -359,7 +361,7 @@ private Double ComputeKernelPValue(Double rawScore) return pValue; } - protected override void SetNaOutput(ref VBuffer dst) + private protected override void SetNaOutput(ref VBuffer dst) { var values = dst.Values; var outputLength = Parent._outputLength; @@ -372,7 +374,7 @@ protected override void SetNaOutput(ref VBuffer dst) dst = new VBuffer(Utils.Size(values), values, dst.Indices); } - protected override sealed void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref VBuffer dst) + private protected override sealed void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref VBuffer dst) { var outputLength = Parent._outputLength; Host.Assert(outputLength >= 2); @@ -508,7 +510,7 @@ protected override sealed void TransformCore(ref TInput input, FixedSizeQueue(outputLength, result, dst.Indices); } - protected override sealed void InitializeStateCore() + private protected override sealed void InitializeStateCore() { Parent = (SequentialAnomalyDetectionTransformBase)ParentTransform; Host.Assert(WindowSize >= 0); @@ -525,7 +527,7 @@ protected override sealed void InitializeStateCore() /// /// The abstract method that realizes the initialization functionality for the anomaly detector. /// - protected abstract void InitializeAnomalyDetector(); + private protected abstract void InitializeAnomalyDetector(); /// /// The abstract method that realizes the main logic for calculating the raw anomaly score bfor the current input given a windowed buffer @@ -535,7 +537,7 @@ protected override sealed void InitializeStateCore() /// A long number that indicates the number of times ComputeRawAnomalyScore has been called so far (starting value = 0). /// The raw anomaly score for the input. The Assumption is the higher absolute value of the raw score, the more anomalous the input is. /// The sign of the score determines whether it's a positive anomaly or a negative one. - protected abstract Double ComputeRawAnomalyScore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration); + private protected abstract Double ComputeRawAnomalyScore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration); } protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(Host, this, schema); @@ -609,13 +611,13 @@ private Delegate MakeGetter(IRow input, TState state) _host.AssertValue(input); var srcGetter = input.GetGetter(_inputColumnIndex); ProcessData processData = _parent.WindowSize > 0 ? - (ProcessData) state.Process : state.ProcessWithoutBuffer; - ValueGetter > valueGetter = (ref VBuffer dst) => - { - TInput src = default; - srcGetter(ref src); - processData(ref src, ref dst); - }; + (ProcessData)state.Process : state.ProcessWithoutBuffer; + ValueGetter> valueGetter = (ref VBuffer dst) => + { + TInput src = default; + srcGetter(ref src); + processData(ref src, ref dst); + }; return valueGetter; } diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs index 5d6bd8f553..3482487a6e 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs @@ -51,28 +51,28 @@ public abstract class StateBase /// /// A reference to the parent transform that operates on the state object. /// - protected SequentialTransformBase ParentTransform; + private protected SequentialTransformBase ParentTransform; /// /// The internal windowed buffer for buffering the values in the input sequence. /// - protected FixedSizeQueue WindowedBuffer; + private protected FixedSizeQueue WindowedBuffer; /// /// The buffer used to buffer the training data points. /// - protected FixedSizeQueue InitialWindowedBuffer; + private protected FixedSizeQueue InitialWindowedBuffer; - protected int WindowSize { get; private set; } + private protected int WindowSize { get; private set; } - protected int InitialWindowSize { get; private set; } + private protected int InitialWindowSize { get; private set; } /// /// Counts the number of rows observed by the transform so far. /// - protected int RowCounter { get; private set; } + private protected int RowCounter { get; private set; } - protected int IncrementRowCounter() + private protected int IncrementRowCounter() { RowCounter++; return RowCounter; @@ -166,10 +166,10 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) } /// - /// The abstract method that specifies the NA value for the dst type. + /// The abstract method that specifies the NA value for 's type. /// /// - protected abstract void SetNaOutput(ref TOutput dst); + private protected abstract void SetNaOutput(ref TOutput dst); /// /// The abstract method that realizes the main logic for the transform. @@ -178,18 +178,18 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) /// A reference to the dst object. /// A reference to the windowed buffer. /// A long number that indicates the number of times TransformCore has been called so far (starting value = 0). - protected abstract void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref TOutput dst); + private protected abstract void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref TOutput dst); /// /// The abstract method that realizes the logic for initializing the state object. /// - protected abstract void InitializeStateCore(); + private protected abstract void InitializeStateCore(); /// /// The abstract method that realizes the logic for learning the parameters and the initial state object from data. /// /// A queue of data points used for training - protected abstract void LearnStateFromDataCore(FixedSizeQueue data); + private protected abstract void LearnStateFromDataCore(FixedSizeQueue data); } /// @@ -242,13 +242,13 @@ private static IDataTransform CreateLambdaTransform(IHost host, IDataView input, /// A reference to the environment variable. /// A reference to the input data view. /// - protected SequentialTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, + private protected SequentialTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, IDataView input, ColumnType outputColTypeOverride = null) : this(windowSize, initialWindowSize, inputColumnName, outputColumnName, Contracts.CheckRef(env, nameof(env)).Register(name), input, outputColTypeOverride) { } - protected SequentialTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, + private protected SequentialTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, IHost host, IDataView input, ColumnType outputColTypeOverride = null) : base(host, input) { @@ -268,7 +268,7 @@ protected SequentialTransformBase(int windowSize, int initialWindowSize, string _transform = CreateLambdaTransform(Host, input, InputColumnName, OutputColumnName, InitFunction, WindowSize > 0, outputColTypeOverride); } - protected SequentialTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input) + private protected SequentialTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input) : base(env, name, input) { Host.CheckValue(ctx, nameof(ctx)); @@ -343,7 +343,7 @@ private void InitFunction(TState state) state.InitState(WindowSize, InitialWindowSize, this, Host); } - public override bool CanShuffle { get { return false; } } + public override bool CanShuffle => false; protected override bool? ShouldUseParallelCursors(Func predicate) { @@ -357,10 +357,7 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando return new Cursor(this, srcCursor); } - public override Schema Schema - { - get { return _transform.Schema; } - } + public override Schema Schema => _transform.Schema; public override long? GetRowCount(bool lazy = true) { diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs index 5e2b32a81b..91ccc51101 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -29,7 +29,7 @@ public abstract class StateBase { // Ideally this class should be private. However, due to the current constraints with the LambdaTransform, we need to have // access to the state class when inheriting from SequentialTransformerBase. - protected IHost Host; + private protected IHost Host; /// /// A reference to the parent transform that operates on the state object. @@ -39,22 +39,26 @@ public abstract class StateBase /// /// The internal windowed buffer for buffering the values in the input sequence. /// - protected FixedSizeQueue WindowedBuffer; + private protected FixedSizeQueue WindowedBuffer; /// /// The buffer used to buffer the training data points. /// - protected FixedSizeQueue InitialWindowedBuffer; + private protected FixedSizeQueue InitialWindowedBuffer; - protected int WindowSize { get; private set; } + private protected int WindowSize { get; private set; } - protected int InitialWindowSize { get; private set; } + private protected int InitialWindowSize { get; private set; } /// /// Counts the number of rows observed by the transform so far. /// protected long RowCounter { get; private set; } + private protected StateBase() + { + } + protected long IncrementRowCounter() { RowCounter++; @@ -147,7 +151,7 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) /// The abstract method that specifies the NA value for the dst type. /// /// - protected abstract void SetNaOutput(ref TOutput dst); + private protected abstract void SetNaOutput(ref TOutput dst); /// /// The abstract method that realizes the main logic for the transform. @@ -156,35 +160,35 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) /// A reference to the dst object. /// A reference to the windowed buffer. /// A long number that indicates the number of times TransformCore has been called so far (starting value = 0). - protected abstract void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref TOutput dst); + private protected abstract void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref TOutput dst); /// /// The abstract method that realizes the logic for initializing the state object. /// - protected abstract void InitializeStateCore(); + private protected abstract void InitializeStateCore(); /// /// The abstract method that realizes the logic for learning the parameters and the initial state object from data. /// /// A queue of data points used for training - protected abstract void LearnStateFromDataCore(FixedSizeQueue data); + private protected abstract void LearnStateFromDataCore(FixedSizeQueue data); } - protected readonly IHost Host; + private protected readonly IHost Host; /// /// The window size for buffering. /// - protected readonly int WindowSize; + private protected readonly int WindowSize; /// /// The number of datapoints from the beginning of the sequence that are used for learning the initial state. /// - protected int InitialWindowSize; + private protected int InitialWindowSize; - public string InputColumnName; - public string OutputColumnName; - protected ColumnType OutputColumnType; + internal readonly string InputColumnName; + internal readonly string OutputColumnName; + private protected ColumnType OutputColumnType; public bool IsRowToRowMapper => false; @@ -197,7 +201,7 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) /// The name of the input column. /// The name of the dst column. /// - protected SequentialTransformerBase(IHost host, int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, ColumnType outputColType) + private protected SequentialTransformerBase(IHost host, int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, ColumnType outputColType) { Host = host; Host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative."); @@ -214,7 +218,7 @@ protected SequentialTransformerBase(IHost host, int windowSize, int initialWindo WindowSize = windowSize; } - protected SequentialTransformerBase(IHost host, ModelLoadContext ctx) + private protected SequentialTransformerBase(IHost host, ModelLoadContext ctx) { Host = host; Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs index 0947d02b25..5ff02a672d 100644 --- a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs @@ -131,7 +131,7 @@ public sealed class StateSlide : StateBase { private SlidingWindowTransformBase _parentSliding; - protected override void SetNaOutput(ref VBuffer output) + private protected override void SetNaOutput(ref VBuffer output) { int size = _parentSliding.WindowSize - _parentSliding._lag + 1; @@ -156,7 +156,7 @@ protected override void SetNaOutput(ref VBuffer output) output = new VBuffer(size, result, output.Indices); } - protected override void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref VBuffer output) + private protected override void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref VBuffer output) { int size = _parentSliding.WindowSize - _parentSliding._lag + 1; var result = output.Values; @@ -177,12 +177,12 @@ protected override void TransformCore(ref TInput input, FixedSizeQueue w output = new VBuffer(size, result, output.Indices); } - protected override void InitializeStateCore() + private protected override void InitializeStateCore() { _parentSliding = (SlidingWindowTransformBase)base.ParentTransform; } - protected override void LearnStateFromDataCore(FixedSizeQueue data) + private protected override void LearnStateFromDataCore(FixedSizeQueue data) { // This method is empty because there is no need for parameter learning from the initial windowed buffer for this transform. } diff --git a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs index 86d422e484..b5a5cd9d16 100644 --- a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs @@ -104,7 +104,7 @@ public abstract class SsaArguments : ArgumentsBase protected readonly bool IsAdaptive; protected readonly ErrorFunctionUtils.ErrorFunction ErrorFunction; protected readonly Func ErrorFunc; - protected readonly ISequenceModeler Model; + protected readonly SequenceModelerBase Model; public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment env) : base(args.WindowSize, 0, args.Source, args.Name, name, env, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold) @@ -120,7 +120,7 @@ public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment IsAdaptive = args.IsAdaptive; // Creating the master SSA model Model = new AdaptiveSingularSpectrumSequenceModeler(Host, args.InitialWindowSize, SeasonalWindowSize + 1, SeasonalWindowSize, - DiscountFactor, null, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalWindowSize / 2, false, false); + DiscountFactor, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalWindowSize / 2, false, false); } public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) @@ -150,7 +150,7 @@ public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, strin IsAdaptive = ctx.Reader.ReadBoolean(); - ctx.LoadModel, SignatureLoadModel>(env, out Model, "SSA"); + ctx.LoadModel, SignatureLoadModel>(env, out Model, "SSA"); Host.CheckDecode(Model != null); } @@ -197,21 +197,21 @@ public override void Save(ModelSaveContext ctx) public sealed class State : AnomalyDetectionStateBase { - private ISequenceModeler _model; + private SequenceModelerBase _model; private SsaAnomalyDetectionBase _parentAnomalyDetector; - protected override void LearnStateFromDataCore(FixedSizeQueue data) + private protected override void LearnStateFromDataCore(FixedSizeQueue data) { // This method is empty because there is no need to implement a training logic here. } - protected override void InitializeAnomalyDetector() + private protected override void InitializeAnomalyDetector() { _parentAnomalyDetector = (SsaAnomalyDetectionBase)Parent; _model = _parentAnomalyDetector.Model.Clone(); } - protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) + private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) { // Get the prediction for the next point opn the series Single expectedValue = 0; diff --git a/src/Microsoft.ML.Transforms/TermLookupTransform.cs b/src/Microsoft.ML.Transforms/TermLookupTransform.cs index 14fc80322a..0d9ad086bf 100644 --- a/src/Microsoft.ML.Transforms/TermLookupTransform.cs +++ b/src/Microsoft.ML.Transforms/TermLookupTransform.cs @@ -158,7 +158,7 @@ public override void Train(IExceptionContext ectx, IRowCursor cursor, int colTer getTerm(ref term); // REVIEW: Should we trim? term = ReadOnlyMemoryUtils.TrimSpaces(term); - var nstr = ReadOnlyMemoryUtils.AddToPool(term, terms); + var nstr = terms.Add(term); if (nstr.Id != values.Count) throw ectx.Except("Duplicate term in lookup data: '{0}'", nstr); @@ -193,7 +193,7 @@ private ValueGetter GetGetterCore(ValueGetter> getTer { getTerm(ref src); src = ReadOnlyMemoryUtils.TrimSpaces(src); - var nstr = ReadOnlyMemoryUtils.FindInPool(src, _terms); + var nstr = _terms.Get(src); if (nstr == null) GetMissing(ref dst); else diff --git a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs index 9aa18dc30a..122d612c7f 100644 --- a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs +++ b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs @@ -653,7 +653,7 @@ public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper() sweeps = sweeper.ProposeSweeps(5, results); } - Assert.True(Utils.Size(sweeps) <= 5); + Assert.True(sweeps.Length <= 5); } } From 325727a255775df1c18f7fbc21bf234cc9746b61 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 8 Nov 2018 16:20:29 -0800 Subject: [PATCH 02/12] Internalize command line related infrastructure. --- .../CommandLine/ArgumentAttribute.cs | 80 +--- .../CommandLine/ArgumentType.cs | 54 +++ .../CommandLine/CharCursor.cs | 2 +- src/Microsoft.ML.Core/CommandLine/CmdLexer.cs | 5 +- .../CommandLine/CmdParser.cs | 359 +++++++----------- .../CommandLine/DefaultArgumentAttribute.cs | 29 ++ .../CommandLine/EnumValueDisplayAttribute.cs | 23 ++ .../CommandLine/HideEnumValueAttribute.cs | 20 + .../{Utils.cs => SpecialPurpose.cs} | 3 +- .../Code/ContractsCheckTest.cs | 24 +- .../Helpers/CodeFixVerifier.cs | 5 + .../Microsoft.ML.CodeAnalyzer.Tests.csproj | 11 +- .../Microsoft.ML.Sweeper.Tests/TestSweeper.cs | 2 +- 13 files changed, 320 insertions(+), 297 deletions(-) create mode 100644 src/Microsoft.ML.Core/CommandLine/ArgumentType.cs create mode 100644 src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs create mode 100644 src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs create mode 100644 src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs rename src/Microsoft.ML.Core/CommandLine/{Utils.cs => SpecialPurpose.cs} (93%) diff --git a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs index 64fe3b5b80..70f9ec8d98 100644 --- a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs +++ b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -// This is separated from CmdParser.cs - using System; using System.Linq; @@ -15,7 +13,8 @@ namespace Microsoft.ML.Runtime.CommandLine /// as the destination of command line argument parsing. /// [AttributeUsage(AttributeTargets.Field)] - public class ArgumentAttribute : Attribute + [BestFriend] + internal class ArgumentAttribute : Attribute { public enum VisibilityType { @@ -24,17 +23,8 @@ public enum VisibilityType EntryPointsOnly } - private ArgumentType _type; private string _shortName; - private string _helpText; - private bool _hide; - private double _sortOrder; - private string _nullName; - private bool _isInputFileName; - private string _specialPurpose; - private VisibilityType _visibility; private string _name; - private Type _signatureType; /// /// Allows control of command line parsing. @@ -42,17 +32,14 @@ public enum VisibilityType /// Specifies the error checking to be done on the argument. public ArgumentAttribute(ArgumentType type) { - _type = type; - _sortOrder = 150; + Type = type; + SortOrder = 150; } /// /// The error checking to be done on the argument. /// - public ArgumentType Type - { - get { return _type; } - } + public ArgumentType Type { get; } /// /// The short name(s) of the argument. @@ -64,7 +51,7 @@ public ArgumentType Type /// public string ShortName { - get { return _shortName; } + get => _shortName; set { Contracts.Check(value == null || !(this is DefaultArgumentAttribute)); @@ -75,54 +62,26 @@ public string ShortName /// /// The help text for the argument. /// - public string HelpText - { - get { return _helpText; } - set { _helpText = value; } - } + public string HelpText { get; set; } - public bool Hide - { - get { return _hide; } - set { _hide = value; } - } + public bool Hide { get; set; } - public double SortOrder - { - get { return _sortOrder; } - set { _sortOrder = value; } - } + public double SortOrder { get; set; } - public string NullName - { - get { return _nullName; } - set { _nullName = value; } - } + public string NullName { get; set; } - public bool IsInputFileName - { - get { return _isInputFileName; } - set { _isInputFileName = value; } - } + public bool IsInputFileName { get; set; } /// /// Allows the GUI or other tools to inspect the intended purpose of the argument and pick a correct custom control. /// - public string Purpose - { - get { return _specialPurpose; } - set { _specialPurpose = value; } - } + public string Purpose { get; set; } - public VisibilityType Visibility - { - get { return _visibility; } - set { _visibility = value; } - } + public VisibilityType Visibility { get; set; } public string Name { - get { return _name; } + get => _name; set { _name = string.IsNullOrWhiteSpace(value) ? null : value; } } @@ -136,15 +95,8 @@ public string[] Aliases } } - public bool IsRequired - { - get { return ArgumentType.Required == (_type & ArgumentType.Required); } - } + public bool IsRequired => ArgumentType.Required == (Type & ArgumentType.Required); - public Type SignatureType - { - get { return _signatureType; } - set { _signatureType = value; } - } + public Type SignatureType { get; set; } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs b/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs new file mode 100644 index 0000000000..5840615fd5 --- /dev/null +++ b/src/Microsoft.ML.Core/CommandLine/ArgumentType.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Runtime.CommandLine +{ + /// + /// Used to control parsing of command line arguments. + /// + [Flags] + [BestFriend] + internal enum ArgumentType + { + /// + /// Indicates that this field is required. An error will be displayed + /// if it is not present when parsing arguments. + /// + Required = 0x01, + + /// + /// Only valid in conjunction with Multiple. + /// Duplicate values will result in an error. + /// + Unique = 0x02, + + /// + /// Inidicates that the argument may be specified more than once. + /// Only valid if the argument is a collection + /// + Multiple = 0x04, + + /// + /// The default type for non-collection arguments. + /// The argument is not required, but an error will be reported if it is specified more than once. + /// + AtMostOnce = 0x00, + + /// + /// For non-collection arguments, when the argument is specified more than + /// once no error is reported and the value of the argument is the last + /// value which occurs in the argument list. + /// + LastOccurenceWins = Multiple, + + /// + /// The default type for collection arguments. + /// The argument is permitted to occur multiple times, but duplicate + /// values will cause an error to be reported. + /// + MultipleUnique = Multiple | Unique, + } +} diff --git a/src/Microsoft.ML.Core/CommandLine/CharCursor.cs b/src/Microsoft.ML.Core/CommandLine/CharCursor.cs index 2ea5dbe8d5..d8a591331c 100644 --- a/src/Microsoft.ML.Core/CommandLine/CharCursor.cs +++ b/src/Microsoft.ML.Core/CommandLine/CharCursor.cs @@ -6,7 +6,7 @@ namespace Microsoft.ML.Runtime.CommandLine { - public sealed class CharCursor + internal sealed class CharCursor { private readonly string _text; private readonly int _ichLim; diff --git a/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs b/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs index e7e6ae19cd..a5c14259a8 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdLexer.cs @@ -2,13 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Text; -using Microsoft.ML.Runtime.Internal.Utilities; namespace Microsoft.ML.Runtime.CommandLine { - public sealed class CmdLexer + [BestFriend] + internal sealed class CmdLexer { private CharCursor _curs; diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index 37fd791358..28c623f3c2 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -1,132 +1,6 @@ -////////////////////////////////////////////////////////////////////////////// -// Command Line Argument Parser -// ---------------------------- -// Usage -// ----- -// -// Parsing command line arguments to a console application is a common problem. -// This library handles the common task of reading arguments from a command line -// and filling in the values in a type. -// -// To use this library, define a class whose fields represent the data that your -// application wants to receive from arguments on the command line. Then call -// CommandLine.ParseArguments() to fill the object with the data -// from the command line. Each field in the class defines a command line argument. -// The type of the field is used to validate the data read from the command line. -// The name of the field defines the name of the command line option. -// -// The parser can handle fields of the following types: -// -// - string -// - int -// - uint -// - bool -// - enum -// - array of the above type -// -// For example, suppose you want to read in the argument list for wc (word count). -// wc takes three optional boolean arguments: -l, -w, and -c and a list of files. -// -// You could parse these arguments using the following code: -// -// class WCArguments -// { -// public bool lines; -// public bool words; -// public bool chars; -// public string[] files; -// } -// -// class WC -// { -// static void Main(string[] args) -// { -// if (CommandLine.ParseArgumentsWithUsage(args, parsedArgs)) -// { -// // insert application code here -// } -// } -// } -// -// So you could call this aplication with the following command line to count -// lines in the foo and bar files: -// -// wc.exe /lines /files:foo /files:bar -// -// The program will display the following usage message when bad command line -// arguments are used: -// -// wc.exe -x -// -// Unrecognized command line argument '-x' -// /lines[+|-] short form /l -// /words[+|-] short form /w -// /chars[+|-] short form /c -// /files= short form /f -// @ Read response file for more options -// -// That was pretty easy. However, you realy want to omit the "/files:" for the -// list of files. The details of field parsing can be controled using custom -// attributes. The attributes which control parsing behaviour are: -// -// ArgumentAttribute -// - controls short name, long name, required, allow duplicates, default value -// and help text -// DefaultArgumentAttribute -// - allows omition of the "/name". -// - This attribute is allowed on only one field in the argument class. -// -// So for the wc.exe program we want this: -// -// using System; -// using Utilities; -// -// class WCArguments -// { -// [Argument(ArgumentType.AtMostOnce, HelpText="Count number of lines in the input text.")] -// public bool lines; -// [Argument(ArgumentType.AtMostOnce, HelpText="Count number of words in the input text.")] -// public bool words; -// [Argument(ArgumentType.AtMostOnce, HelpText="Count number of chars in the input text.")] -// public bool chars; -// [DefaultArgument(ArgumentType.MultipleUnique, HelpText="Input files to count.")] -// public string[] files; -// } -// -// class WC -// { -// static void Main(string[] args) -// { -// WCArguments parsedArgs = new WCArguments(); -// if (CommandLine.ParseArgumentsWithUsage(args, parsedArgs)) -// { -// // insert application code here -// } -// } -// } -// -// -// -// So now we have the command line we want: -// -// wc.exe /lines foo bar -// -// This will set lines to true and will set files to an array containing the -// strings "foo" and "bar". -// -// The new usage message becomes: -// -// wc.exe -x -// -// Unrecognized command line argument '-x' -// /lines[+|-] Count number of lines in the input text. (short form /l) -// /words[+|-] Count number of words in the input text. (short form /w) -// /chars[+|-] Count number of chars in the input text. (short form /c) -// @ Read response file for more options -// Input files to count. (short form /f) -// -// If you want more control over how error messages are reported, how /help is -// dealt with, etc you can instantiate the CommandLine.Parser class. +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. using System; using System.Collections; @@ -142,103 +16,15 @@ namespace Microsoft.ML.Runtime.CommandLine { - /// - /// Used to control parsing of command line arguments. - /// - [Flags] - public enum ArgumentType - { - /// - /// Indicates that this field is required. An error will be displayed - /// if it is not present when parsing arguments. - /// - Required = 0x01, - - /// - /// Only valid in conjunction with Multiple. - /// Duplicate values will result in an error. - /// - Unique = 0x02, - - /// - /// Inidicates that the argument may be specified more than once. - /// Only valid if the argument is a collection - /// - Multiple = 0x04, - - /// - /// The default type for non-collection arguments. - /// The argument is not required, but an error will be reported if it is specified more than once. - /// - AtMostOnce = 0x00, - - /// - /// For non-collection arguments, when the argument is specified more than - /// once no error is reported and the value of the argument is the last - /// value which occurs in the argument list. - /// - LastOccurenceWins = Multiple, - - /// - /// The default type for collection arguments. - /// The argument is permitted to occur multiple times, but duplicate - /// values will cause an error to be reported. - /// - MultipleUnique = Multiple | Unique, - } - - /// - /// Indicates that this argument is the default argument. - /// '/' or '-' prefix only the argument value is specified. - /// The ShortName property should not be set for DefaultArgumentAttribute - /// instances. The LongName property is used for usage text only and - /// does not affect the usage of the argument. - /// - [AttributeUsage(AttributeTargets.Field)] - public class DefaultArgumentAttribute : ArgumentAttribute - { - /// - /// Indicates that this argument is the default argument. - /// - /// Specifies the error checking to be done on the argument. - public DefaultArgumentAttribute(ArgumentType type) - : base(type) - { - } - } - - /// - /// On an enum value - indicates that the value should not be shown in help or UI. - /// - [AttributeUsage(AttributeTargets.Field)] - public class HideEnumValueAttribute : Attribute - { - public HideEnumValueAttribute() - { - } - } - - /// - /// On an enum value - specifies the display name. - /// - [AttributeUsage(AttributeTargets.Field)] - public class EnumValueDisplayAttribute : Attribute - { - public readonly string Name; - - public EnumValueDisplayAttribute(string name) - { - Name = name; - } - } /// /// A delegate used in error reporting. /// - public delegate void ErrorReporter(string message); + internal delegate void ErrorReporter(string message); [Flags] - public enum SettingsFlags + [BestFriend] + internal enum SettingsFlags { None = 0x00, @@ -261,6 +47,136 @@ public interface ICommandLineComponentFactory : IComponentFactory string GetSettingsString(); } + ////////////////////////////////////////////////////////////////////////////// + // Command Line Argument Parser + // ---------------------------- + // Usage + // ----- + // + // Parsing command line arguments to a console application is a common problem. + // This library handles the common task of reading arguments from a command line + // and filling in the values in a type. + // + // To use this library, define a class whose fields represent the data that your + // application wants to receive from arguments on the command line. Then call + // CommandLine.ParseArguments() to fill the object with the data + // from the command line. Each field in the class defines a command line argument. + // The type of the field is used to validate the data read from the command line. + // The name of the field defines the name of the command line option. + // + // The parser can handle fields of the following types: + // + // - string + // - int + // - uint + // - bool + // - enum + // - array of the above type + // + // For example, suppose you want to read in the argument list for wc (word count). + // wc takes three optional boolean arguments: -l, -w, and -c and a list of files. + // + // You could parse these arguments using the following code: + // + // class WCArguments + // { + // public bool lines; + // public bool words; + // public bool chars; + // public string[] files; + // } + // + // class WC + // { + // static void Main(string[] args) + // { + // if (CommandLine.ParseArgumentsWithUsage(args, parsedArgs)) + // { + // // insert application code here + // } + // } + // } + // + // So you could call this aplication with the following command line to count + // lines in the foo and bar files: + // + // wc.exe /lines /files:foo /files:bar + // + // The program will display the following usage message when bad command line + // arguments are used: + // + // wc.exe -x + // + // Unrecognized command line argument '-x' + // /lines[+|-] short form /l + // /words[+|-] short form /w + // /chars[+|-] short form /c + // /files= short form /f + // @ Read response file for more options + // + // That was pretty easy. However, you realy want to omit the "/files:" for the + // list of files. The details of field parsing can be controled using custom + // attributes. The attributes which control parsing behaviour are: + // + // ArgumentAttribute + // - controls short name, long name, required, allow duplicates, default value + // and help text + // DefaultArgumentAttribute + // - allows omition of the "/name". + // - This attribute is allowed on only one field in the argument class. + // + // So for the wc.exe program we want this: + // + // using System; + // using Utilities; + // + // class WCArguments + // { + // [Argument(ArgumentType.AtMostOnce, HelpText="Count number of lines in the input text.")] + // public bool lines; + // [Argument(ArgumentType.AtMostOnce, HelpText="Count number of words in the input text.")] + // public bool words; + // [Argument(ArgumentType.AtMostOnce, HelpText="Count number of chars in the input text.")] + // public bool chars; + // [DefaultArgument(ArgumentType.MultipleUnique, HelpText="Input files to count.")] + // public string[] files; + // } + // + // class WC + // { + // static void Main(string[] args) + // { + // WCArguments parsedArgs = new WCArguments(); + // if (CommandLine.ParseArgumentsWithUsage(args, parsedArgs)) + // { + // // insert application code here + // } + // } + // } + // + // + // + // So now we have the command line we want: + // + // wc.exe /lines foo bar + // + // This will set lines to true and will set files to an array containing the + // strings "foo" and "bar". + // + // The new usage message becomes: + // + // wc.exe -x + // + // Unrecognized command line argument '-x' + // /lines[+|-] Count number of lines in the input text. (short form /l) + // /words[+|-] Count number of words in the input text. (short form /w) + // /chars[+|-] Count number of chars in the input text. (short form /c) + // @ Read response file for more options + // Input files to count. (short form /f) + // + // If you want more control over how error messages are reported, how /help is + // dealt with, etc you can instantiate the CommandLine.Parser class. + /// /// Parser for command line arguments. /// @@ -285,7 +201,8 @@ public interface ICommandLineComponentFactory : IComponentFactory /// Arguments which are array types are collection arguments. Collection /// arguments can be specified multiple times. /// - public sealed class CmdParser + [BestFriend] + internal sealed class CmdParser { private const int SpaceBeforeParam = 2; private readonly ErrorReporter _reporter; diff --git a/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs b/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs new file mode 100644 index 0000000000..2d676f1ece --- /dev/null +++ b/src/Microsoft.ML.Core/CommandLine/DefaultArgumentAttribute.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Runtime.CommandLine +{ + /// + /// Indicates that this argument is the default argument. + /// '/' or '-' prefix only the argument value is specified. + /// The ShortName property should not be set for DefaultArgumentAttribute + /// instances. The LongName property is used for usage text only and + /// does not affect the usage of the argument. + /// + [AttributeUsage(AttributeTargets.Field)] + [BestFriend] + internal class DefaultArgumentAttribute : ArgumentAttribute + { + /// + /// Indicates that this argument is the default argument. + /// + /// Specifies the error checking to be done on the argument. + public DefaultArgumentAttribute(ArgumentType type) + : base(type) + { + } + } +} diff --git a/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs b/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs new file mode 100644 index 0000000000..b6cf4254ca --- /dev/null +++ b/src/Microsoft.ML.Core/CommandLine/EnumValueDisplayAttribute.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Runtime.CommandLine +{ + /// + /// On an enum value - specifies the display name. + /// + [AttributeUsage(AttributeTargets.Field)] + [BestFriend] + internal class EnumValueDisplayAttribute : Attribute + { + public readonly string Name; + + public EnumValueDisplayAttribute(string name) + { + Name = name; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs b/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs new file mode 100644 index 0000000000..964a5cc3f3 --- /dev/null +++ b/src/Microsoft.ML.Core/CommandLine/HideEnumValueAttribute.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Runtime.CommandLine +{ + /// + /// On an enum value - indicates that the value should not be shown in help or UI. + /// + [AttributeUsage(AttributeTargets.Field)] + [BestFriend] + internal class HideEnumValueAttribute : Attribute + { + public HideEnumValueAttribute() + { + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Core/CommandLine/Utils.cs b/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs similarity index 93% rename from src/Microsoft.ML.Core/CommandLine/Utils.cs rename to src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs index 7006fbf6d2..46423d43d9 100644 --- a/src/Microsoft.ML.Core/CommandLine/Utils.cs +++ b/src/Microsoft.ML.Core/CommandLine/SpecialPurpose.cs @@ -4,7 +4,8 @@ namespace Microsoft.ML.Runtime.CommandLine { - public static class SpecialPurpose + [BestFriend] + internal static class SpecialPurpose { /// /// This is used to specify a column mapping of a data transform. diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs index ff528bfcce..2743b8dc9f 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs @@ -2,7 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.CodeAnalysis; using Microsoft.ML.CodeAnalyzer.Tests.Helpers; +using System; +using System.Linq; using Xunit; namespace Microsoft.ML.InternalCodeAnalyzer.Tests @@ -68,16 +71,27 @@ public TypeName() public sealed class ContractsCheckFixTest : CodeFixVerifier { - private static string _preFix; - private static string _postFix; + private readonly Lazy SourcePreFix = TestUtils.LazySource("ContractsCheckBeforeFix.cs"); + private readonly Lazy SourcePostFix = TestUtils.LazySource("ContractsCheckAfterFix.cs"); + + private readonly Lazy SourceArgAttr = TestUtils.LazySource("ArgumentAttribute.cs"); + private readonly Lazy SourceArgType = TestUtils.LazySource("ArgumentType.cs"); + private readonly Lazy SourceBestAttr = TestUtils.LazySource("BestFriendAttribute.cs"); + private readonly Lazy SourceDefArgAttr = TestUtils.LazySource("DefaultArgumentAttribute.cs"); [Fact] public void ContractsCheckFix() { - string test = TestUtils.EnsureSourceLoaded(ref _preFix, "ContractsCheckBeforeFix.cs"); - string expected = TestUtils.EnsureSourceLoaded(ref _postFix, "ContractsCheckAfterFix.cs"); + //VerifyCSharpFix(SourcePreFix.Value, SourcePostFix.Value); + + Solution solution = null; + var proj = CreateProject(TestProjectName, ref solution, SourcePostFix.Value, SourceArgAttr.Value, + SourceArgType.Value, SourceBestAttr.Value, SourceDefArgAttr.Value); + var document = proj.Documents.First(); + var analyzer = GetCSharpDiagnosticAnalyzer(); + var comp = proj.GetCompilationAsync().Result; - VerifyCSharpFix(test, expected); + CycleAndVerifyFix(analyzer, GetCSharpCodeFixProvider(), SourcePostFix.Value, document); } } } diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CodeFixVerifier.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CodeFixVerifier.cs index 489ec5c446..483e7ace65 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CodeFixVerifier.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/CodeFixVerifier.cs @@ -76,6 +76,11 @@ protected void VerifyBasicFix(string oldSource, string newSource, int? codeFixIn private void VerifyFix(string language, DiagnosticAnalyzer analyzer, CodeFixProvider codeFixProvider, string oldSource, string newSource, int? codeFixIndex, bool allowNewCompilerDiagnostics) { var document = CreateDocument(oldSource); + CycleAndVerifyFix(analyzer, codeFixProvider, newSource, document, codeFixIndex, allowNewCompilerDiagnostics); + } + + internal static void CycleAndVerifyFix(DiagnosticAnalyzer analyzer, CodeFixProvider codeFixProvider, string newSource, Document document, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false) + { var analyzerDiagnostics = GetSortedDiagnosticsFromDocuments(analyzer, new[] { document }); var compilerDiagnostics = GetCompilerDiagnostics(document); int attempts = analyzerDiagnostics.Length; diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj b/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj index 386ef06c8b..b55fc8c987 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj @@ -8,7 +8,16 @@ - + + %(Filename)%(Extension) + + + %(Filename)%(Extension) + + + %(Filename)%(Extension) + + %(Filename)%(Extension) diff --git a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs index 122d612c7f..d191084ba8 100644 --- a/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs +++ b/test/Microsoft.ML.Sweeper.Tests/TestSweeper.cs @@ -653,7 +653,7 @@ public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper() sweeps = sweeper.ProposeSweeps(5, results); } - Assert.True(sweeps.Length <= 5); + Assert.True(sweeps == null || sweeps.Length <= 5); } } From 3ed082811f83c65e435a05e3f437c231de486837 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 8 Nov 2018 16:56:17 -0800 Subject: [PATCH 03/12] Remove a handful pointless interfaces. --- .../Data/ITrainerArguments.cs | 15 ---- .../Prediction/IPredictor.cs | 90 ------------------- 2 files changed, 105 deletions(-) delete mode 100644 src/Microsoft.ML.Core/Data/ITrainerArguments.cs diff --git a/src/Microsoft.ML.Core/Data/ITrainerArguments.cs b/src/Microsoft.ML.Core/Data/ITrainerArguments.cs deleted file mode 100644 index e4fdbbdc59..0000000000 --- a/src/Microsoft.ML.Core/Data/ITrainerArguments.cs +++ /dev/null @@ -1,15 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -namespace Microsoft.ML.Runtime -{ - // This is basically a no-op interface put in primarily - // for backward binary compat support for AFx. - // REVIEW: This interface was removed in TLC 3.0 as part of the - // deprecation of the *Factory interfaces, but added back as a temporary - // hack. Remove it asap. - public interface ITrainerArguments - { - } -} diff --git a/src/Microsoft.ML.Core/Prediction/IPredictor.cs b/src/Microsoft.ML.Core/Prediction/IPredictor.cs index d88b81d558..6bd1ac2056 100644 --- a/src/Microsoft.ML.Core/Prediction/IPredictor.cs +++ b/src/Microsoft.ML.Core/Prediction/IPredictor.cs @@ -69,94 +69,4 @@ public interface IPredictor : IPredictorProducing : IPredictorProducing { } - - /// - /// Predictor that returns a probability distribution associated with a prediction result - /// - /// Type of features container (instance) on which to make predictions - /// Type of prediction result - /// Type of probability distribution associated with the predicton - public interface IDistributionPredictor - : IDistPredictorProducing, IPredictor - { - /// - /// Return a probability distribution associated wtih the prediction. - /// - /// Data instance - /// Distribution associated with the prediction - TResultDistribution PredictDistribution(TFeatures features); - - /// - /// Return a probability distribution associated wtih the prediction, as well as the prediction. - /// - /// Data instance - /// Prediction - /// Distribution associated with the prediction - TResultDistribution PredictDistribution(TFeatures features, out TResult result); - } - - /// - /// Predictor that produces predictions for sets of instances at a time - /// for cases where this is more efficient than serial calls to Predict for each instance. - /// - /// Type of features container (instance) on which to make predictions - /// Type of collection of instances - /// Type of prediction result - /// Type of the collection of prediction results - public interface IBulkPredictor - : IPredictor - { - /// - /// Produce predictions for a set of instances - /// - /// Collection of instances - /// Collection of predictions - TResultCollection BulkPredict(TFeaturesCollection featuresCollection); - } - - /// - /// Predictor that can score sets of instances (presumably more efficiently) - /// and returns a distribution associated with a prediction result. - /// - /// Type of features container (instance) on which to make predictions - /// Type of collection of instances - /// Type of prediction result - /// Type of probability distribution associated with the predicton - /// Type of the collection of prediction results - /// Type of the collection of distributions for prediction results - public interface IBulkDistributionPredictor - : IBulkPredictor, - IDistributionPredictor - { - /// - /// Produce distributions over predictions for a set of instances - /// - /// Collection of instances - /// Collection of prediction distributions - TResultDistributionCollection BulkPredictDistribution(TFeaturesCollection featuresCollection); - - /// - /// Produce distributions over predictions for a set of instances, along with actual prediction results - /// - /// Collection of instances - /// Collection of prediction results - /// Collection of distributions associated with prediction results - TResultDistributionCollection BulkPredictDistribution(TFeaturesCollection featuresCollection, - out TResultCollection resultCollection); - } - -#if FUTURE - public interface IBulkPredictor : - IPredictor - { - // REVIEW: Should we also have versions where the caller supplies the "memory" to be filled in. - TResultSet BulkPredict(TTestDataSet dataset); - } - - public interface IBulkPredictor - : IBulkPredictor, IPredictor - { - } -#endif } From d57920006daeace60f5fd0428bab71ee40c1545b Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 8 Nov 2018 17:45:34 -0800 Subject: [PATCH 04/12] Internalize base cursor classes. --- .../Data/LinkedRootCursorBase.cs | 3 +- .../Data/LinkedRowFilterCursorBase.cs | 3 +- .../Data/LinkedRowRootCursorBase.cs | 3 +- .../Data/ReadOnlyMemoryUtils.cs | 3 +- src/Microsoft.ML.Core/Data/RootCursorBase.cs | 3 +- .../Data/SynchronizedCursorBase.cs | 3 +- .../{Data => EntryPoints}/IMlState.cs | 0 .../{Data => EntryPoints}/IPredictorModel.cs | 0 src/Microsoft.ML.Core/Prediction/ITrainer.cs | 32 ++++++++++++------- src/Microsoft.ML.Data/Data/IColumn.cs | 12 +++++-- src/Microsoft.ML.Data/Dirty/IniFileUtils.cs | 3 +- .../Scorers/PredictedLabelScorerBase.cs | 2 +- .../Scorers/RowToRowScorerBase.cs | 2 +- .../Transforms/BindingsWrappedRowCursor.cs | 2 +- .../Trainer/IModelCombiner.cs | 2 ++ .../OnnxTransformTests.cs | 6 ++-- 16 files changed, 52 insertions(+), 27 deletions(-) rename src/Microsoft.ML.Core/{Data => EntryPoints}/IMlState.cs (100%) rename src/Microsoft.ML.Core/{Data => EntryPoints}/IPredictorModel.cs (100%) diff --git a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs index 20f84c6e24..ecec0b1a0d 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRootCursorBase.cs @@ -8,7 +8,8 @@ namespace Microsoft.ML.Runtime.Data /// Base class for a cursor has an input cursor, but still needs to do work on /// MoveNext/MoveMany. /// - public abstract class LinkedRootCursorBase : RootCursorBase + [BestFriend] + internal abstract class LinkedRootCursorBase : RootCursorBase where TInput : class, ICursor { private readonly ICursor _root; diff --git a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs index 4a07bbd47b..22ade4a983 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRowFilterCursorBase.cs @@ -7,7 +7,8 @@ namespace Microsoft.ML.Runtime.Data /// /// Base class for creating a cursor of rows that filters out some input rows. /// - public abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase + [BestFriend] + internal abstract class LinkedRowFilterCursorBase : LinkedRowRootCursorBase { public override long Batch => Input.Batch; diff --git a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs index f4b7a4da67..7874686797 100644 --- a/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs @@ -10,7 +10,8 @@ namespace Microsoft.ML.Runtime.Data /// that the default assumes /// that each input column is exposed as an output column with the same column index. /// - public abstract class LinkedRowRootCursorBase : LinkedRootCursorBase, IRowCursor + [BestFriend] + internal abstract class LinkedRowRootCursorBase : LinkedRootCursorBase, IRowCursor { private readonly bool[] _active; diff --git a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs index 4f89e29201..f579e5f609 100644 --- a/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs +++ b/src/Microsoft.ML.Core/Data/ReadOnlyMemoryUtils.cs @@ -10,7 +10,8 @@ namespace Microsoft.ML.Runtime.Data { - public static class ReadOnlyMemoryUtils + [BestFriend] + internal static class ReadOnlyMemoryUtils { /// diff --git a/src/Microsoft.ML.Core/Data/RootCursorBase.cs b/src/Microsoft.ML.Core/Data/RootCursorBase.cs index 73be098e0f..5b64a40d6a 100644 --- a/src/Microsoft.ML.Core/Data/RootCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/RootCursorBase.cs @@ -17,7 +17,8 @@ namespace Microsoft.ML.Runtime.Data /// that has an input cursor and does NOT need notification on /, /// use . /// - public abstract class RootCursorBase : ICursor + [BestFriend] + internal abstract class RootCursorBase : ICursor { protected readonly IChannel Ch; diff --git a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs index 202c3d8cfd..da60c84ccf 100644 --- a/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs +++ b/src/Microsoft.ML.Core/Data/SynchronizedCursorBase.cs @@ -10,7 +10,8 @@ namespace Microsoft.ML.Runtime.Data /// It delegates all ICursor functionality except Dispose() to the root cursor. /// Dispose is virtual with the default implementation delegating to the input cursor. /// - public abstract class SynchronizedCursorBase : ICursor + [BestFriend] + internal abstract class SynchronizedCursorBase : ICursor where TBase : class, ICursor { protected readonly IChannel Ch; diff --git a/src/Microsoft.ML.Core/Data/IMlState.cs b/src/Microsoft.ML.Core/EntryPoints/IMlState.cs similarity index 100% rename from src/Microsoft.ML.Core/Data/IMlState.cs rename to src/Microsoft.ML.Core/EntryPoints/IMlState.cs diff --git a/src/Microsoft.ML.Core/Data/IPredictorModel.cs b/src/Microsoft.ML.Core/EntryPoints/IPredictorModel.cs similarity index 100% rename from src/Microsoft.ML.Core/Data/IPredictorModel.cs rename to src/Microsoft.ML.Core/EntryPoints/IPredictorModel.cs diff --git a/src/Microsoft.ML.Core/Prediction/ITrainer.cs b/src/Microsoft.ML.Core/Prediction/ITrainer.cs index 7eca991b5c..6647f77592 100644 --- a/src/Microsoft.ML.Core/Prediction/ITrainer.cs +++ b/src/Microsoft.ML.Core/Prediction/ITrainer.cs @@ -13,19 +13,27 @@ namespace Microsoft.ML.Runtime /// Loadable class signatures for trainers. Typically each trainer should register with /// both SignatureTrainer and SignatureXxxTrainer where Xxx is the prediction kind. /// - public delegate void SignatureTrainer(); + [BestFriend] + internal delegate void SignatureTrainer(); - public delegate void SignatureBinaryClassifierTrainer(); - public delegate void SignatureMultiClassClassifierTrainer(); - public delegate void SignatureRegressorTrainer(); - public delegate void SignatureMultiOutputRegressorTrainer(); - public delegate void SignatureRankerTrainer(); - public delegate void SignatureAnomalyDetectorTrainer(); - public delegate void SignatureClusteringTrainer(); - public delegate void SignatureSequenceTrainer(); - public delegate void SignatureMatrixRecommendingTrainer(); - - public delegate void SignatureModelCombiner(PredictionKind kind); + [BestFriend] + internal delegate void SignatureBinaryClassifierTrainer(); + [BestFriend] + internal delegate void SignatureMultiClassClassifierTrainer(); + [BestFriend] + internal delegate void SignatureRegressorTrainer(); + [BestFriend] + internal delegate void SignatureMultiOutputRegressorTrainer(); + [BestFriend] + internal delegate void SignatureRankerTrainer(); + [BestFriend] + internal delegate void SignatureAnomalyDetectorTrainer(); + [BestFriend] + internal delegate void SignatureClusteringTrainer(); + [BestFriend] + internal delegate void SignatureSequenceTrainer(); + [BestFriend] + internal delegate void SignatureMatrixRecommendingTrainer(); /// /// The base interface for a trainers. Implementors should not implement this interface directly, diff --git a/src/Microsoft.ML.Data/Data/IColumn.cs b/src/Microsoft.ML.Data/Data/IColumn.cs index 948d056677..14638081d8 100644 --- a/src/Microsoft.ML.Data/Data/IColumn.cs +++ b/src/Microsoft.ML.Data/Data/IColumn.cs @@ -274,7 +274,7 @@ public ValueGetter GetGetter() /// /// The base class for a few implementations that do not "go" anywhere. /// - public abstract class DefaultCounted : ICounted + private abstract class DefaultCounted : ICounted { public long Position => 0; public long Batch => 0; @@ -331,7 +331,7 @@ public ValueGetter GetGetter() /// column as an . This class will cease to be necessary at the point when all /// metadata implementations are just simple s. /// - public sealed class MetadataRow : DefaultCounted, IRow + public sealed class MetadataRow : IRow { public Schema Schema => _schemaImpl.AsSchema; @@ -341,6 +341,14 @@ public sealed class MetadataRow : DefaultCounted, IRow private readonly KeyValuePair[] _map; + long ICounted.Position => 0; + long ICounted.Batch => 0; + ValueGetter ICounted.GetIdGetter() + => IdGetter; + + private static void IdGetter(ref UInt128 id) + => id = default; + private sealed class SchemaImpl : ISchema { private readonly MetadataRow _parent; diff --git a/src/Microsoft.ML.Data/Dirty/IniFileUtils.cs b/src/Microsoft.ML.Data/Dirty/IniFileUtils.cs index a54ab2eb4f..782c07af01 100644 --- a/src/Microsoft.ML.Data/Dirty/IniFileUtils.cs +++ b/src/Microsoft.ML.Data/Dirty/IniFileUtils.cs @@ -8,7 +8,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public static class IniFileUtils + [BestFriend] + internal static class IniFileUtils { // This could be done better by having something that actually parses the .ini file and provides more // functionality. For now, we'll just provide the minimum needed. If we went the nicer route, probably would diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index 508cab37b9..7fae48f5c7 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -194,7 +194,7 @@ protected override IEnumerable> GetMetadataType yield return TextType.Instance.GetPair(MetadataUtils.Kinds.ScoreValueKind); if (_predColMetadata != null) { - ISchema sch = _predColMetadata.Schema; + var sch = _predColMetadata.Schema; for (int i = 0; i < sch.ColumnCount; ++i) yield return new KeyValuePair(sch.GetColumnName(i), sch.GetColumnType(i)); } diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index c05baf50c6..066e459a54 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -214,7 +214,7 @@ protected override int MapColumnIndex(out bool isSrc, int col) return bindings.MapColumnIndex(out isSrc, col); } - protected sealed class RowCursor : SynchronizedCursorBase, IRowCursor + private sealed class RowCursor : SynchronizedCursorBase, IRowCursor { private readonly BindingsBase _bindings; private readonly bool[] _active; diff --git a/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs b/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs index 4959b7c7b2..409d0dbeaa 100644 --- a/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs +++ b/src/Microsoft.ML.Data/Transforms/BindingsWrappedRowCursor.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Runtime.Data /// inconvenient or inefficient to handle the "no output selected" case in their /// own implementation. /// - public sealed class BindingsWrappedRowCursor : SynchronizedCursorBase, IRowCursor + internal sealed class BindingsWrappedRowCursor : SynchronizedCursorBase, IRowCursor { private readonly ColumnBindingsBase _bindings; diff --git a/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs b/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs index 0a392585a7..196df4dc54 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs @@ -7,6 +7,8 @@ namespace Microsoft.ML.Runtime.Ensemble { + public delegate void SignatureModelCombiner(PredictionKind kind); + /// /// An interface that combines multiple predictors into a single predictor. /// diff --git a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs index a95d5c15c8..00927936b1 100644 --- a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs +++ b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs @@ -45,7 +45,7 @@ private class TestDataDifferntType public string[] data_0; } - private float[] getSampleArrayData() + private float[] GetSampleArrayData() { var samplevector = new float[inputSize]; for (int i = 0; i < inputSize; i++) @@ -65,7 +65,7 @@ void TestSimpleCase() var modelFile = "squeezenet/00000001/model.onnx"; - var samplevector = getSampleArrayData(); + var samplevector = GetSampleArrayData(); var dataView = ComponentCreation.CreateDataView(Env, new TestData[] { @@ -108,7 +108,7 @@ void TestOldSavingAndLoading() var modelFile = "squeezenet/00000001/model.onnx"; - var samplevector = getSampleArrayData(); + var samplevector = GetSampleArrayData(); var dataView = ComponentCreation.CreateDataView(Env, new TestData[] { From 8478934a40199d85dae74c7802fe734847af3db2 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 8 Nov 2018 18:17:35 -0800 Subject: [PATCH 05/12] Internalize PFA support. --- .../DataView/RowToRowMapperTransform.cs | 2 +- .../Model/Pfa/ICanSavePfa.cs | 19 ++++--- .../Prediction/Calibrator.cs | 10 ++-- .../Scorers/GenericScorer.cs | 2 +- .../Scorers/PredictedLabelScorerBase.cs | 4 +- .../Transforms/TransformBase.cs | 49 ++++++++++++------- src/Microsoft.ML.FastTree/FastTree.cs | 4 +- .../Standard/LinearPredictor.cs | 4 +- .../MulticlassLogisticRegression.cs | 4 +- .../Standard/MultiClass/Ova.cs | 4 +- src/Microsoft.ML.Transforms/GroupTransform.cs | 2 +- .../Text/WordTokenizeTransform.cs | 2 +- 12 files changed, 61 insertions(+), 45 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index e968348618..b8276f3abf 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -215,7 +215,7 @@ public void SaveAsOnnx(OnnxContext ctx) } } - public void SaveAsPfa(BoundPfaContext ctx) + void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) { Host.CheckValue(ctx, nameof(ctx)); if (_mapper is ISaveAsPfa pfa) diff --git a/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs b/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs index bc0266f96e..bcda835fe8 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/ICanSavePfa.cs @@ -7,7 +7,8 @@ namespace Microsoft.ML.Runtime.Model.Pfa { - public interface ICanSavePfa + [BestFriend] + internal interface ICanSavePfa { /// /// Whether this object really is capable of saving itself as part of a PFA @@ -22,7 +23,8 @@ public interface ICanSavePfa /// /// This component know how to save himself in Pfa format. /// - public interface ISaveAsPfa : ICanSavePfa + [BestFriend] + internal interface ISaveAsPfa : ICanSavePfa { /// /// Save as PFA. For any columns that are output, this interface should use @@ -37,9 +39,9 @@ public interface ISaveAsPfa : ICanSavePfa /// /// This data model component is savable as PFA. See https://dmg.org/pfa/ . /// - public interface ITransformCanSavePfa : ISaveAsPfa, IDataTransform + [BestFriend] + internal interface ITransformCanSavePfa : ISaveAsPfa, IDataTransform { - } /// @@ -47,7 +49,8 @@ public interface ITransformCanSavePfa : ISaveAsPfa, IDataTransform /// typically called within an that is wrapping /// this mapper, and has already been bound to it. /// - public interface IBindableCanSavePfa : ICanSavePfa, ISchemaBindableMapper + [BestFriend] + internal interface IBindableCanSavePfa : ICanSavePfa, ISchemaBindableMapper { /// /// Save as PFA. If is @@ -71,7 +74,8 @@ public interface IBindableCanSavePfa : ICanSavePfa, ISchemaBindableMapper /// For simple mappers. Intended to be used for and /// instances. /// - public interface ISingleCanSavePfa : ICanSavePfa + [BestFriend] + internal interface ISingleCanSavePfa : ICanSavePfa { /// /// Implementors of this method are responsible for providing the PFA expression that @@ -92,7 +96,8 @@ public interface ISingleCanSavePfa : ICanSavePfa /// For simple mappers. Intended to be used for /// instances. /// - public interface IDistCanSavePfa : ISingleCanSavePfa, IValueMapperDist + [BestFriend] + internal interface IDistCanSavePfa : ISingleCanSavePfa, IValueMapperDist { /// /// The call for distribution predictors. Unlike , diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index c05a87dbaf..4b24e84945 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -265,7 +265,7 @@ public ValueMapper> GetWhatTheFeatureMapper(int return _whatTheFeature.GetWhatTheFeatureMapper(top, bottom, normalize); } - public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) + JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(input, nameof(input)); @@ -275,7 +275,7 @@ public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) return mapper.SaveAsPfa(ctx, input); } - public void SaveAsPfa(BoundPfaContext ctx, JToken input, + void IDistCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input, string score, out JToken scoreToken, string prob, out JToken probToken) { Host.CheckValue(ctx, nameof(ctx)); @@ -283,7 +283,7 @@ public void SaveAsPfa(BoundPfaContext ctx, JToken input, Host.CheckValueOrNull(score); Host.CheckValueOrNull(prob); - JToken scoreExpression = SaveAsPfa(ctx, input); + JToken scoreExpression = ((ISingleCanSavePfa)this).SaveAsPfa(ctx, input); scoreToken = ctx.DeclareVar(score, scoreExpression); var calibrator = Calibrator as ISingleCanSavePfa; if (calibrator?.CanSavePfa != true) @@ -1348,7 +1348,7 @@ private static VersionInfo GetVersionInfo() public Double ParamA { get; } public Double ParamB { get; } - public bool CanSavePfa => true; + bool ICanSavePfa.CanSavePfa => true; public bool CanSaveOnnx(OnnxContext ctx) => true; public PlattCalibrator(IHostEnvironment env, Double paramA, Double paramB) @@ -1426,7 +1426,7 @@ public static Float PredictProbability(Float output, Double a, Double b) return (Float)(1 / (1 + Math.Exp(a * output + b))); } - public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) + JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { _host.CheckValue(ctx, nameof(ctx)); _host.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index 18fe6f2fda..5bb691eb1f 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -205,7 +205,7 @@ protected override void SaveCore(ModelSaveContext ctx) _bindings.Save(ctx); } - public void SaveAsPfa(BoundPfaContext ctx) + void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSavePfa); diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index 7fae48f5c7..56183b4170 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -279,7 +279,7 @@ public override Func GetActiveMapperColumns(bool[] active) protected override BindingsBase GetBindings() => Bindings; public override Schema Schema { get; } - public bool CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; + bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; public bool CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; @@ -336,7 +336,7 @@ protected override void SaveCore(ModelSaveContext ctx) Bindings.Save(ctx); } - public void SaveAsPfa(BoundPfaContext ctx) + void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSavePfa); diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index ca74ec9838..a675ce7acc 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -112,23 +112,25 @@ protected RowToRowTransformBase(IHost host, IDataView input) /// public abstract class FilterBase : TransformBase, ITransformCanSavePfa { - protected FilterBase(IHostEnvironment env, string name, IDataView input) + [BestFriend] + private protected FilterBase(IHostEnvironment env, string name, IDataView input) : base(env, name, input) { } - protected FilterBase(IHost host, IDataView input) + [BestFriend] + private protected FilterBase(IHost host, IDataView input) : base(host, input) { } - public override long? GetRowCount(bool lazy = true) { return null; } + public override long? GetRowCount(bool lazy = true) => null; - public sealed override Schema Schema { get { return Source.Schema; } } + public sealed override Schema Schema => Source.Schema; - public virtual bool CanSavePfa => true; + bool ICanSavePfa.CanSavePfa => true; - public virtual void SaveAsPfa(BoundPfaContext ctx) + void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) { Host.CheckValue(ctx, nameof(ctx)); // Because filters do not modify the schema, this is a no-op. @@ -468,11 +470,14 @@ private sealed class ColumnTmp : OneToOneColumn // The InputTranspose transpose schema, null iff InputTranspose is null. protected ITransposeSchema InputTransposeSchema => InputTranspose?.TransposeSchema; - public virtual bool CanSavePfa => false; + bool ICanSavePfa.CanSavePfa => CanSavePfaCore; + + private protected virtual bool CanSavePfaCore => false; public virtual bool CanSaveOnnx(OnnxContext ctx) => false; - protected OneToOneTransformBase(IHostEnvironment env, string name, OneToOneColumn[] column, + [BestFriend] + private protected OneToOneTransformBase(IHostEnvironment env, string name, OneToOneColumn[] column, IDataView input, Func testType) : base(env, name, input) { @@ -485,7 +490,8 @@ protected OneToOneTransformBase(IHostEnvironment env, string name, OneToOneColum Metadata = new MetadataDispatcher(Infos.Length); } - protected OneToOneTransformBase(IHost host, OneToOneColumn[] column, + [BestFriend] + private protected OneToOneTransformBase(IHost host, OneToOneColumn[] column, IDataView input, Func testType) : base(host, input) { @@ -498,7 +504,8 @@ protected OneToOneTransformBase(IHost host, OneToOneColumn[] column, Metadata = new MetadataDispatcher(Infos.Length); } - protected OneToOneTransformBase(IHost host, ModelLoadContext ctx, + [BestFriend] + private protected OneToOneTransformBase(IHost host, ModelLoadContext ctx, IDataView input, Func testType) : base(host, input) { @@ -514,7 +521,8 @@ protected OneToOneTransformBase(IHost host, ModelLoadContext ctx, /// /// Re-applying constructor. /// - protected OneToOneTransformBase(IHostEnvironment env, string name, OneToOneTransformBase transform, + [BestFriend] + private protected OneToOneTransformBase(IHostEnvironment env, string name, OneToOneTransformBase transform, IDataView newInput, Func checkType) : base(env, name, newInput) { @@ -534,18 +542,20 @@ protected OneToOneTransformBase(IHostEnvironment env, string name, OneToOneTrans Metadata = new MetadataDispatcher(Infos.Length); } - protected MetadataDispatcher Metadata { get; } + [BestFriend] + private protected MetadataDispatcher Metadata { get; } - protected void SaveBase(ModelSaveContext ctx) + [BestFriend] + private protected void SaveBase(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); _bindings.Save(ctx); } - public void SaveAsPfa(BoundPfaContext ctx) + void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) { Host.CheckValue(ctx, nameof(ctx)); - Host.Assert(CanSavePfa); + Host.Assert(((ICanSavePfa)this).CanSavePfa); var toHide = new List(); var toDeclare = new List>(); @@ -596,8 +606,8 @@ public void SaveAsOnnx(OnnxContext ctx) } /// - /// Called by . Should be implemented by subclasses that return - /// true from . Will be called + /// Called by . Should be implemented by subclasses that return + /// true from . Will be called /// /// The context. Can be used to declare cells, access other information, /// and whatnot. This method should not actually, however, declare the variable corresponding @@ -607,13 +617,14 @@ public void SaveAsOnnx(OnnxContext ctx) /// The token in the PFA corresponding to the source col /// Shuold return the declaration corresponding to the value of this column. Will /// return null in the event that we do not know how to express this column as PFA - protected virtual JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToken srcToken) + [BestFriend] + private protected virtual JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToken srcToken) { Host.AssertValue(ctx); Host.Assert(0 <= iinfo && iinfo < _bindings.InfoCount); Host.Assert(Infos[iinfo] == info); Host.AssertValue(srcToken); - Host.Assert(CanSavePfa); + Host.Assert(((ICanSavePfa)this).CanSavePfa); return null; } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 32ada36a09..ec2bc6edaa 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2838,7 +2838,7 @@ public abstract class FastTreePredictionWrapper : public ColumnType InputType { get; } public ColumnType OutputType => NumberType.Float; - public bool CanSavePfa => true; + bool ICanSavePfa.CanSavePfa => true; public bool CanSaveOnnx(OnnxContext ctx) => true; protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs) @@ -3019,7 +3019,7 @@ private string AddCalibrationToIni(string ini, ICalibrator calibrator) } } - public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) + JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 6a0cfaff99..eb521da787 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -99,7 +99,7 @@ IEnumerator IEnumerable.GetEnumerator() public ColumnType OutputType => NumberType.Float; - public bool CanSavePfa => true; + bool ICanSavePfa.CanSavePfa => true; public bool CanSaveOnnx(OnnxContext ctx) => true; @@ -203,7 +203,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.WriteSingleArray(Weight.GetValues()); } - public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) + JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 5b237aa47b..ab7745e072 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -392,7 +392,7 @@ private static VersionInfo GetVersionInfo() public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification; public ColumnType InputType { get; } public ColumnType OutputType { get; } - public bool CanSavePfa => true; + bool ICanSavePfa.CanSavePfa => true; public bool CanSaveOnnx(OnnxContext ctx) => true; internal MulticlassLogisticRegressionPredictor(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null) @@ -867,7 +867,7 @@ public void SaveSummary(TextWriter writer, RoleMappedSchema schema) SaveAsText(writer, schema); } - public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) + JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index 529ee30f46..459a47ec39 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -225,7 +225,7 @@ private static VersionInfo GetVersionInfo() public ColumnType InputType => _impl.InputType; public ColumnType OutputType { get; } public ColumnType DistType => OutputType; - public bool CanSavePfa => _impl.CanSavePfa; + bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; [BestFriend] internal static OvaPredictor Create(IHost host, bool useProb, TScalarPredictor[] predictors) @@ -337,7 +337,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.SaveModel(preds[i], string.Format(SubPredictorFmt, i)); } - public JToken SaveAsPfa(BoundPfaContext ctx, JToken input) + JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index 8eebd94230..e679299831 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -429,7 +429,7 @@ public void GetMetadata(string kind, int col, ref TValue value) /// - The group column getters are taken directly from the trailing cursor. /// - The keep column getters are provided by the aggregators. /// - public sealed class Cursor : RootCursorBase, IRowCursor + private sealed class Cursor : RootCursorBase, IRowCursor { /// /// This class keeps track of the previous group key and tests the current group key against the previous one. diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs index 9cc4d56bb6..05288d7135 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs @@ -361,7 +361,7 @@ private void AddTerms(ReadOnlyMemory txt, char[] separators, List Date: Thu, 8 Nov 2018 22:41:14 -0800 Subject: [PATCH 06/12] Finish internalizing PFA support. Internalize ONNX support. --- .../DataView/RowToRowMapperTransform.cs | 6 ++-- .../Model/Onnx/ICanSaveOnnx.cs | 18 +++++++---- .../Model/Onnx/OnnxContext.cs | 6 ++-- .../Model/Pfa/BoundPfaContext.cs | 3 +- .../Prediction/Calibrator.cs | 25 ++++++++------- .../Properties/AssemblyInfo.cs | 1 + .../Scorers/BinaryClassifierScorer.cs | 4 +-- .../Scorers/GenericScorer.cs | 6 ++-- .../Scorers/MultiClassClassifierScorer.cs | 12 +++---- .../Scorers/PredictedLabelScorerBase.cs | 7 ++-- .../Scorers/SchemaBindablePredictorWrapper.cs | 32 +++++++++++++------ .../Transforms/TransformBase.cs | 11 ++++--- src/Microsoft.ML.FastTree/FastTree.cs | 4 +-- .../KMeansPredictor.cs | 4 +-- src/Microsoft.ML.Onnx/AssemblyInfo.cs | 10 ++++-- src/Microsoft.ML.Onnx/SaveOnnxCommand.cs | 2 +- .../Standard/LinearPredictor.cs | 4 +-- .../MulticlassLogisticRegression.cs | 4 +-- 18 files changed, 99 insertions(+), 60 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index b8276f3abf..308c08ba24 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -76,9 +76,9 @@ private static VersionInfo GetVersionInfo() public override Schema Schema => _bindings.Schema; - public bool CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false; - public bool CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false; + bool ICanSavePfa.CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false; public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func mapperFactory) : base(env, RegistrationName, input) @@ -205,7 +205,7 @@ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolid return cursors; } - public void SaveAsOnnx(OnnxContext ctx) + void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); if (_mapper is ISaveAsOnnx onnx) diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs index 37d938e234..dfc8756303 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs @@ -6,7 +6,8 @@ namespace Microsoft.ML.Runtime.Model.Onnx { - public interface ICanSaveOnnx + [BestFriend] + internal interface ICanSaveOnnx { /// /// Whether this object really is capable of saving itself as part of an ONNX @@ -21,7 +22,8 @@ public interface ICanSaveOnnx /// /// This component know how to save himself in ONNX format. /// - public interface ISaveAsOnnx : ICanSaveOnnx + [BestFriend] + internal interface ISaveAsOnnx : ICanSaveOnnx { /// /// Save as ONNX. @@ -33,7 +35,8 @@ public interface ISaveAsOnnx : ICanSaveOnnx /// /// This data model component is savable as ONNX. /// - public interface ITransformCanSaveOnnx : ISaveAsOnnx, IDataTransform + [BestFriend] + internal interface ITransformCanSaveOnnx : ISaveAsOnnx, IDataTransform { } @@ -42,7 +45,8 @@ public interface ITransformCanSaveOnnx : ISaveAsOnnx, IDataTransform /// typically called within an that is wrapping /// this mapper, and has already been bound to it. /// - public interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper + [BestFriend] + internal interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper { /// /// Save as ONNX. If is @@ -66,7 +70,8 @@ public interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper /// For simple mappers. Intended to be used for and /// instances. /// - public interface ISingleCanSaveOnnx : ICanSaveOnnx + [BestFriend] + internal interface ISingleCanSaveOnnx : ICanSaveOnnx { bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn); } @@ -75,7 +80,8 @@ public interface ISingleCanSaveOnnx : ICanSaveOnnx /// For simple mappers. Intended to be used for /// instances. /// - public interface IDistCanSaveOnnx : ISingleCanSaveOnnx, IValueMapperDist + [BestFriend] + internal interface IDistCanSaveOnnx : ISingleCanSaveOnnx, IValueMapperDist { new bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn); } diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 850d2f6b64..38d9f77915 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -7,7 +7,8 @@ namespace Microsoft.ML.Runtime.Model.Onnx { - public enum OnnxVersion { Stable=0, Experimental=1 } + [BestFriend] + internal enum OnnxVersion { Stable = 0, Experimental = 1 } /// /// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This @@ -16,7 +17,8 @@ public enum OnnxVersion { Stable=0, Experimental=1 } /// given to a component, all other components up to that component have already attempted to express themselves in /// this context, with their outputs possibly available in the ONNX graph. /// - public abstract class OnnxContext + [BestFriend] + internal abstract class OnnxContext { /// /// Generates a unique name for the node based on a prefix. diff --git a/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs b/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs index c75b760f3a..5090ea542c 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/BoundPfaContext.cs @@ -20,7 +20,8 @@ namespace Microsoft.ML.Runtime.Model.Pfa /// has facilities to remember what column name in maps to /// what token in the PFA being built up. /// - public sealed class BoundPfaContext + [BestFriend] + internal sealed class BoundPfaContext { /// /// The internal PFA context, for an escape hatch. diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 4b24e84945..b6c650d77c 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -224,8 +224,8 @@ public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBa public ColumnType InputType => _mapper.InputType; public ColumnType OutputType => _mapper.OutputType; public ColumnType DistType => NumberType.Float; - public bool CanSavePfa => (_mapper as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; + bool ICanSavePfa.CanSavePfa => (_mapper as ICanSavePfa)?.CanSavePfa == true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) : base(env, name, predictor, calibrator) @@ -296,7 +296,10 @@ void IDistCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input, probToken = ctx.DeclareVar(prob, probExpression); } - public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumnName) + bool IDistCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumnName) + => ((ISingleCanSaveOnnx)this).SaveAsOnnx(ctx, outputNames, featureColumnName); + + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumnName) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(outputNames, nameof(outputNames)); @@ -620,9 +623,9 @@ private static VersionInfo GetVersionInfo() /// Whether we can save as PFA. Note that this depends on whether the underlying predictor /// can save as PFA, since in the event that this in particular does not get saved, /// - public bool CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true; + bool ICanSavePfa.CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; public SchemaBindableCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator calibrator) : base(env, LoaderSignature, predictor, calibrator) @@ -653,22 +656,22 @@ public void Save(ModelSaveContext ctx) SaveCore(ctx); } - public void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputs) + void IBindableCanSavePfa.SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputs) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(schema, nameof(schema)); Host.CheckParam(Utils.Size(outputs) == 2, nameof(outputs), "Expected this to have two outputs"); - Host.Check(CanSavePfa, "Called despite not being savable"); + Host.Check(((ICanSavePfa)this).CanSavePfa, "Called despite not being savable"); ctx.Hide(outputs); } - public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputs) + bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputs) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckParam(Utils.Size(outputs) == 2, nameof(outputs), "Expected this to have two outputs"); Host.CheckValue(schema, nameof(schema)); - Host.Check(CanSaveOnnx(ctx), "Called despite not being savable"); + Host.Check(((ICanSaveOnnx)this).CanSaveOnnx(ctx), "Called despite not being savable"); return false; } @@ -1349,7 +1352,7 @@ private static VersionInfo GetVersionInfo() public Double ParamA { get; } public Double ParamB { get; } bool ICanSavePfa.CanSavePfa => true; - public bool CanSaveOnnx(OnnxContext ctx) => true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; public PlattCalibrator(IHostEnvironment env, Double paramA, Double paramB) { @@ -1435,7 +1438,7 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) PfaUtils.Call("+", -ParamB, PfaUtils.Call("*", -ParamA, input))); } - public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, string featureColumnName) + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, string featureColumnName) { _host.CheckValue(ctx, nameof(ctx)); _host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames)); diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs index 5056571670..5fed334c9d 100644 --- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs @@ -6,6 +6,7 @@ using Microsoft.ML; [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 958aba0da1..56a72d3f92 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -185,7 +185,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write(_threshold); } - public override void SaveAsOnnx(OnnxContext ctx) + private protected override void SaveAsOnnxCore(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); @@ -193,7 +193,7 @@ public override void SaveAsOnnx(OnnxContext ctx) if (!ctx.ContainsColumn(DefaultColumnNames.Features)) return; - base.SaveAsOnnx(ctx); + base.SaveAsOnnxCore(ctx); int delta = Bindings.DerivedColumnCount; Host.Assert(delta == 1); diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index 5bb691eb1f..ab3a74b26b 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -141,9 +141,9 @@ private static VersionInfo GetVersionInfo() public override Schema Schema { get; } - public bool CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; + bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; /// /// The entry point for creating a . @@ -220,7 +220,7 @@ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) pfaBindable.SaveAsPfa(ctx, schema, outColNames); } - public void SaveAsOnnx(OnnxContext ctx) + void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index e4c2b80c83..ddde2f9396 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -77,8 +77,8 @@ public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveMod private readonly Func _canWrap; public VectorType Type => _type; - public bool CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; + bool ICanSavePfa.CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; public ISchemaBindableMapper InnerBindable => _bindable; private static VersionInfo GetVersionInfo() @@ -196,20 +196,20 @@ private void SaveCore(ModelSaveContext ctx) throw _host.Except("We do not know how to serialize label names of type '{0}'", _type.ItemType); } - public void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) + void IBindableCanSavePfa.SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); - Contracts.Check(CanSavePfa, "Cannot be saved as PFA"); + Contracts.Check(((ICanSavePfa)this).CanSavePfa, "Cannot be saved as PFA"); Contracts.Assert(_bindable is IBindableCanSavePfa); ((IBindableCanSavePfa)_bindable).SaveAsPfa(ctx, schema, outputNames); } - public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); - Contracts.Check(CanSaveOnnx(ctx), "Cannot be saved as ONNX."); + Contracts.Check(((ICanSaveOnnx)this).CanSaveOnnx(ctx), "Cannot be saved as ONNX."); Contracts.Assert(_bindable is IBindableCanSaveOnnx); return ((IBindableCanSaveOnnx)_bindable).SaveAsOnnx(ctx, schema, outputNames); } diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index 56183b4170..3a6ef6d1e8 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -281,7 +281,7 @@ public override Func GetActiveMapperColumns(bool[] active) bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema, string registrationName, string scoreColKind, string scoreColName, @@ -365,7 +365,10 @@ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) protected abstract JToken PredictedLabelPfa(string[] mapperOutputs); - public virtual void SaveAsOnnx(OnnxContext ctx) + void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) => SaveAsOnnxCore(ctx); + + [BestFriend] + private protected virtual void SaveAsOnnxCore(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(Bindable is IBindableCanSaveOnnx); diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 5a43f069c4..b61193480e 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -44,9 +44,9 @@ public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper protected readonly IValueMapper ValueMapper; protected readonly ColumnType ScoreType; - public bool CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true; + bool ICanSavePfa.CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx(OnnxContext ctx) => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; public SchemaBindablePredictorWrapperBase(IPredictor predictor) { @@ -89,17 +89,31 @@ public virtual void Save(ModelSaveContext ctx) ctx.SaveModel(Predictor, ModelFileUtils.DirPredictor); } - public virtual void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) + void IBindableCanSavePfa.SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Assert(ValueMapper is ISingleCanSavePfa); - var mapper = (ISingleCanSavePfa)ValueMapper; + SaveAsPfaCore(ctx, schema, outputNames); + } + [BestFriend] + private protected virtual void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) + { ctx.Hide(outputNames); } - public virtual bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false; + bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + { + Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(schema, nameof(schema)); + Contracts.Assert(ValueMapper is ISingleCanSaveOnnx); + var mapper = (ISingleCanSaveOnnx)ValueMapper; + return SaveAsOnnxCore(ctx, schema, outputNames); + } + + [BestFriend] + private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false; public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) { @@ -271,7 +285,7 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } - public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) + private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); @@ -287,7 +301,7 @@ public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, str ctx.DeclareVar(outputNames[0], scoreToken); } - public override bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); @@ -382,7 +396,7 @@ public override void Save(ModelSaveContext ctx) base.Save(ctx); } - public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) + private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); @@ -402,7 +416,7 @@ public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, str Contracts.Assert(ctx.TokenOrNullForName(outputNames[1]) == probToken.ToString()); } - public override bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + private protected override sealed bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index a675ce7acc..9cdc99d1f9 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -474,7 +474,9 @@ private sealed class ColumnTmp : OneToOneColumn private protected virtual bool CanSavePfaCore => false; - public virtual bool CanSaveOnnx(OnnxContext ctx) => false; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => CanSaveOnnxCore; + + private protected virtual bool CanSaveOnnxCore => false; [BestFriend] private protected OneToOneTransformBase(IHostEnvironment env, string name, OneToOneColumn[] column, @@ -582,10 +584,10 @@ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) ctx.DeclareVar(toDeclare.ToArray()); } - public void SaveAsOnnx(OnnxContext ctx) + void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); - Host.Assert(CanSaveOnnx(ctx)); + Host.Assert(((ICanSaveOnnx)this).CanSaveOnnx(ctx)); for (int iinfo = 0; iinfo < Infos.Length; ++iinfo) { @@ -628,7 +630,8 @@ private protected virtual JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, C return null; } - protected virtual bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, + [BestFriend] + private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) => false; public sealed override Schema Schema => _bindings.AsSchema; diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index ec2bc6edaa..4bf1eb947b 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2839,7 +2839,7 @@ public abstract class FastTreePredictionWrapper : public ColumnType InputType { get; } public ColumnType OutputType => NumberType.Float; bool ICanSavePfa.CanSavePfa => true; - public bool CanSaveOnnx(OnnxContext ctx) => true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs) : base(env, name) @@ -3068,7 +3068,7 @@ private enum AggregateFunction Max } - public virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs b/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs index 326b549a98..2fb4d29b59 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs @@ -49,7 +49,7 @@ private static VersionInfo GetVersionInfo() public ColumnType InputType { get; } public ColumnType OutputType { get; } - public bool CanSaveOnnx(OnnxContext ctx) => true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; private readonly int _dimensionality; private readonly int _k; @@ -282,7 +282,7 @@ public void GetClusterCentroids(ref VBuffer[] centroids, out int k) k = _k; } - public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { // Computation graph of distances to all centriods for a batch of examples. Note that a centriod is just // the center of a cluster. We use [] to denote the dimension of a variable; for example, X [3, 2] means diff --git a/src/Microsoft.ML.Onnx/AssemblyInfo.cs b/src/Microsoft.ML.Onnx/AssemblyInfo.cs index a540e9aff2..2cfc638423 100644 --- a/src/Microsoft.ML.Onnx/AssemblyInfo.cs +++ b/src/Microsoft.ML.Onnx/AssemblyInfo.cs @@ -1,3 +1,9 @@ -using System.Runtime.CompilerServices; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. -[assembly: InternalsVisibleTo("Microsoft.ML.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] +using System.Runtime.CompilerServices; +using Microsoft.ML; + +[assembly: InternalsVisibleTo("Microsoft.ML.Core.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo("Microsoft.ML.Tests" + PublicKey.TestValue)] diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index 241abc6fbb..3c3dd97dcd 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -21,7 +21,7 @@ namespace Microsoft.ML.Runtime.Model.Onnx { - public sealed class SaveOnnxCommand : DataCommand.ImplBase + internal sealed class SaveOnnxCommand : DataCommand.ImplBase { public const string Summary = "Given a data model, write out the corresponding ONNX."; public const string LoadName = "SaveOnnx"; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index eb521da787..f54fd84b62 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -101,7 +101,7 @@ IEnumerator IEnumerable.GetEnumerator() bool ICanSavePfa.CanSavePfa => true; - public bool CanSaveOnnx(OnnxContext ctx) => true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; /// /// Constructs a new linear predictor. @@ -230,7 +230,7 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) return PfaUtils.Call("model.reg.linear", input, cellRef); } - public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); Host.Check(Utils.Size(outputs) == 1); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index ab7745e072..573b1514ce 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -393,7 +393,7 @@ private static VersionInfo GetVersionInfo() public ColumnType InputType { get; } public ColumnType OutputType { get; } bool ICanSavePfa.CanSavePfa => true; - public bool CanSaveOnnx(OnnxContext ctx) => true; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; internal MulticlassLogisticRegressionPredictor(IHostEnvironment env, in VBuffer weights, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null) : base(env, RegistrationName) @@ -894,7 +894,7 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) return PfaUtils.Call("m.link.softmax", PfaUtils.Call("model.reg.linear", input, cellRef)); } - public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); From fe5267669402f919d5a89c4da8909cbef92a6ed8 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 8 Nov 2018 23:06:41 -0800 Subject: [PATCH 07/12] Polish on PFA and creation of model save/load contexts. --- .../Model/ModelLoadContext.cs | 2 +- .../Model/ModelSaveContext.cs | 16 ++++----- src/Microsoft.ML.Data/Model/Pfa/ModelUtils.cs | 2 +- src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs | 3 +- src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs | 3 +- .../Model/Pfa/SavePfaCommand.cs | 2 +- .../Standard/LinearPredictor.cs | 35 ++++++------------- .../MulticlassLogisticRegression.cs | 1 + 8 files changed, 27 insertions(+), 37 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs index f3507277cf..749a226012 100644 --- a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs +++ b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs @@ -71,7 +71,7 @@ public sealed partial class ModelLoadContext : IDisposable /// /// Create a ModelLoadContext supporting loading from a repository, for implementors of ICanSaveModel. /// - public ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir) + internal ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir) { Contracts.CheckValue(rep, nameof(rep)); Repository = rep; diff --git a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs index 4617be093d..eee8d23df7 100644 --- a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs +++ b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs @@ -11,9 +11,9 @@ namespace Microsoft.ML.Runtime.Model { /// /// This is a convenience context object for saving models to a repository, for - /// implementors of ICanSaveModel. It is not mandated but designed to reduce the + /// implementors of . It is not mandated but designed to reduce the /// amount of boiler plate code. It can also be used when saving to a single stream, - /// for implementors of ICanSaveInBinaryFormat. + /// for implementors of . /// public sealed partial class ModelSaveContext : IDisposable { @@ -72,9 +72,9 @@ public sealed partial class ModelSaveContext : IDisposable public bool InRepository { get { return Repository != null; } } /// - /// Create a ModelSaveContext supporting saving to a repository, for implementors of ICanSaveModel. + /// Create a supporting saving to a repository, for implementors of . /// - public ModelSaveContext(RepositoryWriter rep, string dir, string name) + internal ModelSaveContext(RepositoryWriter rep, string dir, string name) { Contracts.CheckValue(rep, nameof(rep)); Repository = rep; @@ -108,9 +108,9 @@ public ModelSaveContext(RepositoryWriter rep, string dir, string name) } /// - /// Create a ModelSaveContext supporting saving to a single-stream, for implementors of ICanSaveInBinaryFormat. + /// Create a supporting saving to a single-stream, for implementors of . /// - public ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null) + internal ModelSaveContext(BinaryWriter writer, IExceptionContext ectx = null) { Contracts.AssertValueOrNull(ectx); _ectx = ectx; @@ -132,7 +132,7 @@ public void CheckAtModel() /// /// Set the version information in the main stream's header. This should be called before - /// Done is called. + /// is called. /// /// public void SetVersionInfo(VersionInfo ver) @@ -215,7 +215,7 @@ public void SaveNonEmptyString(ReadOnlyMemory str) /// /// Commit the save operation. This completes writing of the main stream. When in repository - /// mode, it disposes the Writer (but not the repository). + /// mode, it disposes (but not ). /// public void Done() { diff --git a/src/Microsoft.ML.Data/Model/Pfa/ModelUtils.cs b/src/Microsoft.ML.Data/Model/Pfa/ModelUtils.cs index e772a65334..110296e1d0 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/ModelUtils.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/ModelUtils.cs @@ -6,7 +6,7 @@ namespace Microsoft.ML.Runtime.Model { - public static class ModelUtils + internal static class ModelUtils { private static string ArgCase(string name) { diff --git a/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs b/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs index b21ceaa3f0..2d89916028 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/PfaContext.cs @@ -11,7 +11,8 @@ namespace Microsoft.ML.Runtime.Model.Pfa /// /// A context for defining a restricted sort of PFA output. /// - public sealed class PfaContext + [BestFriend] + internal sealed class PfaContext { public JToken InputType { get; set; } public JToken OutputType { get; set; } diff --git a/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs b/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs index dbfbb99732..9b4f6bcf3f 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/PfaUtils.cs @@ -8,7 +8,8 @@ namespace Microsoft.ML.Runtime.Model.Pfa { - public static class PfaUtils + [BestFriend] + internal static class PfaUtils { public static JObject AddReturn(this JObject toEdit, string name, JToken value) { diff --git a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs index 2f789b7851..57a1f782c7 100644 --- a/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs +++ b/src/Microsoft.ML.Data/Model/Pfa/SavePfaCommand.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Runtime.Model.Pfa { - public sealed class SavePfaCommand : DataCommand.ImplBase + internal sealed class SavePfaCommand : DataCommand.ImplBase { public const string Summary = "Given a data model, write out the corresponding PFA."; public const string LoadName = "SavePfa"; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index f54fd84b62..89382860f9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -63,8 +63,10 @@ private sealed class WeightsCollection : IReadOnlyList public int Count => _pred.Weight.Length; - public Float this[int index] { - get { + public Float this[int index] + { + get + { Contracts.CheckParam(0 <= index && index < Count, nameof(index), "Out of range"); Float value = 0; _pred.Weight.GetItemOrDefault(index, ref value); @@ -439,17 +441,7 @@ private LinearBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx) // (Base class) // LinearModelStatistics: model statistics (optional, in a separate stream) - string statsDir = Path.Combine(ctx.Directory ?? "", ModelStatsSubModelFilename); - using (var statsEntry = ctx.Repository.OpenEntryOrNull(statsDir, ModelLoadContext.ModelStreamName)) - { - if (statsEntry == null) - _stats = null; - else - { - using (var statsCtx = new ModelLoadContext(ctx.Repository, statsEntry, statsDir)) - _stats = LinearModelStatistics.Create(Host, statsCtx); - } - } + ctx.LoadModelOrNull(Host, out _stats, ModelStatsSubModelFilename); } public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) @@ -474,18 +466,12 @@ protected override void SaveCore(ModelSaveContext ctx) // LinearModelStatistics: model statistics (optional, in a separate stream) base.SaveCore(ctx); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + Contracts.AssertValueOrNull(_stats); if (_stats != null) - { - using (var statsCtx = new ModelSaveContext(ctx.Repository, - Path.Combine(ctx.Directory ?? "", ModelStatsSubModelFilename), ModelLoadContext.ModelStreamName)) - { - _stats.Save(statsCtx); - statsCtx.Done(); - } - } - - ctx.SetVersionInfo(GetVersionInfo()); + ctx.SaveModel(_stats, ModelStatsSubModelFilename); } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; @@ -562,7 +548,8 @@ protected RegressionPredictor(IHostEnvironment env, string name, ModelLoadContex { } - public override PredictionKind PredictionKind { + public override PredictionKind PredictionKind + { get { return PredictionKind.Regression; } } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 573b1514ce..8f38ff9b43 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -566,6 +566,7 @@ public static MulticlassLogisticRegressionPredictor Create(IHostEnvironment env, protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); Host.Assert(_biases.Length == _numClasses); From 4908a876e234c56a2a140d9fe01db0ac31e25372 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 8 Nov 2018 23:19:52 -0800 Subject: [PATCH 08/12] Internalize ITree and related types. --- src/Microsoft.ML.Core/Prediction/ITree.cs | 20 +++++++++++---- src/Microsoft.ML.FastTree/FastTree.cs | 10 +++----- .../TreeEnsembleFeaturizer.cs | 25 ++++++++++--------- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/src/Microsoft.ML.Core/Prediction/ITree.cs b/src/Microsoft.ML.Core/Prediction/ITree.cs index 9014eadd68..67642ecfc5 100644 --- a/src/Microsoft.ML.Core/Prediction/ITree.cs +++ b/src/Microsoft.ML.Core/Prediction/ITree.cs @@ -7,10 +7,16 @@ namespace Microsoft.ML.Runtime.TreePredictor { + // The interfaces contained herein are meant to allow tree visualizer to run without an explicit dependency + // on FastTree, so as to allow it greater generality. These should probably be moved somewhere else, but where? + // FastTree itself is not a good candidate since their entire purpose was to avoid tying the tree visualizer + // to FastTree itself. They are semi-tolerable though as a set of internal types here. + /// /// Predictor that has ensemble tree structures and returns collection of trees. /// - public interface ITreeEnsemble + [BestFriend] + internal interface ITreeEnsemble { /// /// Returns the number of trees in the ensemble. @@ -27,7 +33,8 @@ public interface ITreeEnsemble /// /// Type of tree used in ensemble of tree based predictors /// - public interface ITree + [BestFriend] + internal interface ITree { /// /// Returns the array of right(Greater than) child nodes of every interior nodes @@ -63,7 +70,8 @@ public interface ITree /// Type of tree used in ensemble of tree based predictors /// /// Type of features container (instance) on which to make predictions - public interface ITree : ITree + [BestFriend] + internal interface ITree : ITree { /// /// Returns the leaf node for the given instance. @@ -77,7 +85,8 @@ public interface ITree : ITree /// /// Type to represent the structure of node /// - public interface INode + [BestFriend] + internal interface INode { /// /// Returns Key value pairs representing the properties of the node. @@ -88,7 +97,8 @@ public interface INode /// /// Keys to represent the properties of node. /// - public static class NodeKeys + [BestFriend] + internal static class NodeKeys { /// /// Name of the the interior node. It is Feature name if it is fasttree. Type is string for default trees. diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 4bf1eb947b..377a1e4f54 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2818,7 +2818,7 @@ public abstract class FastTreePredictionWrapper : { //The below two properties are necessary for tree Visualizer public TreeEnsemble TrainedEnsemble { get; } - public int NumTrees => TrainedEnsemble.NumTrees; + int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees; // Inner args is used only for documentation purposes when saving comments to INI files. protected readonly string InnerArgs; @@ -3288,7 +3288,7 @@ private static int FindMaxFeatureIndex(TreeEnsemble ensemble) return ifeatMax; } - public ITree[] GetTrees() + ITree[] ITreeEnsemble.GetTrees() { return TrainedEnsemble.Trees.Select(k => new Tree(k)).ToArray(); } @@ -3392,14 +3392,12 @@ public double GetLeafValue(int leafId) private sealed class TreeNode : INode { - private readonly Dictionary _keyValues; - public TreeNode(Dictionary keyValues) { - _keyValues = keyValues; + KeyValues = keyValues; } - public Dictionary KeyValues { get { return _keyValues; } } + public Dictionary KeyValues { get; } } } } diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index f2821ed611..fde6cf4cd0 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -15,6 +15,7 @@ using System; using System.Collections.Generic; using System.IO; +using Microsoft.ML.Runtime.TreePredictor; [assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(TreeEnsembleFeaturizerTransform), typeof(TreeEnsembleFeaturizerBindableMapper.Arguments), typeof(SignatureBindableMapper), "Tree Ensemble Featurizer Mapper", TreeEnsembleFeaturizerBindableMapper.LoadNameShort)] @@ -191,7 +192,7 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper InputRoleMappedSchema = schema; // A vector containing the output of each tree on a given example. - var treeValueType = new VectorType(NumberType.Float, _owner._ensemble.NumTrees); + var treeValueType = new VectorType(NumberType.Float, _owner._ensemble.TrainedEnsemble.NumTrees); // An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example // ends up in all the trees in the ensemble. var leafIdType = new VectorType(NumberType.Float, _owner._totalLeafCount); @@ -202,7 +203,7 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper // plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1, // which means that #internal = #leaf - 1. // Therefore, the number of internal nodes in the ensemble is #leaf - #trees. - var pathIdType = new VectorType(NumberType.Float, _owner._totalLeafCount - _owner._ensemble.NumTrees); + var pathIdType = new VectorType(NumberType.Float, _owner._totalLeafCount - _owner._ensemble.TrainedEnsemble.NumTrees); Schema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType)); } @@ -280,10 +281,10 @@ public State(IExceptionContext ectx, IRow input, FastTreePredictionWrapper ensem _ectx = ectx; _ectx.AssertValue(input); _ectx.AssertValue(ensemble); - _ectx.Assert(ensemble.NumTrees > 0); + _ectx.Assert(ensemble.TrainedEnsemble.NumTrees > 0); _input = input; _ensemble = ensemble; - _numTrees = _ensemble.NumTrees; + _numTrees = _ensemble.TrainedEnsemble.NumTrees; _numLeaves = numLeaves; _src = default(VBuffer); @@ -326,7 +327,7 @@ public void GetLeafIds(ref VBuffer dst) _leafIdBuilder.Reset(_numLeaves, false); var offset = 0; - var trees = _ensemble.GetTrees(); + var trees = ((ITreeEnsemble)_ensemble).GetTrees(); for (int i = 0; i < trees.Length; i++) { _leafIdBuilder.AddFeature(offset + _leafIds[i], 1); @@ -350,7 +351,7 @@ public void GetPathIds(ref VBuffer dst) if (_pathIdBuilder == null) _pathIdBuilder = BufferBuilder.CreateDefault(); - var trees = _ensemble.GetTrees(); + var trees = ((ITreeEnsemble)_ensemble).GetTrees(); _pathIdBuilder.Reset(_numLeaves - _numTrees, dense: false); var offset = 0; for (int i = 0; i < _numTrees; i++) @@ -471,7 +472,7 @@ private static int CountLeaves(FastTreePredictionWrapper ensemble) { Contracts.AssertValue(ensemble); - var trees = ensemble.GetTrees(); + var trees = ((ITreeEnsemble)ensemble).GetTrees(); var numTrees = trees.Length; var totalLeafCount = 0; for (int i = 0; i < numTrees; i++) @@ -481,7 +482,7 @@ private static int CountLeaves(FastTreePredictionWrapper ensemble) private void GetTreeSlotNames(int col, ref VBuffer> dst) { - var numTrees = _ensemble.NumTrees; + var numTrees = _ensemble.TrainedEnsemble.NumTrees; var names = dst.Values; if (Utils.Size(names) < numTrees) @@ -495,7 +496,7 @@ private void GetTreeSlotNames(int col, ref VBuffer> dst) private void GetLeafSlotNames(int col, ref VBuffer> dst) { - var numTrees = _ensemble.NumTrees; + var numTrees = _ensemble.TrainedEnsemble.NumTrees; var names = dst.Values; if (Utils.Size(names) < _totalLeafCount) @@ -503,7 +504,7 @@ private void GetLeafSlotNames(int col, ref VBuffer> dst) int i = 0; int t = 0; - foreach (var tree in _ensemble.GetTrees()) + foreach (var tree in ((ITreeEnsemble)_ensemble).GetTrees()) { for (int l = 0; l < tree.NumLeaves; l++) names[i++] = string.Format("Tree{0:000}Leaf{1:000}", t, l).AsMemory(); @@ -515,7 +516,7 @@ private void GetLeafSlotNames(int col, ref VBuffer> dst) private void GetPathSlotNames(int col, ref VBuffer> dst) { - var numTrees = _ensemble.NumTrees; + var numTrees = _ensemble.TrainedEnsemble.NumTrees; var totalNodeCount = _totalLeafCount - numTrees; var names = dst.Values; @@ -524,7 +525,7 @@ private void GetPathSlotNames(int col, ref VBuffer> dst) int i = 0; int t = 0; - foreach (var tree in _ensemble.GetTrees()) + foreach (var tree in ((ITreeEnsemble)_ensemble).GetTrees()) { var numLeaves = tree.NumLeaves; for (int l = 0; l < tree.NumLeaves - 1; l++) From 58910835573e08ce62bbcddb6ce58b323c7e7503 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 8 Nov 2018 23:31:41 -0800 Subject: [PATCH 09/12] Remove misuse of CheckAtModel. --- src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs | 1 - .../Standard/LogisticRegression/MulticlassLogisticRegression.cs | 1 - 2 files changed, 2 deletions(-) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 89382860f9..5f3c0c72cc 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -466,7 +466,6 @@ protected override void SaveCore(ModelSaveContext ctx) // LinearModelStatistics: model statistics (optional, in a separate stream) base.SaveCore(ctx); - ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); Contracts.AssertValueOrNull(_stats); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 8f38ff9b43..573b1514ce 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -566,7 +566,6 @@ public static MulticlassLogisticRegressionPredictor Create(IHostEnvironment env, protected override void SaveCore(ModelSaveContext ctx) { base.SaveCore(ctx); - ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); Host.Assert(_biases.Length == _numClasses); From e3e2e51dc83c5d79ea22c4930c5481897fd0bd86 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 9 Nov 2018 09:21:52 -0800 Subject: [PATCH 10/12] Finish forgotten utils since I apparently got distracted --- .../Utilities/ReservoirSampler.cs | 9 ++++--- .../Utilities/ResourceManagerUtils.cs | 5 +++- src/Microsoft.ML.Core/Utilities/Stats.cs | 3 ++- .../Utilities/SubsetStream.cs | 3 ++- .../Utilities/SummaryStatistics.cs | 8 ++++--- .../Utilities/SupervisedBinFinder.cs | 3 ++- .../Utilities/TextReaderStream.cs | 3 ++- .../Utilities/ThreadUtils.cs | 6 +++-- src/Microsoft.ML.Core/Utilities/Tree.cs | 3 ++- .../Utilities/VBufferUtils.cs | 3 ++- .../DataLoadSave/Binary/Zlib/Zlib.cs | 2 ++ .../Evaluators/AucAggregator.cs | 16 ++++++------- .../Evaluators/BinaryClassifierEvaluator.cs | 2 +- .../Properties/AssemblyInfo.cs | 1 + .../Utilities/StreamUtils.cs | 3 ++- src/Microsoft.ML.Data/Utilities/TimerScope.cs | 7 +++--- .../{IntSequencePool.cs => SequencePool.cs} | 9 ++----- .../DatasetFeaturesInference.cs | 24 +++++++++---------- 18 files changed, 62 insertions(+), 48 deletions(-) rename src/Microsoft.ML.Data/Utils/{IntSequencePool.cs => SequencePool.cs} (98%) diff --git a/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs b/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs index acc7d4f3a2..8438a2e742 100644 --- a/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs +++ b/src/Microsoft.ML.Core/Utilities/ReservoirSampler.cs @@ -13,7 +13,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// The sample is created in one pass by calling for every data point in the stream. Implementations should have /// a delegate for getting the next data point, which is invoked if the current data point should go into the reservoir. /// - public interface IReservoirSampler + [BestFriend] + internal interface IReservoirSampler { /// /// If the number of elements sampled is less than the reservoir size, this should return the number of elements sampled. @@ -49,7 +50,8 @@ public interface IReservoirSampler /// for every data point in the stream. In case the next data point does not get 'picked' into the reservoir, the delegate is not invoked. /// Sampling is done according to the algorithm in this paper: https://epubs.siam.org/doi/pdf/10.1137/1.9781611972740.53. /// - public sealed class ReservoirSamplerWithoutReplacement : IReservoirSampler + [BestFriend] + internal sealed class ReservoirSamplerWithoutReplacement : IReservoirSampler { // This array contains a cache of the elements composing the reservoir. private readonly T[] _cache; @@ -122,7 +124,8 @@ public IEnumerable GetSample() /// for every data point in the stream. In case the next data point does not get 'picked' into the reservoir, the delegate is not invoked. /// Sampling is done according to the algorithm in this paper: https://epubs.siam.org/doi/pdf/10.1137/1.9781611972740.53. /// - public sealed class ReservoirSamplerWithReplacement : IReservoirSampler + [BestFriend] + internal sealed class ReservoirSamplerWithReplacement : IReservoirSampler { // This array contains pointers to the elements in the _cache array that are currently in the reservoir (may contain duplicates). private readonly int[] _reservoir; diff --git a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs index 4ea8f0d5b7..9e9c6f80bb 100644 --- a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs @@ -16,7 +16,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// This class takes care of downloading resources needed by ML.NET components. Resources are located in /// a pre-defined location, that can be overridden by defining Environment variable . /// - public sealed class ResourceManagerUtils + [BestFriend] + internal sealed class ResourceManagerUtils { private static volatile ResourceManagerUtils _instance; public static ResourceManagerUtils Instance @@ -301,7 +302,9 @@ public static ResourceDownloadResults GetErrorMessage(out string errorMessage, p return errorResult; } +#pragma warning disable IDE1006 [DllImport("libc", SetLastError = true)] private static extern int chmod(string pathname, int mode); +#pragma warning restore IDE1006 } } diff --git a/src/Microsoft.ML.Core/Utilities/Stats.cs b/src/Microsoft.ML.Core/Utilities/Stats.cs index 119ac22946..182239adde 100644 --- a/src/Microsoft.ML.Core/Utilities/Stats.cs +++ b/src/Microsoft.ML.Core/Utilities/Stats.cs @@ -11,7 +11,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// /// A class containing common statistical functions /// - public static class Stats + [BestFriend] + internal static class Stats { /// /// Returns a number uniformly sampled from 0...(rangeSize-1) diff --git a/src/Microsoft.ML.Core/Utilities/SubsetStream.cs b/src/Microsoft.ML.Core/Utilities/SubsetStream.cs index 84cf4826e6..3bf9ad2c9e 100644 --- a/src/Microsoft.ML.Core/Utilities/SubsetStream.cs +++ b/src/Microsoft.ML.Core/Utilities/SubsetStream.cs @@ -22,7 +22,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// subset stream, the underlying stream will always remain open and /// undisposed. /// - public sealed class SubsetStream : Stream + [BestFriend] + internal sealed class SubsetStream : Stream { private readonly Stream _stream; // The position of the stream. diff --git a/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs b/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs index 3843342a2c..3e36191564 100644 --- a/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs +++ b/src/Microsoft.ML.Core/Utilities/SummaryStatistics.cs @@ -6,7 +6,7 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { - public abstract class SummaryStatisticsBase + internal abstract class SummaryStatisticsBase { // Sum of squared difference from the current mean. protected double M2; @@ -152,7 +152,8 @@ public void Add(SummaryStatisticsBase s) } } - public sealed class SummaryStatisticsUpToSecondOrderMoments : SummaryStatisticsBase + [BestFriend] + internal sealed class SummaryStatisticsUpToSecondOrderMoments : SummaryStatisticsBase { /// /// A convenient way to combine the observations of two Stats objects @@ -177,7 +178,8 @@ public sealed class SummaryStatisticsUpToSecondOrderMoments : SummaryStatisticsB /// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance /// All quantities are weighted, except for RawCount. /// - public sealed class SummaryStatistics : SummaryStatisticsBase + [BestFriend] + internal sealed class SummaryStatistics : SummaryStatisticsBase { // Sum of cubed difference from the current mean. private double _m3; diff --git a/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs b/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs index bb76952c25..63257823a2 100644 --- a/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs +++ b/src/Microsoft.ML.Core/Utilities/SupervisedBinFinder.cs @@ -20,7 +20,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// The class can be used several times sequentially, it is stateful and not thread-safe. /// Both Single and Double precision processing is implemented, and is identical. /// - public sealed class SupervisedBinFinder + [BestFriend] + internal sealed class SupervisedBinFinder { private readonly struct ValuePair : IComparable> where T : IComparable diff --git a/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs b/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs index 5c05275ba7..682a0336cb 100644 --- a/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs +++ b/src/Microsoft.ML.Core/Utilities/TextReaderStream.cs @@ -14,7 +14,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// compensates by inserting \n line feed characters at the end of every /// input line, including the last one. /// - public sealed class TextReaderStream : Stream + [BestFriend] + internal sealed class TextReaderStream : Stream { private readonly TextReader _baseReader; private readonly Encoding _encoding; diff --git a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs index 4756337b39..46a82a4e7c 100644 --- a/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/ThreadUtils.cs @@ -59,7 +59,8 @@ public static Thread CreateForegroundThread(ThreadStart start) /// that the workers have finished by its own means, will call to throw /// the set exception as an inner exception, in the wrapping thread. /// - public sealed class ExceptionMarshaller : IDisposable + [BestFriend] + internal sealed class ExceptionMarshaller : IDisposable { private readonly CancellationTokenSource _ctSource; private readonly object _lock; @@ -130,7 +131,8 @@ public void ThrowIfSet(IExceptionContext ectx) /// Provides a task scheduler that ensures a maximum concurrency level while /// running on top of the ThreadPool. /// - public sealed class LimitedConcurrencyLevelTaskScheduler : TaskScheduler + [BestFriend] + internal sealed class LimitedConcurrencyLevelTaskScheduler : TaskScheduler { // Whether the current thread is processing work items. [ThreadStatic] diff --git a/src/Microsoft.ML.Core/Utilities/Tree.cs b/src/Microsoft.ML.Core/Utilities/Tree.cs index 7d030cf46c..8d4c0f7585 100644 --- a/src/Microsoft.ML.Core/Utilities/Tree.cs +++ b/src/Microsoft.ML.Core/Utilities/Tree.cs @@ -17,7 +17,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// /// Children are keyed with values of this type /// The type of the node value - public sealed class Tree : IDictionary> + [BestFriend] + internal sealed class Tree : IDictionary> { // The key of this node in the parent, assuming this is a child node at all. // This back reference is necessary to complete any "remove" operations. diff --git a/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs b/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs index 19a7819325..bc1e7f4f7f 100644 --- a/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/VBufferUtils.cs @@ -14,7 +14,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// /// Convenience utilities for vector operations on . /// - public static class VBufferUtils + [BestFriend] + internal static class VBufferUtils { private const float SparsityThreshold = 0.25f; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs index 024eaef4a2..7b2ae812a8 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Zlib/Zlib.cs @@ -12,6 +12,7 @@ internal static class Zlib { public const string DllPath = "zlib.dll"; +#pragma warning disable IDE1006 [DllImport(DllPath), SuppressUnmanagedCodeSecurity] private static extern unsafe Constants.RetCode deflateInit2_(ZStream* strm, int level, int method, int windowBits, int memLevel, Constants.Strategy strategy, byte* version, int streamSize); @@ -44,6 +45,7 @@ public static unsafe Constants.RetCode InflateInit2(ZStream* strm, int windowBit [DllImport(DllPath), SuppressUnmanagedCodeSecurity] public static extern unsafe Constants.RetCode inflateEnd(ZStream* strm); +#pragma warning restore IDE1006 } [StructLayout(LayoutKind.Sequential)] diff --git a/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs b/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs index 1390015697..87bbf897d6 100644 --- a/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AucAggregator.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Runtime.Data { public abstract partial class EvaluatorBase { - protected abstract class AucAggregatorBase + internal abstract class AucAggregatorBase { protected Single Score; protected Single Label; @@ -30,7 +30,7 @@ public void ProcessRow(Single label, Single score, Single weight = 1) public abstract Double ComputeWeightedAuc(out Double unweighted); } - protected abstract class AucAggregatorBase : AucAggregatorBase + internal abstract class AucAggregatorBase : AucAggregatorBase { private readonly ReservoirSamplerWithoutReplacement _posReservoir; private readonly ReservoirSamplerWithoutReplacement _negReservoir; @@ -117,7 +117,7 @@ public override Double ComputeWeightedAuc(out Double unweighted) protected abstract Double ComputeWeightedAucCore(out double unweighted); } - protected sealed class UnweightedAucAggregator : AucAggregatorBase + internal sealed class UnweightedAucAggregator : AucAggregatorBase { public UnweightedAucAggregator(IRandom rand, int reservoirSize) : base(rand, reservoirSize) @@ -210,7 +210,7 @@ protected override void AddExample(List examples) } } - protected sealed class WeightedAucAggregator : AucAggregatorBase + internal sealed class WeightedAucAggregator : AucAggregatorBase { public struct AucInfo { @@ -345,7 +345,7 @@ protected override void AddExample(List examples) } } - public abstract class AuPrcAggregatorBase + internal abstract class AuPrcAggregatorBase { protected Single Score; protected Single Label; @@ -364,7 +364,7 @@ public void ProcessRow(Single label, Single score, Single weight = 1) public abstract Double ComputeWeightedAuPrc(out Double unweighted); } - protected abstract class AuPrcAggregatorBase : AuPrcAggregatorBase + private protected abstract class AuPrcAggregatorBase : AuPrcAggregatorBase { protected readonly ReservoirSamplerWithoutReplacement Reservoir; @@ -393,7 +393,7 @@ public override Double ComputeWeightedAuPrc(out Double unweighted) protected abstract Double ComputeWeightedAuPrcCore(out Double unweighted); } - protected sealed class UnweightedAuPrcAggregator : AuPrcAggregatorBase + private protected sealed class UnweightedAuPrcAggregator : AuPrcAggregatorBase { public struct Info { @@ -466,7 +466,7 @@ protected override ValueGetter GetSampleGetter() } } - protected sealed class WeightedAuPrcAggregator : AuPrcAggregatorBase + private protected sealed class WeightedAuPrcAggregator : AuPrcAggregatorBase { public struct Info { diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index bee170ecef..c2cea7c48b 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -535,7 +535,7 @@ private struct RocInfo public readonly List WeightedRecall; public readonly List WeightedFalsePositiveRate; - public readonly AuPrcAggregatorBase AuPrcAggregator; + internal readonly AuPrcAggregatorBase AuPrcAggregator; public double WeightedAuPrc; public double UnweightedAuPrc; diff --git a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs index 5fed334c9d..dde426f0d6 100644 --- a/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Data/Properties/AssemblyInfo.cs @@ -10,6 +10,7 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.ResultProcessor" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Data" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Api" + PublicKey.Value)] diff --git a/src/Microsoft.ML.Data/Utilities/StreamUtils.cs b/src/Microsoft.ML.Data/Utilities/StreamUtils.cs index ac05684a8e..5d52b45b3b 100644 --- a/src/Microsoft.ML.Data/Utilities/StreamUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/StreamUtils.cs @@ -9,7 +9,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities { // REVIEW: Implement properly on CoreCLR. - public static class StreamUtils + [BestFriend] + internal static class StreamUtils { public static Stream OpenInStream(string fileName) { diff --git a/src/Microsoft.ML.Data/Utilities/TimerScope.cs b/src/Microsoft.ML.Data/Utilities/TimerScope.cs index 403d8da81d..c2a590ffe8 100644 --- a/src/Microsoft.ML.Data/Utilities/TimerScope.cs +++ b/src/Microsoft.ML.Data/Utilities/TimerScope.cs @@ -3,17 +3,16 @@ // See the LICENSE file in the project root for more information. using System; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; namespace Microsoft.ML.Runtime.Internal.Utilities { using Stopwatch = System.Diagnostics.Stopwatch; /// - /// A timer scope class that starts a Stopwatch when created, calculates and prints elapsed time, physical and virtual memory usages before sending these to the telemetry when disposed. + /// A timer scope class that starts a when created, calculates and prints elapsed time, physical and virtual memory usages before sending these to the telemetry when disposed. /// - public sealed class TimerScope : IDisposable + [BestFriend] + internal sealed class TimerScope : IDisposable { // Note that this class does not own nor dispose of this channel. private readonly IChannel _ch; diff --git a/src/Microsoft.ML.Data/Utils/IntSequencePool.cs b/src/Microsoft.ML.Data/Utils/SequencePool.cs similarity index 98% rename from src/Microsoft.ML.Data/Utils/IntSequencePool.cs rename to src/Microsoft.ML.Data/Utils/SequencePool.cs index 9a83d17168..6b9ebe01e3 100644 --- a/src/Microsoft.ML.Data/Utils/IntSequencePool.cs +++ b/src/Microsoft.ML.Data/Utils/SequencePool.cs @@ -3,13 +3,7 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; -using System.Collections; using System.IO; -using System.Linq; -using System.Threading; -using System.Text; -using Microsoft.ML.Runtime.Model; namespace Microsoft.ML.Runtime.Internal.Utilities { @@ -19,7 +13,8 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// A dictionary of uint sequences of variable length. Stores the sequences as /// byte sequences encoded with LEB128. Empty sequences (or null) are also valid. /// - public sealed class SequencePool + [BestFriend] + internal sealed class SequencePool { // uint sequences are hashed into _mask+1 buckets. _buckets contains the ID of the first // sequence that falls in it (or -1 if it is empty). diff --git a/src/Microsoft.ML.PipelineInference/DatasetFeaturesInference.cs b/src/Microsoft.ML.PipelineInference/DatasetFeaturesInference.cs index f51c0871f8..235a85e96c 100644 --- a/src/Microsoft.ML.PipelineInference/DatasetFeaturesInference.cs +++ b/src/Microsoft.ML.PipelineInference/DatasetFeaturesInference.cs @@ -21,19 +21,19 @@ public static class DatasetFeatureInference { public sealed class Stats { - [JsonIgnore] public SummaryStatistics Statistics; + [JsonIgnore] private SummaryStatistics _statistics; [JsonIgnore] public double Sum; public Stats() { - Statistics = new SummaryStatistics(); + _statistics = new SummaryStatistics(); } public void Add(double x) { Sum += x; - Statistics.Add(x); + _statistics.Add(x); } public void Add(IEnumerable x) @@ -43,31 +43,31 @@ public void Add(IEnumerable x) } [JsonProperty] - public long Count => Statistics.RawCount; + public long Count => _statistics.RawCount; [JsonProperty] - public double? NonZeroValueCount => Statistics.RawCount > 20 ? (double?)Statistics.Nonzero : null; + public double? NonZeroValueCount => _statistics.RawCount > 20 ? (double?)_statistics.Nonzero : null; [JsonProperty] - public double? Variance => Statistics.RawCount > 20 ? (double?)Statistics.SampleVariance : null; + public double? Variance => _statistics.RawCount > 20 ? (double?)_statistics.SampleVariance : null; [JsonProperty] - public double? StandardDeviation => Statistics.RawCount > 20 ? (double?)Statistics.SampleStdDev : null; + public double? StandardDeviation => _statistics.RawCount > 20 ? (double?)_statistics.SampleStdDev : null; [JsonProperty] - public double? Skewness => Statistics.RawCount > 20 ? (double?)Statistics.Skewness : null; + public double? Skewness => _statistics.RawCount > 20 ? (double?)_statistics.Skewness : null; [JsonProperty] - public double? Kurtosis => Statistics.RawCount > 20 ? (double?)Statistics.Kurtosis : null; + public double? Kurtosis => _statistics.RawCount > 20 ? (double?)_statistics.Kurtosis : null; [JsonProperty] - public double? Mean => Statistics.RawCount > 20 ? (double?)Statistics.Mean : null; + public double? Mean => _statistics.RawCount > 20 ? (double?)_statistics.Mean : null; [JsonIgnore] - public double Min => Statistics.Min; + public double Min => _statistics.Min; [JsonIgnore] - public double Max => Statistics.Max; + public double Max => _statistics.Max; } public sealed class Column From 0d3a10bf25d826239b38a104a26ec222a1628305 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 9 Nov 2018 10:40:53 -0800 Subject: [PATCH 11/12] Contracts, misc fix --- src/Microsoft.ML.Core/Utilities/Contracts.cs | 979 +++++++++--------- .../Scorers/SchemaBindablePredictorWrapper.cs | 2 +- .../TreeEnsembleFeaturizer.cs | 2 +- 3 files changed, 491 insertions(+), 492 deletions(-) diff --git a/src/Microsoft.ML.Core/Utilities/Contracts.cs b/src/Microsoft.ML.Core/Utilities/Contracts.cs index 671bf41db8..e1ee11a0d4 100644 --- a/src/Microsoft.ML.Core/Utilities/Contracts.cs +++ b/src/Microsoft.ML.Core/Utilities/Contracts.cs @@ -55,11 +55,10 @@ internal enum MessageSensitivity } #endif -#if PRIVATE_CONTRACTS - internal static partial class Contracts -#else - public static partial class Contracts +#if !PRIVATE_CONTRACTS + [BestFriend] #endif + internal static partial class Contracts { public const string IsMarkedKey = "ML_IsMarked"; public const string SensitivityKey = "ML_Sensitivity"; @@ -467,281 +466,281 @@ private static string MakeSchemaMismatchMsg(string columnRole, string columnName return $"Schema mismatch for {columnRole} column '{columnName}': expected {expectedType}, got {actualType}"; } - // Check - these check a condition and if it fails, throw the corresponding exception. - // NOTE: The ordering of arguments to these is standardized to be: - // * boolean condition - // * parameter name - // * parameter value - // * message string - // - // Note that these do NOT support a params array of arguments since that would - // involve memory allocation whenever the condition is checked. When message string - // args are need, the condition test should be inlined, eg: - // if (!condition) - // throw Contracts.ExceptXxx(fmt, arg1, arg2); - - public static void Check(bool f) - { - if (!f) - throw Except(); - } - public static void Check(this IExceptionContext ctx, bool f) - { - if (!f) - throw Except(ctx); - } - public static void Check(bool f, string msg) - { - if (!f) - throw Except(msg); - } - public static void Check(this IExceptionContext ctx, bool f, string msg) - { - if (!f) - throw Except(ctx, msg); - } + // Check - these check a condition and if it fails, throw the corresponding exception. + // NOTE: The ordering of arguments to these is standardized to be: + // * boolean condition + // * parameter name + // * parameter value + // * message string + // + // Note that these do NOT support a params array of arguments since that would + // involve memory allocation whenever the condition is checked. When message string + // args are need, the condition test should be inlined, eg: + // if (!condition) + // throw Contracts.ExceptXxx(fmt, arg1, arg2); + + public static void Check(bool f) + { + if (!f) + throw Except(); + } + public static void Check(this IExceptionContext ctx, bool f) + { + if (!f) + throw Except(ctx); + } + public static void Check(bool f, string msg) + { + if (!f) + throw Except(msg); + } + public static void Check(this IExceptionContext ctx, bool f, string msg) + { + if (!f) + throw Except(ctx, msg); + } - /// - /// CheckUserArg / ExceptUserArg should be used when the validation of user-provided arguments failed. - /// Typically, this is shortly after the arguments are parsed using CmdParser. - /// - public static void CheckUserArg(bool f, string name) - { - if (!f) - throw ExceptUserArg(name); - } - public static void CheckUserArg(this IExceptionContext ctx, bool f, string name) - { - if (!f) - throw ExceptUserArg(ctx, name); - } - public static void CheckUserArg(bool f, string name, string msg) - { - if (!f) - throw ExceptUserArg(name, msg); - } - public static void CheckUserArg(this IExceptionContext ctx, bool f, string name, string msg) - { - if (!f) - throw ExceptUserArg(ctx, name, msg); - } + /// + /// CheckUserArg / ExceptUserArg should be used when the validation of user-provided arguments failed. + /// Typically, this is shortly after the arguments are parsed using CmdParser. + /// + public static void CheckUserArg(bool f, string name) + { + if (!f) + throw ExceptUserArg(name); + } + public static void CheckUserArg(this IExceptionContext ctx, bool f, string name) + { + if (!f) + throw ExceptUserArg(ctx, name); + } + public static void CheckUserArg(bool f, string name, string msg) + { + if (!f) + throw ExceptUserArg(name, msg); + } + public static void CheckUserArg(this IExceptionContext ctx, bool f, string name, string msg) + { + if (!f) + throw ExceptUserArg(ctx, name, msg); + } - public static void CheckParam(bool f, string paramName) - { - if (!f) - throw ExceptParam(paramName); - } - public static void CheckParam(this IExceptionContext ctx, bool f, string paramName) - { - if (!f) - throw ExceptParam(ctx, paramName); - } - public static void CheckParam(bool f, string paramName, string msg) - { - if (!f) - throw ExceptParam(paramName, msg); - } - public static void CheckParam(this IExceptionContext ctx, bool f, string paramName, string msg) - { - if (!f) - throw ExceptParam(ctx, paramName, msg); - } - public static void CheckParamValue(bool f, T value, string paramName, string msg) - { - if (!f) - throw ExceptParamValue(value, paramName, msg); - } - public static void CheckParamValue(this IExceptionContext ctx, bool f, T value, string paramName, string msg) - { - if (!f) - throw ExceptParamValue(ctx, value, paramName, msg); - } + public static void CheckParam(bool f, string paramName) + { + if (!f) + throw ExceptParam(paramName); + } + public static void CheckParam(this IExceptionContext ctx, bool f, string paramName) + { + if (!f) + throw ExceptParam(ctx, paramName); + } + public static void CheckParam(bool f, string paramName, string msg) + { + if (!f) + throw ExceptParam(paramName, msg); + } + public static void CheckParam(this IExceptionContext ctx, bool f, string paramName, string msg) + { + if (!f) + throw ExceptParam(ctx, paramName, msg); + } + public static void CheckParamValue(bool f, T value, string paramName, string msg) + { + if (!f) + throw ExceptParamValue(value, paramName, msg); + } + public static void CheckParamValue(this IExceptionContext ctx, bool f, T value, string paramName, string msg) + { + if (!f) + throw ExceptParamValue(ctx, value, paramName, msg); + } - public static T CheckRef(T val, string paramName) where T : class - { - if (object.ReferenceEquals(val, null)) - throw ExceptValue(paramName); - return val; - } - public static T CheckRef(this IExceptionContext ctx, T val, string paramName) where T : class - { - if (object.ReferenceEquals(val, null)) - throw ExceptValue(ctx, paramName); - return val; - } + public static T CheckRef(T val, string paramName) where T : class + { + if (object.ReferenceEquals(val, null)) + throw ExceptValue(paramName); + return val; + } + public static T CheckRef(this IExceptionContext ctx, T val, string paramName) where T : class + { + if (object.ReferenceEquals(val, null)) + throw ExceptValue(ctx, paramName); + return val; + } - public static T CheckRef(this IExceptionContext ctx, T val, string paramName, string msg) where T : class - { - if (object.ReferenceEquals(val, null)) - throw ExceptValue(ctx, paramName, msg); - return val; - } + public static T CheckRef(this IExceptionContext ctx, T val, string paramName, string msg) where T : class + { + if (object.ReferenceEquals(val, null)) + throw ExceptValue(ctx, paramName, msg); + return val; + } public static void CheckValue(T val, string paramName) where T : class - { - if (object.ReferenceEquals(val, null)) - throw ExceptValue(paramName); - } - public static void CheckValue(this IExceptionContext ctx, T val, string paramName) where T : class - { - if (object.ReferenceEquals(val, null)) - throw ExceptValue(ctx, paramName); - } - public static T CheckValue(T val, string paramName, string msg) where T : class - { - if (object.ReferenceEquals(val, null)) - throw ExceptValue(paramName, msg); - return val; - } - public static T CheckValue(this IExceptionContext ctx, T val, string paramName, string msg) where T : class - { - if (object.ReferenceEquals(val, null)) - throw ExceptValue(ctx, paramName, msg); - return val; - } + { + if (object.ReferenceEquals(val, null)) + throw ExceptValue(paramName); + } + public static void CheckValue(this IExceptionContext ctx, T val, string paramName) where T : class + { + if (object.ReferenceEquals(val, null)) + throw ExceptValue(ctx, paramName); + } + public static T CheckValue(T val, string paramName, string msg) where T : class + { + if (object.ReferenceEquals(val, null)) + throw ExceptValue(paramName, msg); + return val; + } + public static T CheckValue(this IExceptionContext ctx, T val, string paramName, string msg) where T : class + { + if (object.ReferenceEquals(val, null)) + throw ExceptValue(ctx, paramName, msg); + return val; + } - public static string CheckNonEmpty(string s, string paramName) - { - if (string.IsNullOrEmpty(s)) - throw ExceptEmpty(paramName); - return s; - } - public static string CheckNonEmpty(this IExceptionContext ctx, string s, string paramName) - { - if (string.IsNullOrEmpty(s)) - throw ExceptEmpty(ctx, paramName); - return s; - } + public static string CheckNonEmpty(string s, string paramName) + { + if (string.IsNullOrEmpty(s)) + throw ExceptEmpty(paramName); + return s; + } + public static string CheckNonEmpty(this IExceptionContext ctx, string s, string paramName) + { + if (string.IsNullOrEmpty(s)) + throw ExceptEmpty(ctx, paramName); + return s; + } - public static string CheckNonWhiteSpace(string s, string paramName) - { - if (string.IsNullOrWhiteSpace(s)) - throw ExceptWhiteSpace(paramName); - return s; - } - public static string CheckNonWhiteSpace(this IExceptionContext ctx, string s, string paramName) - { - if (string.IsNullOrWhiteSpace(s)) - throw ExceptWhiteSpace(ctx, paramName); - return s; - } + public static string CheckNonWhiteSpace(string s, string paramName) + { + if (string.IsNullOrWhiteSpace(s)) + throw ExceptWhiteSpace(paramName); + return s; + } + public static string CheckNonWhiteSpace(this IExceptionContext ctx, string s, string paramName) + { + if (string.IsNullOrWhiteSpace(s)) + throw ExceptWhiteSpace(ctx, paramName); + return s; + } - public static string CheckNonEmpty(string s, string paramName, string msg) - { - if (string.IsNullOrEmpty(s)) - throw ExceptEmpty(paramName, msg); - return s; - } - public static string CheckNonEmpty(this IExceptionContext ctx, string s, string paramName, string msg) - { - if (string.IsNullOrEmpty(s)) - throw ExceptEmpty(ctx, paramName, msg); - return s; - } + public static string CheckNonEmpty(string s, string paramName, string msg) + { + if (string.IsNullOrEmpty(s)) + throw ExceptEmpty(paramName, msg); + return s; + } + public static string CheckNonEmpty(this IExceptionContext ctx, string s, string paramName, string msg) + { + if (string.IsNullOrEmpty(s)) + throw ExceptEmpty(ctx, paramName, msg); + return s; + } - public static string CheckNonWhiteSpace(string s, string paramName, string msg) - { - if (string.IsNullOrWhiteSpace(s)) - throw ExceptWhiteSpace(paramName, msg); - return s; - } - public static string CheckNonWhiteSpace(this IExceptionContext ctx, string s, string paramName, string msg) - { - if (string.IsNullOrWhiteSpace(s)) - throw ExceptWhiteSpace(ctx, paramName, msg); - return s; - } + public static string CheckNonWhiteSpace(string s, string paramName, string msg) + { + if (string.IsNullOrWhiteSpace(s)) + throw ExceptWhiteSpace(paramName, msg); + return s; + } + public static string CheckNonWhiteSpace(this IExceptionContext ctx, string s, string paramName, string msg) + { + if (string.IsNullOrWhiteSpace(s)) + throw ExceptWhiteSpace(ctx, paramName, msg); + return s; + } - public static T[] CheckNonEmpty(T[] args, string paramName) - { - if (Size(args) == 0) - throw ExceptEmpty(paramName); - return args; - } - public static T[] CheckNonEmpty(this IExceptionContext ctx, T[] args, string paramName) - { - if (Size(args) == 0) - throw ExceptEmpty(ctx, paramName); - return args; - } - public static T[] CheckNonEmpty(T[] args, string paramName, string msg) - { - if (Size(args) == 0) - throw ExceptEmpty(paramName, msg); - return args; - } - public static T[] CheckNonEmpty(this IExceptionContext ctx, T[] args, string paramName, string msg) - { - if (Size(args) == 0) - throw ExceptEmpty(ctx, paramName, msg); - return args; - } - public static ICollection CheckNonEmpty(ICollection args, string paramName) - { - if (Size(args) == 0) - throw ExceptEmpty(paramName); - return args; - } - public static ICollection CheckNonEmpty(this IExceptionContext ctx, ICollection args, string paramName) - { - if (Size(args) == 0) - throw ExceptEmpty(ctx, paramName); - return args; - } - public static ICollection CheckNonEmpty(ICollection args, string paramName, string msg) - { - if (Size(args) == 0) - throw ExceptEmpty(paramName, msg); - return args; - } - public static ICollection CheckNonEmpty(this IExceptionContext ctx, ICollection args, string paramName, string msg) - { - if (Size(args) == 0) - throw ExceptEmpty(ctx, paramName, msg); - return args; - } + public static T[] CheckNonEmpty(T[] args, string paramName) + { + if (Size(args) == 0) + throw ExceptEmpty(paramName); + return args; + } + public static T[] CheckNonEmpty(this IExceptionContext ctx, T[] args, string paramName) + { + if (Size(args) == 0) + throw ExceptEmpty(ctx, paramName); + return args; + } + public static T[] CheckNonEmpty(T[] args, string paramName, string msg) + { + if (Size(args) == 0) + throw ExceptEmpty(paramName, msg); + return args; + } + public static T[] CheckNonEmpty(this IExceptionContext ctx, T[] args, string paramName, string msg) + { + if (Size(args) == 0) + throw ExceptEmpty(ctx, paramName, msg); + return args; + } + public static ICollection CheckNonEmpty(ICollection args, string paramName) + { + if (Size(args) == 0) + throw ExceptEmpty(paramName); + return args; + } + public static ICollection CheckNonEmpty(this IExceptionContext ctx, ICollection args, string paramName) + { + if (Size(args) == 0) + throw ExceptEmpty(ctx, paramName); + return args; + } + public static ICollection CheckNonEmpty(ICollection args, string paramName, string msg) + { + if (Size(args) == 0) + throw ExceptEmpty(paramName, msg); + return args; + } + public static ICollection CheckNonEmpty(this IExceptionContext ctx, ICollection args, string paramName, string msg) + { + if (Size(args) == 0) + throw ExceptEmpty(ctx, paramName, msg); + return args; + } - public static void CheckDecode(bool f) - { - if (!f) - throw ExceptDecode(); - } - public static void CheckDecode(this IExceptionContext ctx, bool f) - { - if (!f) - throw ExceptDecode(ctx); - } - public static void CheckDecode(bool f, string msg) - { - if (!f) - throw ExceptDecode(msg); - } - public static void CheckDecode(this IExceptionContext ctx, bool f, string msg) - { - if (!f) - throw ExceptDecode(ctx, msg); - } + public static void CheckDecode(bool f) + { + if (!f) + throw ExceptDecode(); + } + public static void CheckDecode(this IExceptionContext ctx, bool f) + { + if (!f) + throw ExceptDecode(ctx); + } + public static void CheckDecode(bool f, string msg) + { + if (!f) + throw ExceptDecode(msg); + } + public static void CheckDecode(this IExceptionContext ctx, bool f, string msg) + { + if (!f) + throw ExceptDecode(ctx, msg); + } - public static void CheckIO(bool f) - { - if (!f) - throw ExceptIO(); - } - public static void CheckIO(this IExceptionContext ctx, bool f) - { - if (!f) - throw ExceptIO(ctx); - } - public static void CheckIO(bool f, string msg) - { - if (!f) - throw ExceptIO(msg); - } - public static void CheckIO(this IExceptionContext ctx, bool f, string msg) - { - if (!f) - throw ExceptIO(ctx, msg); - } + public static void CheckIO(bool f) + { + if (!f) + throw ExceptIO(); + } + public static void CheckIO(this IExceptionContext ctx, bool f) + { + if (!f) + throw ExceptIO(ctx); + } + public static void CheckIO(bool f, string msg) + { + if (!f) + throw ExceptIO(msg); + } + public static void CheckIO(this IExceptionContext ctx, bool f, string msg) + { + if (!f) + throw ExceptIO(ctx, msg); + } #if !PRIVATE_CONTRACTS /// @@ -753,247 +752,247 @@ public static void CheckAlive(this IHostEnvironment env) throw Process(new OperationCanceledException("Operation was cancelled."), env); } #endif - /// - /// This documents that the parameter can legally be null. - /// - [Conditional("INVARIANT_CHECKS")] - public static void CheckValueOrNull(T val) where T : class - { - } - [Conditional("INVARIANT_CHECKS")] - public static void CheckValueOrNull(this IExceptionContext ctx, T val) where T : class - { - } + /// + /// This documents that the parameter can legally be null. + /// + [Conditional("INVARIANT_CHECKS")] + public static void CheckValueOrNull(T val) where T : class + { + } + [Conditional("INVARIANT_CHECKS")] + public static void CheckValueOrNull(this IExceptionContext ctx, T val) where T : class + { + } - // Assert + // Assert - #region Private assert handling + #region Private assert handling - private static void DbgFailCore(string msg, IExceptionContext ctx = null) - { - var handler = _handler; - - if (handler != null) - handler(msg, ctx); - else if (ctx != null) - Debug.Fail(msg, ctx.ContextDescription); - else - Debug.Fail(msg); - } + private static void DbgFailCore(string msg, IExceptionContext ctx = null) + { + var handler = _handler; - private static void DbgFail(IExceptionContext ctx = null) - { - DbgFailCore("Assertion Failed", ctx); - } - private static void DbgFail(string msg) - { - DbgFailCore(msg); - } - private static void DbgFail(IExceptionContext ctx, string msg) - { - DbgFailCore(msg, ctx); - } - private static void DbgFailValue(IExceptionContext ctx = null) - { - DbgFailCore("Non-null assertion failure", ctx); - } - private static void DbgFailValue(string paramName) - { - DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-null assertion failure: {0}", paramName)); - } - private static void DbgFailValue(IExceptionContext ctx, string paramName) - { - DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-null assertion failure: {0}", paramName), ctx); - } - private static void DbgFailValue(string paramName, string msg) - { - DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-null assertion failure: {0}: {1}", paramName, msg)); - } - private static void DbgFailValue(IExceptionContext ctx, string paramName, string msg) - { - DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-null assertion failure: {0}: {1}", paramName, msg), ctx); - } - private static void DbgFailEmpty(IExceptionContext ctx = null) - { - DbgFailCore("Non-empty assertion failure", ctx); - } - private static void DbgFailEmpty(string msg) - { - DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-empty assertion failure: {0}", msg)); - } - private static void DbgFailEmpty(IExceptionContext ctx, string msg) - { - DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-empty assertion failure: {0}", msg), ctx); - } + if (handler != null) + handler(msg, ctx); + else if (ctx != null) + Debug.Fail(msg, ctx.ContextDescription); + else + Debug.Fail(msg); + } - #endregion Private assert handling + private static void DbgFail(IExceptionContext ctx = null) + { + DbgFailCore("Assertion Failed", ctx); + } + private static void DbgFail(string msg) + { + DbgFailCore(msg); + } + private static void DbgFail(IExceptionContext ctx, string msg) + { + DbgFailCore(msg, ctx); + } + private static void DbgFailValue(IExceptionContext ctx = null) + { + DbgFailCore("Non-null assertion failure", ctx); + } + private static void DbgFailValue(string paramName) + { + DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-null assertion failure: {0}", paramName)); + } + private static void DbgFailValue(IExceptionContext ctx, string paramName) + { + DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-null assertion failure: {0}", paramName), ctx); + } + private static void DbgFailValue(string paramName, string msg) + { + DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-null assertion failure: {0}: {1}", paramName, msg)); + } + private static void DbgFailValue(IExceptionContext ctx, string paramName, string msg) + { + DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-null assertion failure: {0}: {1}", paramName, msg), ctx); + } + private static void DbgFailEmpty(IExceptionContext ctx = null) + { + DbgFailCore("Non-empty assertion failure", ctx); + } + private static void DbgFailEmpty(string msg) + { + DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-empty assertion failure: {0}", msg)); + } + private static void DbgFailEmpty(IExceptionContext ctx, string msg) + { + DbgFailCore(string.Format(CultureInfo.InvariantCulture, "Non-empty assertion failure: {0}", msg), ctx); + } - [Conditional("DEBUG")] - public static void Assert(bool f) - { - if (!f) - DbgFail(); - } - [Conditional("DEBUG")] - public static void Assert(this IExceptionContext ctx, bool f) - { - if (!f) - DbgFail(ctx); - } + #endregion Private assert handling - [Conditional("DEBUG")] - public static void Assert(bool f, string msg) - { - if (!f) - DbgFail(msg); - } - [Conditional("DEBUG")] - public static void Assert(this IExceptionContext ctx, bool f, string msg) - { - if (!f) - DbgFail(ctx, msg); - } + [Conditional("DEBUG")] + public static void Assert(bool f) + { + if (!f) + DbgFail(); + } + [Conditional("DEBUG")] + public static void Assert(this IExceptionContext ctx, bool f) + { + if (!f) + DbgFail(ctx); + } - [Conditional("DEBUG")] - public static void AssertValue(T val) where T : class - { - if (object.ReferenceEquals(val, null)) - DbgFailValue(); - } - [Conditional("DEBUG")] - public static void AssertValue(this IExceptionContext ctx, T val) where T : class - { - if (object.ReferenceEquals(val, null)) - DbgFailValue(ctx); - } + [Conditional("DEBUG")] + public static void Assert(bool f, string msg) + { + if (!f) + DbgFail(msg); + } + [Conditional("DEBUG")] + public static void Assert(this IExceptionContext ctx, bool f, string msg) + { + if (!f) + DbgFail(ctx, msg); + } - [Conditional("DEBUG")] - public static void AssertValue(T val, string paramName) where T : class - { - if (object.ReferenceEquals(val, null)) - DbgFailValue(paramName); - } - [Conditional("DEBUG")] - public static void AssertValue(this IExceptionContext ctx, T val, string paramName) where T : class - { - if (object.ReferenceEquals(val, null)) - DbgFailValue(ctx, paramName); - } + [Conditional("DEBUG")] + public static void AssertValue(T val) where T : class + { + if (object.ReferenceEquals(val, null)) + DbgFailValue(); + } + [Conditional("DEBUG")] + public static void AssertValue(this IExceptionContext ctx, T val) where T : class + { + if (object.ReferenceEquals(val, null)) + DbgFailValue(ctx); + } - [Conditional("DEBUG")] - public static void AssertValue(T val, string name, string msg) where T : class - { - if (object.ReferenceEquals(val, null)) - DbgFailValue(name, msg); - } - [Conditional("DEBUG")] - public static void AssertValue(this IExceptionContext ctx, T val, string name, string msg) where T : class - { - if (object.ReferenceEquals(val, null)) - DbgFailValue(ctx, name, msg); - } + [Conditional("DEBUG")] + public static void AssertValue(T val, string paramName) where T : class + { + if (object.ReferenceEquals(val, null)) + DbgFailValue(paramName); + } + [Conditional("DEBUG")] + public static void AssertValue(this IExceptionContext ctx, T val, string paramName) where T : class + { + if (object.ReferenceEquals(val, null)) + DbgFailValue(ctx, paramName); + } - [Conditional("DEBUG")] - public static void AssertNonEmpty(string s) - { - if (string.IsNullOrEmpty(s)) - DbgFailEmpty(); - } - [Conditional("DEBUG")] - public static void AssertNonEmpty(this IExceptionContext ctx, string s) - { - if (string.IsNullOrEmpty(s)) - DbgFailEmpty(ctx); - } + [Conditional("DEBUG")] + public static void AssertValue(T val, string name, string msg) where T : class + { + if (object.ReferenceEquals(val, null)) + DbgFailValue(name, msg); + } + [Conditional("DEBUG")] + public static void AssertValue(this IExceptionContext ctx, T val, string name, string msg) where T : class + { + if (object.ReferenceEquals(val, null)) + DbgFailValue(ctx, name, msg); + } - [Conditional("DEBUG")] - public static void AssertNonWhiteSpace(string s) - { - if (string.IsNullOrWhiteSpace(s)) - DbgFailEmpty(); - } - [Conditional("DEBUG")] - public static void AssertNonWhiteSpace(this IExceptionContext ctx, string s) - { - if (string.IsNullOrWhiteSpace(s)) - DbgFailEmpty(ctx); - } + [Conditional("DEBUG")] + public static void AssertNonEmpty(string s) + { + if (string.IsNullOrEmpty(s)) + DbgFailEmpty(); + } + [Conditional("DEBUG")] + public static void AssertNonEmpty(this IExceptionContext ctx, string s) + { + if (string.IsNullOrEmpty(s)) + DbgFailEmpty(ctx); + } - [Conditional("DEBUG")] - public static void AssertNonEmpty(string s, string msg) - { - if (string.IsNullOrEmpty(s)) - DbgFailEmpty(msg); - } - [Conditional("DEBUG")] - public static void AssertNonEmpty(this IExceptionContext ctx, string s, string msg) - { - if (string.IsNullOrEmpty(s)) - DbgFailEmpty(ctx, msg); - } + [Conditional("DEBUG")] + public static void AssertNonWhiteSpace(string s) + { + if (string.IsNullOrWhiteSpace(s)) + DbgFailEmpty(); + } + [Conditional("DEBUG")] + public static void AssertNonWhiteSpace(this IExceptionContext ctx, string s) + { + if (string.IsNullOrWhiteSpace(s)) + DbgFailEmpty(ctx); + } - [Conditional("DEBUG")] - public static void AssertNonWhiteSpace(string s, string msg) - { - if (string.IsNullOrWhiteSpace(s)) - DbgFailEmpty(msg); - } - [Conditional("DEBUG")] - public static void AssertNonWhiteSpace(this IExceptionContext ctx, string s, string msg) - { - if (string.IsNullOrWhiteSpace(s)) - DbgFailEmpty(ctx, msg); - } + [Conditional("DEBUG")] + public static void AssertNonEmpty(string s, string msg) + { + if (string.IsNullOrEmpty(s)) + DbgFailEmpty(msg); + } + [Conditional("DEBUG")] + public static void AssertNonEmpty(this IExceptionContext ctx, string s, string msg) + { + if (string.IsNullOrEmpty(s)) + DbgFailEmpty(ctx, msg); + } - [Conditional("DEBUG")] - public static void AssertNonEmpty(ReadOnlySpan args) - { - if (args.IsEmpty) - DbgFail(); - } - [Conditional("DEBUG")] - public static void AssertNonEmpty(Span args) - { - if (args.IsEmpty) - DbgFail(); - } + [Conditional("DEBUG")] + public static void AssertNonWhiteSpace(string s, string msg) + { + if (string.IsNullOrWhiteSpace(s)) + DbgFailEmpty(msg); + } + [Conditional("DEBUG")] + public static void AssertNonWhiteSpace(this IExceptionContext ctx, string s, string msg) + { + if (string.IsNullOrWhiteSpace(s)) + DbgFailEmpty(ctx, msg); + } - [Conditional("DEBUG")] - public static void AssertNonEmpty(ICollection args) - { - if (Size(args) == 0) - DbgFail(); - } - [Conditional("DEBUG")] - public static void AssertNonEmpty(this IExceptionContext ctx, ICollection args) - { - if (Size(args) == 0) - DbgFail(ctx); - } + [Conditional("DEBUG")] + public static void AssertNonEmpty(ReadOnlySpan args) + { + if (args.IsEmpty) + DbgFail(); + } + [Conditional("DEBUG")] + public static void AssertNonEmpty(Span args) + { + if (args.IsEmpty) + DbgFail(); + } - [Conditional("DEBUG")] - public static void AssertNonEmpty(ICollection args, string msg) - { - if (Size(args) == 0) - DbgFail(msg); - } - [Conditional("DEBUG")] - public static void AssertNonEmpty(this IExceptionContext ctx, ICollection args, string msg) - { - if (Size(args) == 0) - DbgFail(ctx, msg); - } + [Conditional("DEBUG")] + public static void AssertNonEmpty(ICollection args) + { + if (Size(args) == 0) + DbgFail(); + } + [Conditional("DEBUG")] + public static void AssertNonEmpty(this IExceptionContext ctx, ICollection args) + { + if (Size(args) == 0) + DbgFail(ctx); + } - /// - /// This documents that the parameter can legally be null. - /// - [Conditional("INVARIANT_CHECKS")] - public static void AssertValueOrNull(T val) where T : class - { - } - [Conditional("INVARIANT_CHECKS")] - public static void AssertValueOrNull(this IExceptionContext ctx, T val) where T : class - { + [Conditional("DEBUG")] + public static void AssertNonEmpty(ICollection args, string msg) + { + if (Size(args) == 0) + DbgFail(msg); + } + [Conditional("DEBUG")] + public static void AssertNonEmpty(this IExceptionContext ctx, ICollection args, string msg) + { + if (Size(args) == 0) + DbgFail(ctx, msg); + } + + /// + /// This documents that the parameter can legally be null. + /// + [Conditional("INVARIANT_CHECKS")] + public static void AssertValueOrNull(T val) where T : class + { + } + [Conditional("INVARIANT_CHECKS")] + public static void AssertValueOrNull(this IExceptionContext ctx, T val) where T : class + { + } } } -} diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index b61193480e..36389a7d1d 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -416,7 +416,7 @@ private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSch Contracts.Assert(ctx.TokenOrNullForName(outputNames[1]) == probToken.ToString()); } - private protected override sealed bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) + private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index fde6cf4cd0..6ca7e937d5 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -11,11 +11,11 @@ using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.TreePredictor; using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime.TreePredictor; [assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(TreeEnsembleFeaturizerTransform), typeof(TreeEnsembleFeaturizerBindableMapper.Arguments), typeof(SignatureBindableMapper), "Tree Ensemble Featurizer Mapper", TreeEnsembleFeaturizerBindableMapper.LoadNameShort)] From 1d61e0b8daf317d3b9af6910ab8144b0edc0f741 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Fri, 9 Nov 2018 13:44:47 -0800 Subject: [PATCH 12/12] Adjust code analyzer test to account for contracts now being internal --- .../Code/ContractsCheckTest.cs | 7 ++++--- .../Microsoft.ML.CodeAnalyzer.Tests.csproj | 3 +++ .../Resources/ContractsCheckResource.cs | 12 ++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs index 2743b8dc9f..da5994ec7a 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs @@ -12,8 +12,9 @@ namespace Microsoft.ML.InternalCodeAnalyzer.Tests { public sealed class ContractsCheckTest : DiagnosticVerifier { - private static string _contractsSource; - internal static string Source => TestUtils.EnsureSourceLoaded(ref _contractsSource, "ContractsCheckResource.cs"); + private readonly Lazy Source = TestUtils.LazySource("ContractsCheckResource.cs"); + private readonly Lazy SourceContracts = TestUtils.LazySource("Contracts.cs"); + private readonly Lazy SourceFriend = TestUtils.LazySource("BestFriendAttribute.cs"); [Fact] public void ContractsCheck() @@ -40,7 +41,7 @@ public void ContractsCheck() diagDecode.CreateDiagnosticResult(basis + 39, 41, "CheckDecode", "\"This message is suspicious\""), }; - VerifyCSharpDiagnostic(Source, expected); + VerifyCSharpDiagnostic(Source.Value + SourceContracts.Value + SourceFriend.Value, expected); } [Fact] diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj b/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj index b55fc8c987..7851f01ecb 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj @@ -8,6 +8,9 @@ + + %(Filename)%(Extension) + %(Filename)%(Extension) diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs index 0c8ca4f332..cd6690942f 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/ContractsCheckResource.cs @@ -57,3 +57,15 @@ public static class Messages public const string CoolMessage = "This is super cool"; } } + +// Dummy declarations so that the independent compilation of contracts works as expected. +namespace Microsoft.ML.Runtime +{ + [Flags] + internal enum MessageSensitivity + { + None = 0, + Unknown = ~None + } + internal interface IHostEnvironment : IExceptionContext { } +} \ No newline at end of file