diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj index ae327a26c2..f9a1b5ef27 100644 --- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj +++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj @@ -15,6 +15,7 @@ + diff --git a/src/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.csproj b/src/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.csproj index a5f3c4b748..39f2db5c80 100644 --- a/src/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.csproj +++ b/src/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.csproj @@ -1,8 +1,9 @@ - + netstandard2.0 Microsoft.ML.HalLearners + true diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs new file mode 100644 index 0000000000..47734ae55b --- /dev/null +++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs @@ -0,0 +1,850 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Float = System.Single; + +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Security; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.Conversion; +using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.SymSgd; +using Microsoft.ML.Runtime.Training; + +[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Arguments), + new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, + SymSgdClassificationTrainer.UserNameValue, + SymSgdClassificationTrainer.LoadNameValue, + SymSgdClassificationTrainer.ShortName)] + +[assembly: LoadableClass(typeof(void), typeof(SymSgdClassificationTrainer), null, typeof(SignatureEntryPointModule), "SymSGD")] + +namespace Microsoft.ML.Runtime.SymSgd +{ + using TPredictor = IPredictorWithFeatureWeights; + + public sealed class SymSgdClassificationTrainer : + TrainerBase, + ITrainer + { + public const string LoadNameValue = "SymbolicSGD"; + public const string UserNameValue = "Symbolic SGD (binary)"; + public const string ShortName = "SymSGD"; + + public sealed class Arguments : LearnerInputBaseWithLabel + { + [Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Determinism not guaranteed. " + + "Multi-threading is not supported currently.", ShortName = "nt")] + public int? NumberOfThreads; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of passes over the data.", ShortName = "iter", SortOrder = 50)] + [TGUI(SuggestedSweeps = "1,5,10,20,30,40,50")] + [TlcModule.SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 })] + public int NumberOfIterations = 50; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance for difference in average loss in consecutive passes.", ShortName = "tol")] + public float Tol = 1e-4f; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr", NullName = "", SortOrder = 51)] + [TGUI(SuggestedSweeps = ",1e1,1e0,1e-1,1e-2,1e-3")] + [TlcModule.SweepableDiscreteParam("LearningRate", new object[] { "", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f })] + public float? LearningRate; + + [Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization", ShortName = "l2", SortOrder = 52)] + [TGUI(SuggestedSweeps = "0.0,1e-5,1e-5,1e-6,1e-7")] + [TlcModule.SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f })] + public float L2Regularization; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The number of iterations each thread learns a local model until combining it with the " + + "global model. Low value means more updated global model and high value means less cache traffic.", ShortName = "freq", NullName = "")] + [TGUI(SuggestedSweeps = ",5,20")] + [TlcModule.SweepableDiscreteParam("UpdateFrequency", new object[] { "", 5, 20 })] + public int? UpdateFrequency; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The acceleration memory budget in MB", ShortName = "accelMemBudget")] + public long MemorySize = 1024; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Shuffle data?", ShortName = "shuf")] + public bool Shuffle = true; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Apply weight to the positive class, for imbalanced data", ShortName = "piw")] + public Float PositiveInstanceWeight = 1; + + public void Check(IExceptionContext ectx) + { + ectx.CheckUserArg(LearningRate == null || LearningRate.Value > 0, nameof(LearningRate), "Must be positive."); + ectx.CheckUserArg(NumberOfIterations > 0, nameof(NumberOfIterations), "Must be positive."); + ectx.CheckUserArg(PositiveInstanceWeight > 0, nameof(PositiveInstanceWeight), "Must be positive"); + ectx.CheckUserArg(UpdateFrequency == null || UpdateFrequency > 0, nameof(UpdateFrequency), "Must be positive"); + } + } + + public override TrainerInfo Info { get; } + private readonly Arguments _args; + + /// + /// This method ensures that the data meets the requirements of this trainer and its + /// subclasses, injects necessary transforms, and throws if it couldn't meet them. + /// + /// The channel + /// The training examples + /// Gets the length of weights and bias array. For binary classification and regression, + /// this is 1. For multi-class classification, this equals the number of classes on the label. + /// A potentially modified version of + private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount) + { + ch.AssertValue(examples); + CheckLabel(examples, out weightSetCount); + examples.CheckFeatureFloatVector(); + var idvToShuffle = examples.Data; + IDataView idvToFeedTrain; + if (idvToShuffle.CanShuffle) + idvToFeedTrain = idvToShuffle; + else + { + var shuffleArgs = new ShuffleTransform.Arguments + { + PoolOnly = false, + ForceShuffle = _args.Shuffle + }; + idvToFeedTrain = new ShuffleTransform(Host, shuffleArgs, idvToShuffle); + } + + ch.Assert(idvToFeedTrain.CanShuffle); + + var roles = examples.Schema.GetColumnRoleNames(); + var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles); + + ch.AssertValue(examplesToFeedTrain.Schema.Label); + ch.AssertValue(examplesToFeedTrain.Schema.Feature); + if (examples.Schema.Weight != null) + ch.AssertValue(examplesToFeedTrain.Schema.Weight); + + int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize; + ch.Check(numFeatures > 0, "Training set has no features, aborting training."); + return examplesToFeedTrain; + } + + public override TPredictor Train(TrainContext context) + { + Host.CheckValue(context, nameof(context)); + TPredictor pred; + using (var ch = Host.Start("Training")) + { + var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount); + var initPred = context.InitialPredictor; + var linInitPred = (initPred as CalibratedPredictorBase)?.SubPredictor as LinearPredictor; + linInitPred = linInitPred ?? initPred as LinearPredictor; + Host.CheckParam(context.InitialPredictor == null || linInitPred != null, nameof(context), + "Initial predictor was not a linear predictor."); + pred = TrainCore(ch, preparedData, linInitPred, weightSetCount); + ch.Done(); + } + return pred; + } + + public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + + public SymSgdClassificationTrainer(IHostEnvironment env, Arguments args) + : base(env, LoadNameValue) + { + args.Check(Host); + _args = args; + Info = new TrainerInfo(); + } + + private TPredictor CreatePredictor(VBuffer weights, Float bias) + { + Host.CheckParam(weights.Length > 0, nameof(weights)); + + VBuffer maybeSparseWeights = default; + VBufferUtils.CreateMaybeSparseCopy(ref weights, ref maybeSparseWeights, + Conversions.Instance.GetIsDefaultPredicate(NumberType.Float)); + var predictor = new LinearBinaryPredictor(Host, ref maybeSparseWeights, bias); + return new ParameterMixingCalibratedPredictor(Host, predictor, new PlattCalibrator(Host, -1, 0)); + } + + [TlcModule.EntryPoint(Name = "Trainers.SymSgdBinaryClassifier", Desc = "Train a symbolic SGD.", UserName = SymSgdClassificationTrainer.UserNameValue, ShortName = SymSgdClassificationTrainer.ShortName)] + public static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Arguments input) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register("TrainSymSGD"); + host.CheckValue(input, nameof(input)); + EntryPointUtils.CheckInputArgs(host, input); + + return LearnerEntryPointsUtils.Train(host, input, + () => new SymSgdClassificationTrainer(host, input), + () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); + } + + // We buffer instances from the cursor (limited to memorySize) and passes that buffer to + // the native code to learn for multiple instances by one interop call. + + /// + /// This struct holds the information about the size, label and isDense of each instance + /// to be able to pass it to the native code. + /// + private struct InstanceProperties + { + public readonly int FeatureCount; + public readonly float Label; + public readonly bool IsDense; + + public InstanceProperties(int featureCount, float label, bool isDense) + { + FeatureCount = featureCount; + Label = label; + IsDense = isDense; + } + } + + /// + /// ArrayManager stores multiple arrays of type in a "very long" array whose size is specified by accelChunkSize. + /// Once one of the very long arrays is full, another one is allocated to store additional arrays. The required memory + /// for this buffering is limited by memorySize. + /// + /// Note that these very long arrays can be reused. This means that learning can be done in batches without the overhead associated + /// with allocation. + /// + /// The benefit of this way of storage is that only a handful of new calls will be needed + /// which saves time. + /// + /// The type of arrays to be stored + private sealed class ArrayManager : IDisposable + { + /// + /// This structure is used for pinning very long arrays to stop GC from moving them. + /// The reason for this design is that when these arrays are passed to native code, + /// GC does not move the objects. + /// + private struct VeryLongArray + { + public T[] Buffer; + public GCHandle GcHandle; + public int Length => Buffer.Length; + public VeryLongArray(int veryLongArrayLength) + { + Buffer = new T[veryLongArrayLength]; + GcHandle = GCHandle.Alloc(Buffer, GCHandleType.Pinned); + } + + public void Free() + { + GcHandle.Free(); + } + } + // This list holds very long arrays. + private readonly List _storage; + // Length of each very long array + // This is not readonly because there might be an instance where the length of the + // instance is longer than _veryLongArrayLength and we have to adjust it + private int _veryLongArrayLength; + // This index is used to walk over _storage list. During storing or giving an array, + // we are at _storage[_storageIndex]. + private int _storageIndex; + // This index is used within a very long array from _storage[_storageIndex]. During storing or + // giving an array, we are at _storage[_storageIndex][_indexInCurArray]. + private int _indexInCurArray; + // This is used to access AccelMemBudget, AccelChunkSize and UsedMemory + private readonly SymSgdClassificationTrainer _trainer; + + private readonly IChannel _ch; + + // Size of type T + private readonly int _sizeofT; + + /// + /// Constructor for initializing _storage and other indices. + /// + /// + /// + public ArrayManager(SymSgdClassificationTrainer trainer, IChannel ch) + { + _storage = new List(); + // Setting the default value to 2^17. + _veryLongArrayLength = (1 << 17); + _indexInCurArray = 0; + _storageIndex = 0; + _trainer = trainer; + _ch = ch; + _sizeofT = Marshal.SizeOf(typeof(T)); + } + + /// + /// + /// Returns if the allocation was successful + private bool CheckAndAllocate() + { + // Check if this allocation violates the memorySize. + if (_trainer.UsedMemory + _veryLongArrayLength * _sizeofT <= _trainer.AcceleratedMemoryBudgetBytes) + { + // Add the additional allocation to UsedMemory + _trainer.UsedMemory += _veryLongArrayLength * _sizeofT; + _storage.Add(new VeryLongArray(_veryLongArrayLength)); + return true; + } + // If allocation violates the budget, bail. + return false; + } + + /// + /// This method checks if an array of size fits in _storage[_storageIndex][_indexInCurArray.._indexInCurArray+size-1]. + /// + /// The size of the array to fit in the very long array _storage[_storageIndex] + /// + private bool FitsInCurArray(int size) + { + _ch.Assert(_storage[_storageIndex].Length == _veryLongArrayLength); + return _indexInCurArray <= _veryLongArrayLength - size; + } + + /// + /// Tries to add array to the storage without violating the restriction of memorySize. + /// + /// The array to be added + /// Length of the array. .Length is unreliable since TLC cursoring + /// has its own allocation mechanism. + /// Return if the allocation was successful + public bool AddToStorage(T[] instArray, int instArrayLength) + { + _ch.Assert(0 < instArrayLength && instArrayLength <= Utils.Size(instArray)); + _ch.Assert(instArrayLength * _sizeofT * 2 < _trainer.AcceleratedMemoryBudgetBytes); + if (instArrayLength > _veryLongArrayLength) + { + // In this case, we need to increase _veryLongArrayLength. + if (_indexInCurArray == 0 && _storageIndex == 0) + { + // If there are no instances loaded, all of the allocated very long arrays need to be deallocated + // and longer _veryLongArrayLength be used instead. + DeallocateVeryLongArrays(); + _storage.Clear(); + + _veryLongArrayLength = instArrayLength; + } + else + { + // If there are already instances loaded into the _storage, train on them. + return false; + } + } + // Special case that happens only when _storage is empty + if (_storage.Count == 0) + { + if (!CheckAndAllocate()) + return false; + _indexInCurArray = 0; + } + // Check if instArray can be fitted in the current setup. + else if (!FitsInCurArray(instArrayLength)) + { + // Check if we reached the end of _storage. If so try to allocate a new very long array. + // Otherwise, there are more very long arrays left, just move to the next one. + if (_storageIndex == _storage.Count - 1) + { + if (!CheckAndAllocate()) + return false; + } + _indexInCurArray = 0; + _storageIndex++; + } + Array.Copy(instArray, 0, _storage[_storageIndex].Buffer, _indexInCurArray, instArrayLength); + _indexInCurArray += instArrayLength; + return true; + } + + /// + /// This is a soft clear, meaning that it doesn't reallocate, only sets _storageIndex and + /// _indexInCurArray to 0. + /// + public void ResetIndexing() + { + _storageIndex = 0; + _indexInCurArray = 0; + } + + /// + /// Gives an array of . + /// + /// The size of array to give + /// + /// + public void GiveArrayOfSize(int size, out GCHandle? outGcHandle, out int outArrayStartIndex) + { + // Generally it is the user responsibility to not ask for an array of a size that has not been + // previously allocated. + + // In case no allocation has occured. + if (_storage.Count == 0) + { + outGcHandle = null; + outArrayStartIndex = 0; + } + else + { + // Check if the array fits in _storage[_storageIndex]. + if (!FitsInCurArray(size)) + { + // If not, it must be in the next very long array. + _storageIndex++; + _indexInCurArray = 0; + } + outGcHandle = _storage[_storageIndex].GcHandle; + outArrayStartIndex = _indexInCurArray; + _indexInCurArray += size; + } + } + + private void DeallocateVeryLongArrays() + { + foreach (var veryLongArray in _storage) + veryLongArray.Free(); + } + + public void Dispose() + { + DeallocateVeryLongArrays(); + } + } + + /// + /// This class manages the buffering for instances + /// + private sealed class InputDataManager : IDisposable + { + // This ArrayManager is used for indices of instances + private readonly ArrayManager _instIndices; + // This ArrayManager is used for values of instances + private readonly ArrayManager _instValues; + // This is a list of the properties of instances that are buffered. + private readonly List _instanceProperties; + private readonly FloatLabelCursor.Factory _cursorFactory; + private FloatLabelCursor _cursor; + // This is used as a mechanism to make sure that the memorySize restriction is not violated. + private bool _cursorMoveNext; + // This is the index to go over the instances in instanceProperties + private int _instanceIndex; + // This is used to access AccelMemBudget, AccelChunkSize and UsedMemory + private readonly SymSgdClassificationTrainer _trainer; + private readonly IChannel _ch; + + // Whether memorySize was big enough to load the entire instances into the buffer + private bool _isFullyLoaded; + public bool IsFullyLoaded => _isFullyLoaded; + public int Count => _instanceProperties.Count; + + // Tells if we have gone through the dataset entirely. + public bool FinishedTheLoad => !_cursorMoveNext; + + public InputDataManager(SymSgdClassificationTrainer trainer, FloatLabelCursor.Factory cursorFactory, IChannel ch) + { + _instIndices = new ArrayManager(trainer, ch); + _instValues = new ArrayManager(trainer, ch); + _instanceProperties = new List(); + + _cursorFactory = cursorFactory; + _ch = ch; + _cursor = cursorFactory.Create(); + _cursorMoveNext = _cursor.MoveNext(); + _isFullyLoaded = true; + + _instanceIndex = 0; + + _trainer = trainer; + } + + // Has to be called for cursoring through the data + public void RestartLoading(bool needShuffle, IHost host) + { + _cursor.Dispose(); + if (needShuffle) + _cursor = _cursorFactory.Create(RandomUtils.Create(host.Rand.Next())); + else + _cursor = _cursorFactory.Create(); + _cursorMoveNext = _cursor.MoveNext(); + } + + /// + /// This method tries to load as much as possible from the cursor into the buffer until the memorySize is reached. + /// + public void LoadAsMuchAsPossible() + { + _instValues.ResetIndexing(); + _instIndices.ResetIndexing(); + _instanceProperties.Clear(); + + while (_cursorMoveNext) + { + int featureCount = _cursor.Features.Count; + // If the instance has no feature, ignore it! + if (featureCount == 0) + { + _cursorMoveNext = _cursor.MoveNext(); + continue; + } + + // We assume that cursor.Features.values are represented by Float and cursor.Features.indices are represented by int + // We conservatively assume that an instance is sparse and therefore, it has an array of Floats and ints for values and indices + int perNonZeroInBytes = sizeof(Float) + sizeof(int); + if (featureCount > _trainer.AcceleratedMemoryBudgetBytes / perNonZeroInBytes) + { + // Hopefully this never happens. But the memorySize must >= perNonZeroInBytes * length(the longest instance). + throw _ch.Except("Acceleration memory budget is too small! Need at least {0} MB for at least one of the instances", + featureCount * perNonZeroInBytes / (1024 * 1024)); + } + + bool couldLoad = true; + if (!_cursor.Features.IsDense) + // If it is a sparse instance, load its indices to instIndices buffer + couldLoad = _instIndices.AddToStorage(_cursor.Features.Indices, featureCount); + // Load values of an instance into instValues + if (couldLoad) + couldLoad = _instValues.AddToStorage(_cursor.Features.Values, featureCount); + + // If the load was successful, load the instance properties to instanceProperties + if (couldLoad) + { + float label = _cursor.Label; + InstanceProperties prop = new InstanceProperties(featureCount, label, _cursor.Features.IsDense); + _instanceProperties.Add(prop); + + _cursorMoveNext = _cursor.MoveNext(); + + if (_instanceProperties.Count > (1 << 30)) + { + // If it happened to be the case that we have so much memory that we were able to load (1<<30) instances, + // break. This is because in such a case _instanceProperties can only be addressed by int32 and (1<<30) is + // getting close to the limits. This should rarely happen! + _isFullyLoaded = false; + break; + } + } + else + { + // If couldLoad fails at any point (which is becuase of memorySize), isFullyLoaded becomes false forever + _isFullyLoaded = false; + break; + } + } + } + + public void PrepareCursoring() + { + _instanceIndex = 0; + _instIndices.ResetIndexing(); + _instValues.ResetIndexing(); + } + + /// + /// This method provides instances stored in the buffer in a sequential order. Note that method PrepareCursoring should be called before using this method. + /// + /// The property of the given instance. It is set to null in case there are no more instance. + /// + /// The offset for the indices array. + /// + /// The offset for the values array. + /// Retruns whether output is valid. Otherwise we have gone through the entire loaded instances. + public bool GiveNextInstance(out InstanceProperties? prop, out GCHandle? indicesGcHandle, out int indicesStartIndex, + out GCHandle? valuesGcHandle, out int valuesStartIndex) + { + if (_instanceIndex == _instanceProperties.Count) + { + // We hit the end. + prop = null; + indicesGcHandle = null; + indicesStartIndex = 0; + valuesGcHandle = null; + valuesStartIndex = 0; + return false; + } + prop = _instanceProperties[_instanceIndex]; + if (!prop.Value.IsDense) + { + // If sparse, set indices array accordingly. + _instIndices.GiveArrayOfSize(prop.Value.FeatureCount, out indicesGcHandle, out indicesStartIndex); + } + else + { + indicesGcHandle = null; + indicesStartIndex = 0; + } + // Load values here. + _instValues.GiveArrayOfSize(prop.Value.FeatureCount, out valuesGcHandle, out valuesStartIndex); + _instanceIndex++; + return true; + } + + public void Dispose() + { + _cursor.Dispose(); + _instIndices.Dispose(); + _instValues.Dispose(); + } + } + + private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount) + { + int numFeatures = data.Schema.Feature.Type.VectorSize; + var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight); + int numThreads = 1; + ch.CheckUserArg(numThreads > 0, nameof(_args.NumberOfThreads), + "The number of threads must be either null or a positive integer."); + + ch.Assert(numThreads > 0); + var positiveInstanceWeight = _args.PositiveInstanceWeight; + VBuffer weights = default; + float bias = 0.0f; + if (predictor != null) + { + predictor.GetFeatureWeights(ref weights); + VBufferUtils.Densify(ref weights); + bias = predictor.Bias; + } + else + weights = VBufferUtils.CreateDense(numFeatures); + + // Reference: Parasail. SymSGD. + bool tuneLR = _args.LearningRate == null; + var lr = _args.LearningRate ?? 1.0f; + + bool tuneNumLocIter = (_args.UpdateFrequency == null); + var numLocIter = _args.UpdateFrequency ?? 1; + + var l2Const = _args.L2Regularization; + var piw = _args.PositiveInstanceWeight; + + // This is state of the learner that is shared with the native code. + State state = new State(); + GCHandle stateGCHandle = default; + try + { + stateGCHandle = GCHandle.Alloc(state, GCHandleType.Pinned); + + state.TotalInstancesProcessed = 0; + using (InputDataManager inputDataManager = new InputDataManager(this, cursorFactory, ch)) + { + bool shouldInitialize = true; + using (var pch = Host.StartProgressChannel("Preprocessing")) + inputDataManager.LoadAsMuchAsPossible(); + + int iter = 0; + if (inputDataManager.IsFullyLoaded) + ch.Info("Data fully loaded into memory."); + using (var pch = Host.StartProgressChannel("Training")) + { + if (inputDataManager.IsFullyLoaded) + { + pch.SetHeader(new ProgressHeader(new[] { "iterations" }), + entry => entry.SetProgress(0, state.PassIteration, _args.NumberOfIterations)); + // If fully loaded, call the SymSGDNative and do not come back until learned for all iterations. + Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weights.Values, ref bias, numFeatures, + _args.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _args.Tol, _args.Shuffle, shouldInitialize, stateGCHandle); + shouldInitialize = false; + } + else + { + pch.SetHeader(new ProgressHeader(new[] { "iterations" }), + entry => entry.SetProgress(0, iter, _args.NumberOfIterations)); + + // Since we loaded data in batch sizes, multiple passes over the loaded data is feasible. + int numPassesForABatch = inputDataManager.Count / 10000; + while (iter < _args.NumberOfIterations) + { + // We want to train on the final passes thoroughly (without learning on the same batch multiple times) + // This is for fine tuning the AUC. Experimentally, we found that 1 or 2 passes is enough + int numFinalPassesToTrainThoroughly = 2; + // We also do not want to learn for more passes than what the user asked + int numPassesForThisBatch = Math.Min(numPassesForABatch, _args.NumberOfIterations - iter - numFinalPassesToTrainThoroughly); + // If all of this leaves us with 0 passes, then set numPassesForThisBatch to 1 + numPassesForThisBatch = Math.Max(1, numPassesForThisBatch); + state.PassIteration = iter; + Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weights.Values, ref bias, numFeatures, + numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _args.Tol, _args.Shuffle, shouldInitialize, stateGCHandle); + shouldInitialize = false; + + // Check if we are done with going through the data + if (inputDataManager.FinishedTheLoad) + { + iter += numPassesForThisBatch; + // Check if more passes are left + if (iter < _args.NumberOfIterations) + inputDataManager.RestartLoading(_args.Shuffle, Host); + } + + // If more passes are left, load as much as possible + if (iter < _args.NumberOfIterations) + inputDataManager.LoadAsMuchAsPossible(); + } + } + + // Maps back the dense features that are mislocated + if (numThreads > 1) + Native.MapBackWeightVector(weights.Values, stateGCHandle); + Native.DeallocateSequentially(stateGCHandle); + } + } + } + finally + { + if (stateGCHandle.IsAllocated) + stateGCHandle.Free(); + } + return CreatePredictor(weights, bias); + } + + private void CheckLabel(RoleMappedData examples, out int weightSetCount) + { + examples.CheckBinaryLabel(); + weightSetCount = 1; + } + + private long AcceleratedMemoryBudgetBytes => _args.MemorySize * 1024 * 1024; + private long UsedMemory { get; set; } + + private static unsafe class Native + { + internal const string DllName = "SymSgdNative"; + internal const string MklDllName = "MklImports"; + + static Native() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + //Work around to get libMKLImport loaded before SymSGDNative + cblas_sdot(0, null, 0, null, 0); + } + } + + [DllImport(MklDllName), SuppressUnmanagedCodeSecurity] + private static extern void cblas_sdot(int vecSize, float* denseVecX, int incX, float* denseVecY, int incY); + + [DllImport(DllName), SuppressUnmanagedCodeSecurity] + private static extern void LearnAll(int totalNumInstances, int* instSizes, int** instIndices, + float** instValues, float* labels, bool tuneLR, ref float lr, float l2Const, float piw, float* weightVector, ref float bias, + int numFeatres, int numPasses, int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, State* state); + + /// + /// This method puts all of the buffered instances in array of pointers to pass it to SymSGDNative. + /// + /// The buffered data + /// Specifies if SymSGD should tune alpha automatically + /// Initial learning rate + /// + /// + /// The storage for the weight vector + /// bias + /// Number of features + /// Number of passes + /// Number of threads + /// Specifies if SymSGD should tune numLocIter automatically + /// Number of thread local iterations of SGD before combining with the global model + /// Tolerance for the amount of decrease in the total loss in consecutive passes + /// Specifies if data needs to be shuffled + /// Specifies if this is the first time to run SymSGD + /// + public static void LearnAll(InputDataManager inputDataManager, bool tuneLR, + ref float lr, float l2Const, float piw, float[] weightVector, ref float bias, int numFeatres, int numPasses, + int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, GCHandle stateGCHandle) + { + inputDataManager.PrepareCursoring(); + + int totalNumInstances = inputDataManager.Count; + // Each instance has a pointer to indices array and a pointer to values array + int*[] arrayIndicesPointers = new int*[totalNumInstances]; + float*[] arrayValuesPointers = new float*[totalNumInstances]; + // Labels of the instances + float[] instLabels = new float[totalNumInstances]; + // Sizes of each inst + int[] instSizes = new int[totalNumInstances]; + + int instanceIndex = 0; + // Going through the buffer to set the properties and the pointers + while (inputDataManager.GiveNextInstance(out InstanceProperties? prop, out GCHandle? indicesGcHandle, out int indicesStartIndex, out GCHandle? valuesGcHandle, out int valuesStartIndex)) + { + if (prop.Value.IsDense) + { + arrayIndicesPointers[instanceIndex] = null; + } + else + { + int* pIndicesArray = (int*)indicesGcHandle.Value.AddrOfPinnedObject(); + arrayIndicesPointers[instanceIndex] = &pIndicesArray[indicesStartIndex]; + } + float* pValuesArray = (float*)valuesGcHandle.Value.AddrOfPinnedObject(); + arrayValuesPointers[instanceIndex] = &pValuesArray[valuesStartIndex]; + + instLabels[instanceIndex] = prop.Value.Label; + instSizes[instanceIndex] = prop.Value.FeatureCount; + instanceIndex++; + } + + fixed (float* pweightVector = &weightVector[0]) + fixed (int** pIndicesPointer = &arrayIndicesPointers[0]) + fixed (float** pValuesPointer = &arrayValuesPointers[0]) + fixed (int* pInstSizes = &instSizes[0]) + fixed (float* pInstLabels = &instLabels[0]) + { + LearnAll(totalNumInstances, pInstSizes, pIndicesPointer, pValuesPointer, pInstLabels, tuneLR, ref lr, l2Const, piw, + pweightVector, ref bias, numFeatres, numPasses, numThreads, tuneNumLocIter, ref numLocIter, tolerance, needShuffle, shouldInitialize, (State*)stateGCHandle.AddrOfPinnedObject()); + } + } + + [DllImport(DllName), SuppressUnmanagedCodeSecurity] + private static extern void MapBackWeightVector(float* weightVector, State* state); + + /// + /// Maps back the dense feature to the correct position + /// + /// The weight vector + /// + public static void MapBackWeightVector(float[] weightVector, GCHandle stateGCHandle) + { + fixed (float* pweightVector = &weightVector[0]) + MapBackWeightVector(pweightVector, (State*)stateGCHandle.AddrOfPinnedObject()); + } + + [DllImport(DllName), SuppressUnmanagedCodeSecurity] + private static extern void DeallocateSequentially(State* state); + + public static void DeallocateSequentially(GCHandle stateGCHandle) + { + DeallocateSequentially((State*)stateGCHandle.AddrOfPinnedObject()); + } + } + + /// + /// This is the state of a SymSGD learner that is shared between the managed and native code. + /// + [StructLayout(LayoutKind.Explicit)] + public unsafe struct State + { +#pragma warning disable 649 // never assigned + [FieldOffset(0x00)] + public readonly int NumLearners; + [FieldOffset(0x04)] + public int TotalInstancesProcessed; + [FieldOffset(0x08)] + public readonly void* Learners; + [FieldOffset(0x10)] + public readonly void* FreqFeatUnorderedMap; + [FieldOffset(0x18)] + public readonly int* FreqFeatDirectMap; + [FieldOffset(0x20)] + public readonly int NumFrequentFeatures; + [FieldOffset(0x24)] + public int PassIteration; + [FieldOffset(0x28)] + public readonly float WeightScaling; +#pragma warning restore 649 // never assigned + } + } + +} diff --git a/src/Microsoft.ML.HalLearners/doc.xml b/src/Microsoft.ML.HalLearners/doc.xml index d7ec04bb89..15022b4aa2 100644 --- a/src/Microsoft.ML.HalLearners/doc.xml +++ b/src/Microsoft.ML.HalLearners/doc.xml @@ -1,4 +1,4 @@ - + @@ -22,6 +22,24 @@ - + + + Parallel Stochastic Gradient Descent trainer. + + + Stochastic gradient descent (SGD) is an interative algorithm + that optimizes a differentiable objective function. SYMSGD parallelizes SGD using Sound Combiners. + + + + new SymSgdBinaryClassifier() + { + NumberOfIterations = 50, + L2Regularization = 0, + Shuffle = true + } + + + \ No newline at end of file diff --git a/src/Microsoft.ML.StandardLearners/Microsoft.ML.StandardLearners.csproj b/src/Microsoft.ML.StandardLearners/Microsoft.ML.StandardLearners.csproj index 6bada43299..f23f8c1ad9 100644 --- a/src/Microsoft.ML.StandardLearners/Microsoft.ML.StandardLearners.csproj +++ b/src/Microsoft.ML.StandardLearners/Microsoft.ML.StandardLearners.csproj @@ -1,4 +1,4 @@ - + netstandard2.0 @@ -11,6 +11,7 @@ + diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index bf753fe486..4f887dd2b4 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -838,6 +838,18 @@ public void Add(Microsoft.ML.Trainers.StochasticGradientDescentBinaryClassifier _jsonNodes.Add(Serialize("Trainers.StochasticGradientDescentBinaryClassifier", input, output)); } + public Microsoft.ML.Trainers.SymSgdBinaryClassifier.Output Add(Microsoft.ML.Trainers.SymSgdBinaryClassifier input) + { + var output = new Microsoft.ML.Trainers.SymSgdBinaryClassifier.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Trainers.SymSgdBinaryClassifier input, Microsoft.ML.Trainers.SymSgdBinaryClassifier.Output output) + { + _jsonNodes.Add(Serialize("Trainers.SymSgdBinaryClassifier", input, output)); + } + public Microsoft.ML.Transforms.ApproximateBootstrapSampler.Output Add(Microsoft.ML.Transforms.ApproximateBootstrapSampler input) { var output = new Microsoft.ML.Transforms.ApproximateBootstrapSampler.Output(); @@ -9761,6 +9773,128 @@ public StochasticGradientDescentBinaryClassifierPipelineStep(Output output) } } + namespace Trainers + { + + /// + /// Train a symbolic SGD. + /// + public sealed partial class SymSgdBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + { + + + /// + /// Degree of lock-free parallelism. Determinism not guaranteed. Multi-threading is not supported currently. + /// + public int? NumberOfThreads { get; set; } + + /// + /// Number of passes over the data. + /// + [TlcModule.SweepableDiscreteParamAttribute("NumberOfIterations", new object[]{1, 5, 10, 20, 30, 40, 50})] + public int NumberOfIterations { get; set; } = 50; + + /// + /// Tolerance for difference in average loss in consecutive passes. + /// + public float Tol { get; set; } = 0.0001f; + + /// + /// Learning rate + /// + [TlcModule.SweepableDiscreteParamAttribute("LearningRate", new object[]{"", 10f, 1f, 0.1f, 0.01f, 0.001f})] + public float? LearningRate { get; set; } + + /// + /// L2 regularization + /// + [TlcModule.SweepableDiscreteParamAttribute("L2Regularization", new object[]{0f, 1E-05f, 1E-05f, 1E-06f, 1E-07f})] + public float L2Regularization { get; set; } + + /// + /// The number of iterations each thread learns a local model until combining it with the global model. Low value means more updated global model and high value means less cache traffic. + /// + [TlcModule.SweepableDiscreteParamAttribute("UpdateFrequency", new object[]{"", 5, 20})] + public int? UpdateFrequency { get; set; } + + /// + /// The acceleration memory budget in MB + /// + public long MemorySize { get; set; } = 1024; + + /// + /// Shuffle data? + /// + public bool Shuffle { get; set; } = true; + + /// + /// Apply weight to the positive class, for imbalanced data + /// + public float PositiveInstanceWeight { get; set; } = 1f; + + /// + /// Column to use for labels + /// + public string LabelColumn { get; set; } = "Label"; + + /// + /// The data to be used for training + /// + public Var TrainingData { get; set; } = new Var(); + + /// + /// Column to use for features + /// + public string FeatureColumn { get; set; } = "Features"; + + /// + /// Normalize option for the feature column + /// + public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; + + /// + /// Whether learner should cache input training data + /// + public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; + + + public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + { + /// + /// The trained model + /// + public Var PredictorModel { get; set; } = new Var(); + + } + public Var GetInputData() => TrainingData; + + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + if (previousStep != null) + { + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(SymSgdBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } + + TrainingData = dataStep.Data; + } + Output output = experiment.Add(this); + return new SymSgdBinaryClassifierPipelineStep(output); + } + + private class SymSgdBinaryClassifierPipelineStep : ILearningPipelinePredictorStep + { + public SymSgdBinaryClassifierPipelineStep(Output output) + { + Model = output.PredictorModel; + } + + public Var Model { get; } + } + } + } + namespace Transforms { diff --git a/src/Native/CMakeLists.txt b/src/Native/CMakeLists.txt index 767f6151fa..a0b06f864f 100644 --- a/src/Native/CMakeLists.txt +++ b/src/Native/CMakeLists.txt @@ -182,3 +182,4 @@ add_subdirectory(CpuMathNative) add_subdirectory(FastTreeNative) add_subdirectory(LdaNative) add_subdirectory(FactorizationMachineNative) +add_subdirectory(SymSgdNative) \ No newline at end of file diff --git a/src/Native/SymSgdNative/CMakeLists.txt b/src/Native/SymSgdNative/CMakeLists.txt new file mode 100644 index 0000000000..92168832a7 --- /dev/null +++ b/src/Native/SymSgdNative/CMakeLists.txt @@ -0,0 +1,32 @@ +project (SymSgdNative) + +set(SOURCES + SymSgdNative.cpp +) + +if(WIN32) + find_library(MKL_LIBRARY MklImports HINTS ${CMAKE_SOURCE_DIR}/../../packages/mlnetmkldeps/0.0.0.5/runtimes/win-x64/native) +else() + list(APPEND SOURCES ${VERSION_FILE_PATH}) + if(CMAKE_SYSTEM_NAME STREQUAL Darwin) + message("Linking SymSgdNative with MKL on macOS.") + find_library(MKL_LIBRARY libMklImports.dylib HINTS "${CMAKE_SOURCE_DIR}/../../packages/mlnetmkldeps/0.0.0.5/runtimes/osx-x64/native") + else() + message("Linking SymSgdNative with MKL on linux.") + find_library(MKL_LIBRARY libMklImports.so HINTS ${CMAKE_SOURCE_DIR}/../../packages/mlnetmkldeps/0.0.0.5/runtimes/linux-x64/native) + SET(CMAKE_SKIP_BUILD_RPATH FALSE) + SET(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE) + SET(CMAKE_INSTALL_RPATH "${CMAKE_SOURCE_DIR}/../../packages/mlnetmkldeps/0.0.0.5/runtimes") + SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) + SET(CMAKE_INSTALL_RPATH "${CMAKE_SOURCE_DIR}/../../packages/mlnetmkldeps/0.0.0.5/runtimes") + endif() +endif() + +add_library(SymSgdNative SHARED ${SOURCES} ${RESOURCES}) +target_link_libraries(SymSgdNative PUBLIC ${MKL_LIBRARY}) + +if(CMAKE_SYSTEM_NAME STREQUAL Darwin) + set_target_properties(SymSgdNative PROPERTIES INSTALL_RPATH "${CMAKE_SOURCE_DIR}/../../packages/mlnetmkldeps/0.0.0.5/runtimes/osx-x64/native") +endif() + +install_library_and_symbols (SymSgdNative) \ No newline at end of file diff --git a/src/Native/SymSgdNative/Macros.h b/src/Native/SymSgdNative/Macros.h new file mode 100644 index 0000000000..728dbe3efa --- /dev/null +++ b/src/Native/SymSgdNative/Macros.h @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#pragma once +#define MIN(__X__, __Y__) (((__X__) > (__Y__)) ? (__Y__) : (__X__)) + +// This is a very large prime number used for permutation +#define VERYLARGEPRIME 961748941 \ No newline at end of file diff --git a/src/Native/SymSgdNative/SparseBLAS.h b/src/Native/SymSgdNative/SparseBLAS.h new file mode 100644 index 0000000000..211ccbadb6 --- /dev/null +++ b/src/Native/SymSgdNative/SparseBLAS.h @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#pragma once +#include "../Stdafx.h" + +extern "C" float cblas_sdot(const int vecSize, const float* denseVecX, const int incX, const float* denseVecY, const int incY); +extern "C" float cblas_sdoti(const int sparseVecSize, const float* sparseVecValues, const int* sparseVecIndices, float* denseVec); +extern "C" void cblas_saxpy(const int vecSize, const float coef, const float* denseVecX, const int incX, float* denseVecY, const int incY); +extern "C" void cblas_saxpyi(const int sparseVecSize, const float coef, const float* sparseVecValues, const int* sparseVecIndices, float* denseVec); + +float SDOT(const int vecSize, const float* denseVecX, const float* denseVecY) { + return cblas_sdot(vecSize, denseVecX, 1, denseVecY, 1); +} + +float SDOTI(const int sparseVecSize, const int* sparseVecIndices, const float* sparseVecValues, float* denseVec) { + return cblas_sdoti(sparseVecSize, sparseVecValues, sparseVecIndices, denseVec); +} + +void SAXPY(const int vecSize, const float* denseVecX, float* denseVecY, float coef) { + return cblas_saxpy(vecSize, coef, denseVecX, 1, denseVecY, 1); +} + +void SAXPYI(const int sparseVecSize, const int* sparseVecIndices, const float* sparseVecValues, float* denseVec, float coef) { + cblas_saxpyi(sparseVecSize, coef, sparseVecValues, sparseVecIndices, denseVec); +} \ No newline at end of file diff --git a/src/Native/SymSgdNative/SymSgdNative.cpp b/src/Native/SymSgdNative/SymSgdNative.cpp new file mode 100644 index 0000000000..e13798fdc2 --- /dev/null +++ b/src/Native/SymSgdNative/SymSgdNative.cpp @@ -0,0 +1,440 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +// SymSGDNative.cpp : Defines the exported functions for the DLL application. + +#include +#include +#if defined(USE_OMP) +#include +#endif +#include +#include "../Stdafx.h" +#include "Macros.h" +#include "SparseBLAS.h" +#include "SymSgdNative.h" + +// This method learns for a single instance +inline void LearnInstance(int instSize, int* instIndices, float* instValues, + float label, float alpha, float l2Const, float piw, float& weightScaling, float* weightVector, float& bias) { + float dotProduct = 0.0f; + if (instIndices) // If it is a sparse instance + dotProduct = SDOTI(instSize, instIndices, instValues, weightVector) * weightScaling + bias; + else // If it is dense case. + dotProduct = SDOT(instSize, instValues, weightVector)*weightScaling + bias; + // Compute the derivative coefficient + float sigmoidPrediction = 1.0f / (1.0f + exp(-dotProduct)); + float derivative = (label > 0) ? piw * (sigmoidPrediction - 1) : sigmoidPrediction; + float derivativeCoef = -alpha * derivative; + weightScaling *= (1.0f - alpha*l2Const); + // Apply the derivative back to the weightVector + if (instIndices) // If it is a sparse instance + SAXPYI(instSize, instIndices, instValues, weightVector, derivativeCoef / weightScaling); + else + SAXPY(instSize, instValues, weightVector, derivativeCoef / weightScaling); + + bias = bias + derivativeCoef; +} + +// This method permutes frequent features with starting consecutive features +void ComputeRemapping(int totalNumInstances, int* instSizes, int** instIndices, + int numFeat, int numLocIter, int numThreads, SymSGDState* state, int& numFreqFeat) { + // There are two maps used for this permutation: + // 1) a direct map + state->FreqFeatDirectMap = new int[numFeat]; + // 2) an unordered map + state->FreqFeatUnorderedMap = new std::unordered_map(); + + int* freqFeatDirectMap = state->FreqFeatDirectMap; + std::unordered_map* freqFeatUnorderedMap = (std::unordered_map*)state->FreqFeatUnorderedMap; + + // If numLocIter is 1, it means that every iteration must be reduced and therefore, no feature should be considered frequent + // In this special case, we do not need a mapping since it is identity. + if (numLocIter == 1) { + memset(freqFeatDirectMap, -1, sizeof(int)*numFeat); + numFreqFeat = 0; + return; + } + + memset(freqFeatDirectMap, 0, sizeof(int)*numFeat); + + // Frequent features are searched by subSampling the data and histogramming the frequecy of each feature + int subSampleSize = MIN(1000 * numLocIter*numThreads, totalNumInstances); + // The threshold to call a feature frequent + float threshold = (float)subSampleSize / (float)numLocIter; + + // Compute the histogram for the frequency of features + // freqFeatDirectMap is used to store the histogram + for (int i = 0; i < subSampleSize; i++) { + if (instIndices[i]) { + for (int j = 0; j < instSizes[i]; j++) { + freqFeatDirectMap[instIndices[i][j]]++; + } + } else + { + for (int j = 0; j < numFeat; j++) { + freqFeatDirectMap[j]++; + } + } + } + // Compute the permutation that is required to re-order features such that + // frequent features are at the beginning of feature space + + // feature i and feautre numFreqFeat are subject to swap. The only difficulty is when + // numFreqFeat is already recognized as a frequent feature. + // Variable numFreqFeat keeps the numFreqFeat observed so far + numFreqFeat = 0; + for (int i = 0; i < numFeat; i++) { + // Check if i is a frequent feature + if (freqFeatDirectMap[i] > threshold) { + // Check if all the features seen so far are all frequent + if (numFreqFeat != i) { + // We have to swap i with numFreqFeat + auto searchedRes = freqFeatUnorderedMap->find(numFreqFeat); + // Check if numFreqFeat is already occupied in freqFeatureUnorderedMap + // which means that numFreqFeat is already a frequent feature + if (searchedRes == freqFeatUnorderedMap->end()) { + // In this case, numFreqFeature is unoccupied and can be used for permutation for i. + (*freqFeatUnorderedMap)[i] = numFreqFeat; + (*freqFeatUnorderedMap)[numFreqFeat] = i; + } else { + // Since numFreqFeat is already a frequent feature, its mapped index, (*freqFeatUnorderedMap)[numFreqFeat], + // was a non-frequent feature. So in this case, numFreqFeat mapping is removed and its mapped feature index + // is used for i instead. + int oldFreqFeat = numFreqFeat; + int oldFreqFeatMappedTo = searchedRes->second; + freqFeatUnorderedMap->erase(searchedRes); + (*freqFeatUnorderedMap)[oldFreqFeatMappedTo] = i; + (*freqFeatUnorderedMap)[i] = oldFreqFeatMappedTo; + } + } + numFreqFeat++; + } + // freqFeatDirectMap is a direct map and -1 means freqFeatDirectMap[i] = i (identity mapping) + freqFeatDirectMap[i] = -1; + } + + // Here, using the unordered_map, we set the direct map accordingly + auto endOfUnorderedMap = freqFeatUnorderedMap->end(); + for (auto it = freqFeatUnorderedMap->begin(); it != endOfUnorderedMap; it++) { + freqFeatDirectMap[it->first] = it->second; + } +} + +// This method, remap an instance using the maps provided by ComputeRemapping +void RemapInstances(int* instSizes, int** instIndices, float** instValues, int myStart, int myEnd, SymSGDState* state) { + int* freqFeatDirectMap = state->FreqFeatDirectMap; + std::unordered_map* freqFeatUnorderedMap = (std::unordered_map*)state->FreqFeatUnorderedMap; + auto itBegin = freqFeatUnorderedMap->begin(); + auto itEnd = freqFeatUnorderedMap->begin(); + for (int j = myStart; j < myEnd; j++) { + int instSize = instSizes[j]; + // Check if instance is sparse + if (instIndices[j]) { + // Just swap the indices accordingly + for (int k = 0; k < instSize; k++) { + int oldIndex = instIndices[j][k]; + // Direct map is used here since it is much more efficient to access it + auto mappedIndex = freqFeatDirectMap[oldIndex]; + if (mappedIndex != -1) { + int newIndex = mappedIndex; + instIndices[j][k] = newIndex; + } + } + } else + { + // If the instance is dense, we have to swap the values instead + for (auto it = itBegin; it != itEnd; it++) { + float temp = instValues[j][it->second]; + instValues[j][it->second] = instValues[j][it->first]; + instValues[j][it->first] = temp; + } + } + } +} + +float MaxPossibleAlpha(float alpha, float l2Const, int totalNumInstances) { + return (1.0f - pow(10.0f, -6.0f / (float)totalNumInstances)) / l2Const; +} + +void TuneAlpha(float& alpha, float l2Const, int totalNumInstances, int* instSizes, int** instIndices, + float** instValues, int numFeat, int numThreads) { + alpha = 1e0f; + int logSqrtNumInst = (int) round(log10(sqrt(totalNumInstances)))-3; + //int logAverageNorm = int(log10(averageNorm)); + if (logSqrtNumInst > 0) + for (int i = 0; i < logSqrtNumInst; i++) + alpha = alpha / 10.0f; + else if (logSqrtNumInst < 0) + for (int i = 0; i < -logSqrtNumInst; i++) + alpha = alpha * 10.0f; + + // If we have l2Const > 0, we want to make sure alpha is not too large + if (l2Const > 0) { + // Since weightScaling is multiplied by (1-alpha*lambda), + // we should make sure (1-alpha*lamda)^totalNumInstances > 1e-6 which is + // the threshold for applying the weightScaling. Therefore, + // alpha < (1-10^(-6/totalNumInstances))/l2Const. + alpha = MIN(alpha, MaxPossibleAlpha(alpha, l2Const, totalNumInstances)); + } + + printf("Initial learning rate is tuned to %f\n", alpha); +} + + +void TuneNumLocIter(int& numLocIter, int totalNumInstances, int* instSizes, int numThreads) { + int averageInstSizes = 0; + for (int i = 0; i < totalNumInstances; i++) + averageInstSizes += instSizes[i]; + averageInstSizes = averageInstSizes / totalNumInstances; + + if (averageInstSizes > 1000) + numLocIter = 40 / numThreads; + else + numLocIter = 160 / numThreads; +} + + +// This method sets SymSGDState and computes Remapping for indices and allocates the +// required memory for SymSGD learners. +void InitializeState(int totalNumInstances, int* instSizes, int** instIndices, float** instValues, + int numFeat, bool tuneNumLocIter, int& numLocIter, int numThreads, bool tuneAlpha, float& alpha, + float l2Const, SymSGDState* state) { + if (tuneAlpha) { + TuneAlpha(alpha, l2Const, totalNumInstances, instSizes, instIndices, instValues, numFeat, numThreads); + } else { + // Check if user alpha is too large because of l2Const. Check the comment about positive l2Const in TuneAlpha. + float maxPossibleAlpha = MaxPossibleAlpha(alpha, l2Const, totalNumInstances); + if (alpha > maxPossibleAlpha) + printf("Warning: learning rate is too high! Try using a value < %e instead\n", maxPossibleAlpha); + } + + if (tuneNumLocIter) + TuneNumLocIter(numLocIter, totalNumInstances, instSizes, numThreads); + + state->WeightScaling = 1.0f; +#if defined(USE_OMP) + if (numThreads > 1) { + state->PassIteration = 0; + state->NumFrequentFeatures = 0; + state->TotalInstancesProcessed = 0; + ComputeRemapping(totalNumInstances, instSizes, instIndices, numFeat, + numLocIter, numThreads, state, state->NumFrequentFeatures); + printf("Number of frequent features: %d\nNumber of features: %d\n", state->NumFrequentFeatures, numFeat); + + state->NumLearners = numThreads; + state->Learners = new SymSGD*[numThreads]; + SymSGD** learners = (SymSGD**)(state->Learners); + + // Allocation of SymSGD learners happens in parallel to follow the first touch policy. + #pragma omp parallel num_threads(numThreads) + { + int threadId = omp_get_thread_num(); + learners[threadId] = new SymSGD(state->NumFrequentFeatures, threadId); + } + } + + // To make sure that MKL runs sequentially + omp_set_num_threads(1); +#endif +} + +float Loss(int instSize, int* instIndices, float* instValues, + float label, float piw, float& weightScaling, float* weightVector, float& bias) { + float dotProduct = 0.0f; + if (instIndices) // If it is a sparse instance + dotProduct = SDOTI(instSize, instIndices, instValues, weightVector) * weightScaling + bias; + else // If it is dense case. + dotProduct = SDOT(instSize, instValues, weightVector) * weightScaling + bias; + float sigmoidPrediction = 1.0f / (1.0f + exp(-dotProduct)); + float loss = (label > 0) ? -log2(sigmoidPrediction) : -log2(1 - sigmoidPrediction); + // To prevent from loss going to infinity + if (loss > 100.0f) + loss = 100.0f; + if (label > 0) + loss *= piw; + return loss; +} + +// This methdo learns for loaded instance for as many passes as demanded +// Note that InitializeState should be called before this method +EXPORT_API(void) LearnAll(int totalNumInstances, int* instSizes, int** instIndices, float** instValues, + float* labels, bool tuneAlpha, float& alpha, float l2Const, float piw, float* weightVector, + float& bias, int numFeat, int numPasses, int numThreads, bool tuneNumLocIter, int& numLocIter, float tolerance, + bool needShuffle, bool shouldInitialize, SymSGDState* state) +{ + // If this is the first time LearnAll is called, initialize it. + if (shouldInitialize) + InitializeState(totalNumInstances, instSizes, instIndices, instValues, numFeat, tuneNumLocIter, numLocIter, numThreads, tuneAlpha, alpha, l2Const, state); + float& weightScaling = state->WeightScaling; + + float totalAverageLoss = 0.0f; // Reserved for total loss computation + float oldAverageLoss = INFINITY; + float olderAverageLoss = INFINITY; + float oldestAverageLoss = INFINITY; + float totalOverallAverageLoss = INFINITY; + + float adjustedAlpha = alpha; + // Check if totalNumInstances is too small to run in parallel + if (numThreads == 1 || totalNumInstances < numThreads) { + // For i=[0..totalNumInstances-1], (curPermMultiplier*i) % totalNumInstances always creates a pseduo random permutation + int64_t curPermMultiplier = (VERYLARGEPRIME % totalNumInstances); + // In the sequential case, just apply normal SGD + for (int i = 0; i < numPasses; i++) { + for (int j = 0; j < totalNumInstances; j++) { + int64_t index = j; + if (needShuffle) + index = (((int64_t)index * (int64_t)curPermMultiplier) % (int64_t)totalNumInstances); + // alpha decays with the square root of number of instances processed. + float thisAlpha = adjustedAlpha / (float)sqrt(1 + state->PassIteration * totalNumInstances + j); + LearnInstance(instSizes[index], instIndices[index], instValues[index], labels[index], thisAlpha, l2Const, piw, + weightScaling, weightVector, bias); + //state->TotalInstancesProcessed++; + if (weightScaling < 1e-6) { + for (int k = 0; k < numFeat; k++) { + weightVector[k] *= weightScaling; + } + weightScaling = 1.0f; + } + } + + float averageLoss = 0.0f; + // Computing the total loss + for (int j = 0; j < totalNumInstances; j++) + averageLoss += Loss(instSizes[j], instIndices[j], instValues[j], labels[j], piw, weightScaling, weightVector, bias); + averageLoss = averageLoss / (float)totalNumInstances; + // If we the loss did not improve the learning rate was high, decay it. + if (tuneAlpha && oldAverageLoss - averageLoss < tolerance) + adjustedAlpha = adjustedAlpha / 10.0f; + float overallAverageLoss = oldestAverageLoss - averageLoss; + oldestAverageLoss = olderAverageLoss; + olderAverageLoss = oldAverageLoss; + oldAverageLoss = averageLoss; + + averageLoss = 0.0f; + // Terminate if average loss difference between current model and the model from 3 passes ago is small + if (overallAverageLoss < tolerance) + break; + + // For shuffling in the next passes, instead of curPermMultiplier, use curPermMultiplier^2 which has exactly the same effect + if (needShuffle) + curPermMultiplier = (((int64_t)curPermMultiplier * (int64_t)curPermMultiplier) % (int64_t)totalNumInstances); + state->PassIteration++; + } + } else { +#if defined(USE_OMP) + // In parallel case... + bool shouldRemap = !((std::unordered_map*)state->FreqFeatUnorderedMap)->empty(); + SymSGD** learners = (SymSGD**)(state->Learners); + + float oldWeightScaling = 1.0f; + #pragma omp parallel num_threads(numThreads) + { + int threadId = omp_get_thread_num(); + // Compute the portion of instances associated with threadId + int myStart = (totalNumInstances * threadId) / numThreads; + int myEnd = (totalNumInstances * (threadId + 1)) / numThreads; + int myRangeLength = myEnd - myStart; + + if (shouldRemap) + RemapInstances(instSizes, instIndices, instValues, myStart, myEnd, state); + + // This variable is used to keep track of how many instances are learned so far to do a reduction + int instancesLearnedSinceReduction = 0; + + learners[threadId]->ResetModel(bias, weightVector, weightScaling); + int64_t curPermMultiplier = (VERYLARGEPRIME % myRangeLength); + + for (int i = 0; i < numPasses; i++) { + for (int j = 0; j < myRangeLength; j++) { + int64_t index = myStart + j; + if (needShuffle) + index = myStart + (((int64_t)j * (int64_t)curPermMultiplier) % (int64_t)myRangeLength); + // alpha decays with the square root of number of instances processed. + float thisAlpha = adjustedAlpha / (float)sqrt(1 + state->PassIteration*totalNumInstances + j*numThreads); + learners[threadId]->LearnLocalModel(instSizes[index], instIndices[index], instValues[index], labels[index], + thisAlpha, l2Const, piw, weightVector); + instancesLearnedSinceReduction++; + // If it reached numLocIter, do a reduction + if (instancesLearnedSinceReduction == numLocIter) { + learners[threadId]->Reduction(weightVector, bias, weightScaling); + learners[threadId]->ResetModel(bias, weightVector, weightScaling); + instancesLearnedSinceReduction = 0; + } + } + + // Check if we need to reweight the weight vector + if (l2Const > 0.0f) { + #pragma omp barrier + if (weightScaling < 1e-6) { + #pragma omp for + for (int featIndex = 0; featIndex < numFeat; featIndex++) + weightVector[featIndex] *= weightScaling; + if (threadId == 0) + weightScaling = 1.0f; + } + } + + #pragma omp barrier + #pragma omp for reduction(+:totalAverageLoss) + for (int j = 0; j < totalNumInstances; j++) + totalAverageLoss += Loss(instSizes[j], instIndices[j], instValues[j], labels[j], piw, weightScaling, weightVector, bias); + #pragma omp barrier + if (threadId == 0) { + totalAverageLoss = totalAverageLoss / (float)totalNumInstances; + // If we the loss did not improve the learning rate was high, decay it. + if (tuneAlpha && oldAverageLoss - totalAverageLoss < tolerance) + adjustedAlpha = adjustedAlpha / 10.0f; + state->PassIteration++; + totalOverallAverageLoss = oldestAverageLoss - totalAverageLoss; + oldestAverageLoss = olderAverageLoss; + olderAverageLoss = oldAverageLoss; + oldAverageLoss = totalAverageLoss; + + totalAverageLoss = 0.0f; + } + #pragma omp barrier + // Terminate if average loss difference between current model and the model from 3 passes ago is small + if (totalOverallAverageLoss < tolerance) + break; + + if (needShuffle) + curPermMultiplier = (((int64_t)curPermMultiplier * (int64_t)curPermMultiplier) % (int64_t)myRangeLength); + } + } + state->TotalInstancesProcessed += numPasses*totalNumInstances; +#endif + } +} + +// This method maps back the weight vector to the original feature space +EXPORT_API(void) MapBackWeightVector(float* weightVector, SymSGDState* state) { + std::unordered_map* freqFeatUnorderedMap = (std::unordered_map*)state->FreqFeatUnorderedMap; + auto endOfUnorderedMap = freqFeatUnorderedMap->end(); + for (auto it = freqFeatUnorderedMap->begin(); it != endOfUnorderedMap; it++) { + if (it->first < it->second) { + float temp = weightVector[it->second]; + weightVector[it->second] = weightVector[it->first]; + weightVector[it->first] = temp; + } + } +} + +// Deallocation method +EXPORT_API(void) DeallocateSequentially(SymSGDState* state) { +#if defined(USE_OMP) + // To make sure that for the rest of MKL calls use parallelism + omp_set_num_threads(omp_get_num_procs()); +#endif + + SymSGD** learners = (SymSGD**)(state->Learners); + if (learners) { + for (int i = 0; i < state->NumLearners; i++) + delete learners[i]; + } + if (state->FreqFeatUnorderedMap) + delete (std::unordered_map*)state->FreqFeatUnorderedMap; + if (state->FreqFeatDirectMap) + delete[] state->FreqFeatDirectMap; +} \ No newline at end of file diff --git a/src/Native/SymSgdNative/SymSgdNative.h b/src/Native/SymSgdNative/SymSgdNative.h new file mode 100644 index 0000000000..286cf72fbd --- /dev/null +++ b/src/Native/SymSgdNative/SymSgdNative.h @@ -0,0 +1,131 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#pragma once +#include "../Stdafx.h" + +using namespace std; + +// In almost every sparse dataset, there is a great imbalance in frequency of features +// This class learns a local model for frequent features and modifies the global model for non-frequent features +// and only applies the frequent features it learned locally to the global model after certain number of iterations +class SymSGD { +private: + int _numFreqFeat; + // Local models that is learned + float* _localModel; + // A copy of the local model when it started learning + float* _startModel; + + // Bias is by default a frequent feature + float _bias, _startBias; + + // Local weightScaling for L2 regularization + float _weightScaling, _startWeightScaling; +public: + SymSGD(int numFreqFeat, int seed) { + _numFreqFeat = numFreqFeat; + if (numFreqFeat > 0) { + _localModel = new float[numFreqFeat]; + _startModel = new float[numFreqFeat]; + } else + { + _localModel = NULL; + _startModel = NULL; + } + } + + ~SymSGD() { + if (_numFreqFeat > 0) { + delete[] _localModel; + delete[] _startModel; + } + } + + // Learns for the local model on frequent features and global model for non-frequent features + void LearnLocalModel(int instSize, int * instIndices, + float * instValues, float instLabel, float alpha, float l2Const, float piw, float* globModel) { + float dotProduct = 0.0f; + // Check if it is a sparse instance + if (instIndices) { + for (int i = 0; i < instSize; i++) { + int curIndex = instIndices[i]; + if (curIndex < _numFreqFeat) { + // dotProduct on freqeunt features are computed with local model + dotProduct += _localModel[curIndex] * instValues[i]; + } else + { + // Otherwise on global model + dotProduct += globModel[curIndex] * instValues[i]; + } + } + } else + { + // In dense case scenario, there is no need to check on indices + dotProduct += SDOT(_numFreqFeat, &_localModel[0], instValues) + + SDOT(instSize - _numFreqFeat, &globModel[_numFreqFeat], &instValues[_numFreqFeat]); + } + dotProduct = dotProduct*_weightScaling + _bias; + + _weightScaling *= (1.0f - alpha*l2Const); + + // Compute the derivative coefficient + float sigmoidPrediction = 1.0f / (1.0f + exp(-dotProduct)); + float derivative = (instLabel > 0) ? (sigmoidPrediction - 1) : sigmoidPrediction; + if (instLabel > 0) + derivative *= piw; + float derivativeCoef = -2 * alpha * derivative; + float derivativeWeightScaledCoef = derivativeCoef / _weightScaling; + if (instIndices) { + for (int i = 0; i < instSize; i++) { + int curIndex = instIndices[i]; + if (curIndex < _numFreqFeat) { // Apply the gradient to the local model for frequent features + _localModel[curIndex] += derivativeWeightScaledCoef * instValues[i]; + } else // Apply the gradient to the global model for non-frequent features + { + globModel[curIndex] += derivativeWeightScaledCoef * instValues[i]; + } + } + } else + { + // In dense case scenario, there is no need to check on indices + SAXPY(_numFreqFeat, instValues, &_localModel[0], derivativeWeightScaledCoef); + SAXPY(instSize - _numFreqFeat, &instValues[_numFreqFeat], &globModel[_numFreqFeat], derivativeWeightScaledCoef); + } + _bias = _bias + derivativeCoef; + } + + // This method copies the global models to _localModel and _startModel + void ResetModel(float bias, float* globModel, float weightScaling) { + memcpy(&_localModel[0], globModel, _numFreqFeat * sizeof(float)); + memcpy(&_startModel[0], &_localModel[0], _numFreqFeat * sizeof(float)); + + _bias = bias; + _startBias = bias; + _weightScaling = weightScaling; + _startWeightScaling = weightScaling; + } + + // Adds the delta of the _localModel and _startModel to the global model for frequent features + void Reduction(float* globModel, float& bias, float& weightScaling) { + for (int i = 0; i < _numFreqFeat; i++) { + globModel[i] += _localModel[i] - _startModel[i]; + } + bias += _bias - _startBias; + weightScaling *= (_weightScaling / _startWeightScaling); + } +}; + +// The state that is shared between SymSGD and SymSGDNative +struct SymSGDState +{ + int NumLearners; + int TotalInstancesProcessed; + void* Learners; + void* FreqFeatUnorderedMap; + int* FreqFeatDirectMap; + int NumFrequentFeatures; + int PassIteration; + float WeightScaling; +}; \ No newline at end of file diff --git a/src/Native/build.cmd b/src/Native/build.cmd index e2bbc3a4dc..2e0b6bfb11 100644 --- a/src/Native/build.cmd +++ b/src/Native/build.cmd @@ -106,6 +106,13 @@ if exist "%__IntermediatesDir%\INSTALL.vcxproj" goto BuildNativeProj goto :Failure :BuildNativeProj +echo Copying MKL library in bin folder. This is a temporary fix. +mkdir "%__binDir%\AnyCPU.%CMAKE_BUILD_TYPE%" +mkdir "%__binDir%\AnyCPU.%CMAKE_BUILD_TYPE%\Microsoft.ML.Tests" +mkdir "%__binDir%\AnyCPU.%CMAKE_BUILD_TYPE%\Microsoft.ML.Tests\netcoreapp2.0" +mkdir "%__binDir%\AnyCPU.%CMAKE_BUILD_TYPE%\Microsoft.ML.Predictor.Tests" +mkdir "%__binDir%\AnyCPU.%CMAKE_BUILD_TYPE%\Microsoft.ML.Predictor.Tests\netcoreapp2.0" + :: Build the project created by Cmake set __msbuildArgs=/p:Platform=%__BuildArch% /p:PlatformToolset="%__PlatformToolset%" diff --git a/src/Native/build.proj b/src/Native/build.proj index 41611a24ff..fe9e53cd61 100644 --- a/src/Native/build.proj +++ b/src/Native/build.proj @@ -90,6 +90,8 @@ RelativePath="Microsoft.ML\runtimes\$(PackageRid)\native" /> + diff --git a/src/Native/build.sh b/src/Native/build.sh index 7f5b835cad..402a1a4b91 100755 --- a/src/Native/build.sh +++ b/src/Native/build.sh @@ -98,3 +98,7 @@ set -x # turn on trace cmake "$DIR" -G "Unix Makefiles" $__cmake_defines set +x # turn off trace make install +echo "Changing libMklImports.dylib's executable path within libSymSgdNative.dylib so that loader can find it." +if [[ "$OSTYPE" == "darwin"* ]]; then + install_name_tool "-change" @loader_path/libMklImports.dylib @rpath/libMklImports.dylib "$RootRepo"/bin/x64."$__configuration"/Native/libSymSgdNative.dylib +fi \ No newline at end of file diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 59d168cc3c..3331225365 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -66,6 +66,7 @@ Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary mod Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Runtime.Learners.Sdca TrainMultiClass Microsoft.ML.Runtime.Learners.SdcaMultiClassTrainer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MulticlassClassificationOutput Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Runtime.Learners.Sdca TrainRegression Microsoft.ML.Runtime.Learners.SdcaRegressionTrainer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+RegressionOutput Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Runtime.Learners.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Runtime.Learners.StochasticGradientDescentClassificationTrainer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+BinaryClassificationOutput +Trainers.SymSgdBinaryClassifier Train a symbolic SGD. Microsoft.ML.Runtime.SymSgd.SymSgdClassificationTrainer TrainSymSgd Microsoft.ML.Runtime.SymSgd.SymSgdClassificationTrainer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+BinaryClassificationOutput Transforms.ApproximateBootstrapSampler Approximate bootstrap sampling. Microsoft.ML.Runtime.Data.BootstrapSample GetSample Microsoft.ML.Runtime.Data.BootstrapSampleTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames the PredictedLabel and Score columns to include the name of the positive class. Microsoft.ML.Runtime.EntryPoints.ScoreModel RenameBinaryPredictionScoreColumns Microsoft.ML.Runtime.EntryPoints.ScoreModel+RenameBinaryPredictionScoreColumnsInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.BinNormalizer The values are assigned into equidensity bins and a value is mapped to its bin_number/number_of_bins. Microsoft.ML.Runtime.Data.Normalize Bin Microsoft.ML.Runtime.Data.NormalizeTransform+BinArguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 35b554c6ed..51b69ec5c8 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -15790,6 +15790,253 @@ "ITrainerOutput" ] }, + { + "Name": "Trainers.SymSgdBinaryClassifier", + "Desc": "Train a symbolic SGD.", + "FriendlyName": "Symbolic SGD (binary)", + "ShortName": "SymSGD", + "Inputs": [ + { + "Name": "TrainingData", + "Type": "DataView", + "Desc": "The data to be used for training", + "Aliases": [ + "data" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "FeatureColumn", + "Type": "String", + "Desc": "Column to use for features", + "Aliases": [ + "feat" + ], + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": "Features" + }, + { + "Name": "LabelColumn", + "Type": "String", + "Desc": "Column to use for labels", + "Aliases": [ + "lab" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": "Label" + }, + { + "Name": "NormalizeFeatures", + "Type": { + "Kind": "Enum", + "Values": [ + "No", + "Warn", + "Auto", + "Yes" + ] + }, + "Desc": "Normalize option for the feature column", + "Aliases": [ + "norm" + ], + "Required": false, + "SortOrder": 5.0, + "IsNullable": false, + "Default": "Auto" + }, + { + "Name": "Caching", + "Type": { + "Kind": "Enum", + "Values": [ + "Auto", + "Memory", + "Disk", + "None" + ] + }, + "Desc": "Whether learner should cache input training data", + "Aliases": [ + "cache" + ], + "Required": false, + "SortOrder": 6.0, + "IsNullable": false, + "Default": "Auto" + }, + { + "Name": "NumberOfIterations", + "Type": "Int", + "Desc": "Number of passes over the data.", + "Aliases": [ + "iter" + ], + "Required": false, + "SortOrder": 50.0, + "IsNullable": false, + "Default": 50, + "SweepRange": { + "RangeType": "Discrete", + "Values": [ + 1, + 5, + 10, + 20, + 30, + 40, + 50 + ] + } + }, + { + "Name": "LearningRate", + "Type": "Float", + "Desc": "Learning rate", + "Aliases": [ + "lr" + ], + "Required": false, + "SortOrder": 51.0, + "IsNullable": true, + "Default": null, + "SweepRange": { + "RangeType": "Discrete", + "Values": [ + "", + 10.0, + 1.0, + 0.1, + 0.01, + 0.001 + ] + } + }, + { + "Name": "L2Regularization", + "Type": "Float", + "Desc": "L2 regularization", + "Aliases": [ + "l2" + ], + "Required": false, + "SortOrder": 52.0, + "IsNullable": false, + "Default": 0.0, + "SweepRange": { + "RangeType": "Discrete", + "Values": [ + 0.0, + 1E-05, + 1E-05, + 1E-06, + 1E-07 + ] + } + }, + { + "Name": "NumberOfThreads", + "Type": "Int", + "Desc": "Degree of lock-free parallelism. Determinism not guaranteed. Multi-threading is not supported currently.", + "Aliases": [ + "nt" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "Tol", + "Type": "Float", + "Desc": "Tolerance for difference in average loss in consecutive passes.", + "Aliases": [ + "tol" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 0.0001 + }, + { + "Name": "UpdateFrequency", + "Type": "Int", + "Desc": "The number of iterations each thread learns a local model until combining it with the global model. Low value means more updated global model and high value means less cache traffic.", + "Aliases": [ + "freq" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": true, + "Default": null, + "SweepRange": { + "RangeType": "Discrete", + "Values": [ + "", + 5, + 20 + ] + } + }, + { + "Name": "MemorySize", + "Type": "Int", + "Desc": "The acceleration memory budget in MB", + "Aliases": [ + "accelMemBudget" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 1024 + }, + { + "Name": "Shuffle", + "Type": "Bool", + "Desc": "Shuffle data?", + "Aliases": [ + "shuf" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": true + }, + { + "Name": "PositiveInstanceWeight", + "Type": "Float", + "Desc": "Apply weight to the positive class, for imbalanced data", + "Aliases": [ + "piw" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": 1.0 + } + ], + "Outputs": [ + { + "Name": "PredictorModel", + "Type": "PredictorModel", + "Desc": "The trained model" + } + ], + "InputKind": [ + "ITrainerInputWithLabel", + "ITrainerInput" + ], + "OutputKind": [ + "IBinaryClassificationOutput", + "ITrainerOutput" + ] + }, { "Name": "Transforms.ApproximateBootstrapSampler", "Desc": "Approximate bootstrap sampling.", diff --git a/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-CV-breast-cancer-out.txt b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-CV-breast-cancer-out.txt new file mode 100644 index 0000000000..e8f77de3b1 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-CV-breast-cancer-out.txt @@ -0,0 +1,65 @@ +maml.exe CV tr=SymSGD{nt=1} threads=- norm=No dout=%Output% data=%Data% seed=1 +Not adding a normalizer. +Data fully loaded into memory. +Not training a calibrator because it is not needed. +Not adding a normalizer. +Data fully loaded into memory. +Not training a calibrator because it is not needed. +Warning: The predictor produced non-finite prediction values on 8 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable. +TEST POSITIVE RATIO: 0.3785 (134.0/(134.0+220.0)) +Confusion table + ||====================== +PREDICTED || positive | negative | Recall +TRUTH ||====================== + positive || 132 | 2 | 0.9851 + negative || 8 | 212 | 0.9636 + ||====================== +Precision || 0.9429 | 0.9907 | +OVERALL 0/1 ACCURACY: 0.971751 +LOG LOSS/instance: Infinity +Test-set entropy (prior Log-Loss/instance): 0.956998 +LOG-LOSS REDUCTION (RIG): -Infinity +AUC: 0.991045 +Warning: The predictor produced non-finite prediction values on 8 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable. +TEST POSITIVE RATIO: 0.3191 (105.0/(105.0+224.0)) +Confusion table + ||====================== +PREDICTED || positive | negative | Recall +TRUTH ||====================== + positive || 96 | 9 | 0.9143 + negative || 11 | 213 | 0.9509 + ||====================== +Precision || 0.8972 | 0.9595 | +OVERALL 0/1 ACCURACY: 0.939210 +LOG LOSS/instance: Infinity +Test-set entropy (prior Log-Loss/instance): 0.903454 +LOG-LOSS REDUCTION (RIG): -Infinity +AUC: 0.963435 + +OVERALL RESULTS +--------------------------------------- +AUC: 0.977240 (0.0138) +Accuracy: 0.955481 (0.0163) +Positive precision: 0.920027 (0.0228) +Positive recall: 0.949680 (0.0354) +Negative precision: 0.975057 (0.0156) +Negative recall: 0.957265 (0.0064) +Log-loss: Infinity (NaN) +Log-loss reduction: -Infinity (NaN) +F1 Score: 0.934582 (0.0289) +AUPRC: 0.964431 (0.0168) + +--------------------------------------- +Physical memory usage(MB): %Number% +Virtual memory usage(MB): %Number% +%DateTime% Time elapsed(s): %Number% + +--- Progress log --- +[1] 'Preprocessing' started. +[1] 'Preprocessing' finished in %Time%. +[2] 'Training' started. +[2] 'Training' finished in %Time%. +[3] 'Preprocessing #2' started. +[3] 'Preprocessing #2' finished in %Time%. +[4] 'Training #2' started. +[4] 'Training #2' finished in %Time%. diff --git a/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-CV-breast-cancer-rp.txt b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-CV-breast-cancer-rp.txt new file mode 100644 index 0000000000..cbe85dbf29 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-CV-breast-cancer-rp.txt @@ -0,0 +1,4 @@ +SymSGD +AUC Accuracy Positive precision Positive recall Negative precision Negative recall Log-loss Log-loss reduction F1 Score AUPRC /nt Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings +0.97724 0.955481 0.920027 0.94968 0.975057 0.957265 Infinity -Infinity 0.934582 0.964431 1 SymSGD %Data% %Output% 99 0 0 maml.exe CV tr=SymSGD{nt=1} threads=- norm=No dout=%Output% data=%Data% seed=1 /nt:1 + diff --git a/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-CV-breast-cancer.txt b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-CV-breast-cancer.txt new file mode 100644 index 0000000000..d0e7499c6d --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-CV-breast-cancer.txt @@ -0,0 +1,700 @@ +Instance Label Score Probability Log-loss Assigned +5 1 1021.923 1 0 1 +6 0 101.91272 1 Infinity 1 +8 0 -334.817749 0 0 0 +9 0 -235.957581 0 0 0 +10 0 -272.79837 0 0 0 +11 0 -320.737427 0 0 0 +18 1 647.4429 1 0 1 +20 1 217.233337 1 0 1 +21 1 479.27478 1 0 1 +25 1 140.3844 1 0 1 +28 0 -320.737427 0 0 0 +31 0 -296.864 0 0 0 +32 1 454.251282 1 0 1 +35 0 -320.737427 0 0 0 +37 0 -78.9787 5.011722E-35 0 0 +40 0 +41 1 199.0091 1 0 1 +44 1 656.8247 1 0 1 +45 0 -322.804565 0 0 0 +46 1 682.001953 1 0 1 +48 0 -308.837372 0 0 0 +50 1 269.319824 1 0 1 +51 1 55.29657 1 0 1 +52 1 287.521057 1 0 1 +54 1 316.589355 1 0 1 +56 1 913.6865 1 0 1 +60 1 119.077576 1 0 1 +63 1 4.26116943 0.9860904 0.020208151201720485 1 +64 0 -325.252838 0 0 0 +66 0 -287.8332 0 0 0 +68 1 560.21106 1 0 1 +69 0 -288.1764 0 0 0 +70 0 -262.6084 0 0 0 +71 1 377.826416 1 0 1 +72 0 -7.10409546 0.00082105794 0.0011850227709561453 0 +73 1 361.900269 1 0 1 +74 1 308.8277 1 0 1 +76 0 -258.3786 0 0 0 +77 0 -166.195282 0 0 0 +79 0 -349.318481 0 0 0 +82 0 -207.568787 0 0 0 +88 0 -287.8332 0 0 0 +90 0 -301.379425 0 0 0 +91 0 -350.5975 0 0 0 +92 0 -287.8332 0 0 0 +93 0 -325.252838 0 0 0 +95 0 -301.379425 0 0 0 +96 0 -355.1129 0 0 0 +97 0 -283.3178 0 0 0 +98 1 380.50592 1 0 1 +99 1 571.659058 1 0 1 +100 1 51.7908325 1 0 1 +102 0 -265.5418 0 0 0 +104 1 527.381836 1 0 1 +105 1 1.51104736 0.8192164 0.28768354956614323 1 +106 1 975.2739 1 0 1 +108 0 -298.846741 0 0 0 +109 1 582.2739 1 0 1 +111 1 427.5822 1 0 1 +112 1 370.9096 1 0 1 +113 1 805.0089 1 0 1 +115 0 -169.885254 0 0 0 +117 1 521.3834 1 0 1 +120 0 -289.590118 0 0 0 +121 0 -167.273956 0 0 0 +122 1 864.813965 1 0 1 +123 1 263.304382 1 0 1 +125 0 -325.252838 0 0 0 +128 1 277.651428 1 0 1 +129 0 -577.7748 0 0 0 +131 0 -296.864 0 0 0 +132 1 762.661865 1 0 1 +133 0 -303.018982 0 0 0 +137 0 -340.095428 0 0 0 +138 0 -289.415222 0 0 0 +141 0 -344.61084 0 0 0 +144 0 -320.737427 0 0 0 +145 0 +147 0 -309.023651 0 0 0 +150 0 -272.79837 0 0 0 +151 1 249.55658 1 0 1 +152 1 884.802856 1 0 1 +154 0 -349.126221 0 0 0 +156 0 -227.212433 0 0 0 +161 0 -274.6302 0 0 0 +164 0 +167 1 283.316284 1 0 1 +169 0 -331.047241 0 0 0 +171 0 -301.379425 0 0 0 +173 1 1041.47913 1 0 1 +174 1 612.9641 1 0 1 +176 0 -296.864 0 0 0 +177 1 578.026 1 0 1 +179 1 180.726685 1 0 1 +180 0 -272.79837 0 0 0 +181 0 -349.126221 0 0 0 +183 1 834.9143 1 0 1 +187 1 544.907654 1 0 1 +188 1 709.276855 1 0 1 +189 0 -181.0476 0 0 0 +191 1 457.053833 1 0 1 +192 0 -307.1912 0 0 0 +196 0 384.403748 1 Infinity 1 +198 0 -349.126221 0 0 0 +199 0 -316.222015 0 0 0 +201 1 688.0464 1 0 1 +202 0 -301.379425 0 0 0 +204 0 -301.379425 0 0 0 +205 1 1028.11108 1 0 1 +206 1 667.4751 1 0 1 +207 0 -272.79837 0 0 0 +209 0 -254.736725 0 0 0 +210 1 1086.058 1 0 1 +211 1 925.716064 1 0 1 +212 0 -301.379425 0 0 0 +216 0 -325.252838 0 0 0 +218 1 734.4236 1 0 1 +219 0 -234.219574 0 0 0 +223 1 539.3552 1 0 1 +226 1 715.6188 1 0 1 +228 0 -272.79837 0 0 0 +233 1 355.085876 1 0 1 +237 1 459.5498 1 0 1 +239 1 451.899231 1 0 1 +240 0 -212.39624 0 0 0 +241 0 -284.771729 0 0 0 +242 0 -296.864 0 0 0 +244 0 -301.379425 0 0 0 +246 1 1044.54126 1 0 1 +247 1 409.237671 1 0 1 +248 0 -221.818024 0 0 0 +249 0 +250 0 -251.085815 0 0 0 +252 0 307.68988 1 Infinity 1 +254 1 728.536743 1 0 1 +257 0 -316.222015 0 0 0 +258 0 -292.348633 0 0 0 +259 0 572.462158 1 Infinity 1 +260 1 531.1308 1 0 1 +262 1 912.760254 1 0 1 +267 1 408.017578 1 0 1 +268 1 736.6128 1 0 1 +269 0 -301.379425 0 0 0 +271 0 -283.3178 0 0 0 +272 1 408.017578 1 0 1 +275 0 +276 0 -316.222015 0 0 0 +277 0 -325.252838 0 0 0 +278 0 -301.379425 0 0 0 +279 1 438.823853 1 0 1 +280 0 -292.348633 0 0 0 +283 1 555.4083 1 0 1 +284 1 445.768616 1 0 1 +285 1 1001.10779 1 0 1 +288 1 54.43335 1 0 1 +290 0 -349.126221 0 0 0 +291 0 -301.379425 0 0 0 +293 1 386.6949 1 0 1 +296 0 139.220642 1 Infinity 1 +297 0 +299 1 227.814941 1 0 1 +300 1 407.6792 1 0 1 +301 0 -301.379425 0 0 0 +303 0 -301.379425 0 0 0 +304 1 265.7011 1 0 1 +308 1 588.7855 1 0 1 +309 0 -65.36084 4.11289773E-29 0 0 +311 0 -349.126221 0 0 0 +312 1 -94.50748 9.035572E-42 136.34536397147204 0 +314 0 -296.671753 0 0 0 +316 1 466.170166 1 0 1 +317 1 736.0132 1 0 1 +319 0 161.598083 1 Infinity 1 +321 0 +323 1 388.03302 1 0 1 +327 0 -325.252838 0 0 0 +328 1 584.984 1 0 1 +329 1 459.499573 1 0 1 +331 0 -280.869537 0 0 0 +332 0 -206.44986 0 0 0 +333 1 399.285278 1 0 1 +336 1 425.5838 1 0 1 +338 0 -296.671753 0 0 0 +343 0 -349.126221 0 0 0 +344 1 473.410339 1 0 1 +346 0 -245.701279 0 0 0 +347 0 -294.1391 0 0 0 +348 1 -152.083282 0 Infinity 0 +349 1 237.0896 1 0 1 +350 0 -352.0688 0 0 0 +352 0 39.6635437 1 Infinity 1 +353 1 660.963257 1 0 1 +354 0 -325.252838 0 0 0 +355 0 -327.084656 0 0 0 +358 1 590.6726 1 0 1 +360 1 943.9149 1 0 1 +361 1 683.516 1 0 1 +366 1 969.8589 1 0 1 +368 0 -304.543427 0 0 0 +370 0 -166.498291 0 0 0 +371 0 -304.543427 0 0 0 +373 0 -317.6933 0 0 0 +376 0 -325.252838 0 0 0 +377 0 -296.671753 0 0 0 +378 0 -363.106323 0 0 0 +379 0 -122.1077 0 0 0 +381 1 651.466431 1 0 1 +383 0 -344.61084 0 0 0 +384 0 -344.61084 0 0 0 +387 0 -126.32016 0 0 0 +388 0 -307.5344 0 0 0 +389 0 -277.759 0 0 0 +391 1 947.5735 1 0 1 +392 0 -316.222015 0 0 0 +395 0 -316.222015 0 0 0 +396 0 -292.348633 0 0 0 +398 0 -227.269958 0 0 0 +399 0 -228.794418 0 0 0 +404 0 -281.178345 0 0 0 +406 0 -213.6662 0 0 0 +409 0 -293.9306 0 0 0 +413 0 -261.0264 0 0 0 +414 1 646.5259 1 0 1 +415 0 -57.21808 1.41417837E-25 0 0 +416 1 582.0941 1 0 1 +418 0 -135.317932 0 0 0 +419 0 -278.856018 0 0 0 +422 0 -65.34631 4.17307942E-29 0 0 +423 0 -262.6084 0 0 0 +428 0 -325.252838 0 0 0 +429 0 -320.737427 0 0 0 +430 0 -160.551788 0 0 0 +434 0 683.9347 1 Infinity 1 +436 1 575.846558 1 0 1 +439 0 -331.0646 0 0 0 +440 1 429.814148 1 0 1 +441 0 -130.0997 0 0 0 +442 0 -280.509918 0 0 0 +449 1 803.155 1 0 1 +450 0 -304.1297 0 0 0 +451 0 -331.0646 0 0 0 +452 0 -361.0996 0 0 0 +453 1 566.1138 1 0 1 +454 0 -221.693909 0 0 0 +455 1 16.0523376 0.9999999 1.7198266111377426E-07 1 +456 1 636.5864 1 0 1 +457 1 647.529541 1 0 1 +464 0 -335.580017 0 0 0 +465 1 677.602661 1 0 1 +466 1 787.547852 1 0 1 +467 1 679.4784 1 0 1 +474 0 -331.0646 0 0 0 +480 0 -302.483521 0 0 0 +482 1 781.156738 1 0 1 +483 1 878.3706 1 0 1 +484 0 -308.7732 0 0 0 +487 1 945.2323 1 0 1 +489 1 24.710083 1 0 1 +492 0 -283.125549 0 0 0 +493 1 920.3756 1 0 1 +495 0 -287.64093 0 0 0 +497 0 -259.831 0 0 0 +501 0 -311.7066 0 0 0 +502 0 -322.208679 0 0 0 +504 0 -349.126221 0 0 0 +507 0 -214.695511 0 0 0 +510 0 -349.126221 0 0 0 +513 0 -287.64093 0 0 0 +514 1 754.5259 1 0 1 +517 0 -296.671753 0 0 0 +519 1 791.0669 1 0 1 +520 0 -377.7073 0 0 0 +521 0 -364.161072 0 0 0 +522 1 296.72522 1 0 1 +523 1 465.930176 1 0 1 +527 0 -287.8332 0 0 0 +528 0 -292.468475 0 0 0 +529 0 -283.125549 0 0 0 +531 0 -213.6662 0 0 0 +532 0 -272.79837 0 0 0 +533 0 -316.222015 0 0 0 +534 0 -320.737427 0 0 0 +535 0 -267.2987 0 0 0 +538 0 -311.7066 0 0 0 +539 0 -302.675781 0 0 0 +540 0 -262.380951 0 0 0 +541 0 -340.095428 0 0 0 +544 0 -286.656677 0 0 0 +546 1 1069.49353 1 0 1 +547 0 -316.029755 0 0 0 +548 0 -311.514343 0 0 0 +549 1 549.9683 1 0 1 +557 0 -352.0688 0 0 0 +558 0 -320.737427 0 0 0 +559 0 -307.1912 0 0 0 +560 0 -283.3178 0 0 0 +561 0 -283.3178 0 0 0 +563 0 -316.222015 0 0 0 +565 1 880.770142 1 0 1 +566 0 -270.05722 0 0 0 +569 1 740.958 1 0 1 +577 0 -325.252838 0 0 0 +578 0 -325.252838 0 0 0 +581 1 785.143066 1 0 1 +582 1 986.6736 1 0 1 +584 0 -412.1561 0 0 0 +586 1 1092.98242 1 0 1 +590 1 617.891 1 0 1 +593 0 -308.7732 0 0 0 +594 1 774.4863 1 0 1 +600 0 -316.222015 0 0 0 +602 0 -311.7066 0 0 0 +604 1 256.321228 1 0 1 +606 0 -346.0821 0 0 0 +607 0 -349.126221 0 0 0 +609 0 -331.0646 0 0 0 +612 1 1115.01685 1 0 1 +613 0 -169.23941 0 0 0 +614 0 -292.156342 0 0 0 +617 0 +618 0 -311.7066 0 0 0 +619 0 -307.1912 0 0 0 +621 0 -15.8763733 1.27344038E-07 1.8371862313930792E-07 0 +622 0 -296.873169 0 0 0 +624 0 -289.1122 0 0 0 +627 0 -165.369843 0 0 0 +629 0 -335.580017 0 0 0 +633 1 390.5418 1 0 1 +634 0 -340.095428 0 0 0 +638 0 -335.580017 0 0 0 +639 0 -352.0688 0 0 0 +641 0 -316.222015 0 0 0 +642 0 -316.222015 0 0 0 +644 0 -344.61084 0 0 0 +645 0 -316.222015 0 0 0 +649 0 -316.222015 0 0 0 +652 0 -293.988159 0 0 0 +653 0 -311.7066 0 0 0 +654 0 -292.348633 0 0 0 +656 0 -307.1912 0 0 0 +657 0 -72.37637 3.69266974E-32 0 0 +660 0 -325.252838 0 0 0 +661 0 -287.8332 0 0 0 +665 0 -349.126221 0 0 0 +668 1 342.943665 1 0 1 +670 1 812.6575 1 0 1 +678 0 -349.126221 0 0 0 +679 0 -344.61084 0 0 0 +680 1 1145.28394 1 0 1 +681 1 969.3158 1 0 1 +682 0 -270.114777 0 0 0 +683 0 -349.126221 0 0 0 +685 0 -349.126221 0 0 0 +688 0 -335.580017 0 0 0 +689 0 -331.988342 0 0 0 +691 1 742.5989 1 0 1 +692 0 -340.095428 0 0 0 +693 0 -313.773743 0 0 0 +694 0 -323.866241 0 0 0 +696 1 765.3994 1 0 1 +697 1 661.339233 1 0 1 +698 1 685.243042 1 0 1 +0 0 -655.8448 0 0 0 +1 0 352.095367 1 Infinity 1 +2 0 -559.3646 0 0 0 +3 0 1740.70276 1 Infinity 1 +4 0 -568.0029 0 0 0 +7 0 -495.5454 0 0 0 +12 1 241.3157 1 0 1 +13 0 -462.88446 0 0 0 +14 1 942.8955 1 0 1 +15 1 16.80188 0.99999994 8.5991327994145617E-08 1 +16 0 -553.6872 0 0 0 +17 0 -643.057739 0 0 0 +19 0 -668.631836 0 0 0 +22 0 -540.900146 0 0 0 +23 1 +24 0 -604.696655 0 0 0 +26 0 -270.657074 0 0 0 +27 0 -566.4742 0 0 0 +29 0 -182.078979 0 0 0 +30 0 -411.0862 0 0 0 +33 0 -579.956238 0 0 0 +34 0 -418.961884 0 0 0 +36 1 1591.25452 1 0 1 +38 1 1299.106 1 0 1 +39 1 138.435211 1 0 1 +42 1 1330.6217 1 0 1 +43 1 -316.820679 0 Infinity 0 +47 0 -515.32605 0 0 0 +49 1 1846.631 1 0 1 +53 1 -274.015076 0 Infinity 0 +55 1 1099.324 1 0 1 +57 1 -608.199 0 Infinity 0 +58 1 -331.3812 0 Infinity 0 +59 1 272.168976 1 0 1 +61 0 -444.419952 0 0 0 +62 1 1275.43091 1 0 1 +65 1 -452.4966 0 Infinity 0 +67 1 413.874054 1 0 1 +75 0 -419.5797 0 0 0 +78 0 -488.458527 0 0 0 +80 0 -582.7846 0 0 0 +81 0 -516.1598 0 0 0 +83 0 -916.846863 0 0 0 +84 1 897.471436 1 0 1 +85 1 746.382935 1 0 1 +86 1 635.3207 1 0 1 +87 1 1356.65088 1 0 1 +89 0 -620.399658 0 0 0 +94 0 -617.483643 0 0 0 +101 1 841.3391 1 0 1 +103 1 -1044.82617 0 Infinity 0 +107 1 1459.15247 1 0 1 +110 0 -263.1876 0 0 0 +114 0 -85.66205 6.272564E-38 0 0 +116 0 -16.9219666 4.47593E-08 6.4574021972779571E-08 0 +118 0 -543.7712 0 0 0 +119 0 -418.9355 0 0 0 +124 1 404.972565 1 0 1 +126 1 567.644653 1 0 1 +127 0 -630.2707 0 0 0 +130 0 -322.597717 0 0 0 +134 0 -670.7141 0 0 0 +135 0 -421.652374 0 0 0 +136 0 -553.6872 0 0 0 +139 0 +140 0 -451.529541 0 0 0 +142 1 488.315338 1 0 1 +143 0 -142.331116 0 0 0 +146 1 204.013458 1 0 1 +148 0 -941.3387 0 0 0 +149 1 752.705933 1 0 1 +153 0 -322.504425 0 0 0 +155 1 1089.60254 1 0 1 +157 0 -617.483643 0 0 0 +158 0 +159 1 1923.50525 1 0 1 +160 1 1494.5094 1 0 1 +162 0 -630.2707 0 0 0 +163 0 -310.768677 0 0 0 +165 0 -490.5085 0 0 0 +166 1 1570.026 1 0 1 +168 0 -630.2707 0 0 0 +170 0 -451.529541 0 0 0 +172 0 -515.32605 0 0 0 +175 1 1194.714 1 0 1 +178 0 -643.057739 0 0 0 +182 0 -668.631836 0 0 0 +184 1 1070.428 1 0 1 +185 0 -487.669739 0 0 0 +186 1 1482.4314 1 0 1 +190 1 2258.611 1 0 1 +193 0 -604.696655 0 0 0 +194 0 -630.2707 0 0 0 +195 0 -643.057739 0 0 0 +197 0 -543.2626 0 0 0 +200 1 1415.398 1 0 1 +203 0 -655.8448 0 0 0 +208 0 -474.8827 0 0 0 +213 1 2248.68115 1 0 1 +214 1 2270.2356 1 0 1 +215 1 1645.00281 1 0 1 +217 0 -604.696655 0 0 0 +220 0 -567.16925 0 0 0 +221 1 221.44339 1 0 1 +222 1 -65.19409 4.85921E-29 94.055192962286199 0 +224 1 1289.00232 1 0 1 +225 0 -515.32605 0 0 0 +227 1 1841.02966 1 0 1 +229 1 1514.78943 1 0 1 +230 1 930.4363 1 0 1 +231 1 1347.16248 1 0 1 +232 0 355.92923 1 Infinity 1 +234 0 50.92752 1 Infinity 1 +235 0 +236 1 1204.77161 1 0 1 +238 1 2421.62354 1 0 1 +243 0 -499.813416 0 0 0 +245 0 -547.411255 0 0 0 +251 1 963.113037 1 0 1 +253 1 1330.6217 1 0 1 +255 1 1480.26221 1 0 1 +256 0 -451.529541 0 0 0 +261 1 1146.05249 1 0 1 +263 1 671.7992 1 0 1 +264 1 168.765533 1 0 1 +265 0 -208.3869 0 0 0 +266 1 1778.757 1 0 1 +270 1 1587.86926 1 0 1 +273 1 71.0156555 1 0 1 +274 0 -548.627563 0 0 0 +281 0 -579.956238 0 0 0 +282 1 1006.21533 1 0 1 +286 1 1933.38391 1 0 1 +287 0 -670.7141 0 0 0 +289 1 1586.01746 1 0 1 +292 1 +294 0 +295 1 906.5393 1 0 1 +298 0 -764.4775 0 0 0 +302 1 1682.504 1 0 1 +305 1 1757.95251 1 0 1 +306 0 -604.696655 0 0 0 +307 0 -604.696655 0 0 0 +310 0 -657.927 0 0 0 +313 0 -425.9555 0 0 0 +315 0 +318 0 -994.1385 0 0 0 +320 1 276.6519 1 0 1 +322 0 -630.2707 0 0 0 +324 0 -604.696655 0 0 0 +325 0 -115.24649 0 0 0 +326 1 -15.4368286 1.97637988E-07 22.270636392005638 0 +330 1 698.895264 1 0 1 +334 1 1535.56677 1 0 1 +335 0 -425.9555 0 0 0 +337 0 -604.696655 0 0 0 +339 1 1217.44373 1 0 1 +340 1 493.843048 1 0 1 +341 0 -604.696655 0 0 0 +342 0 -438.7425 0 0 0 +345 0 -425.9555 0 0 0 +351 0 -617.483643 0 0 0 +356 1 -20.48114 1.27395428E-09 29.54803935647886 0 +357 1 1071.41785 1 0 1 +359 1 718.953 1 0 1 +362 0 -396.3484 0 0 0 +363 0 412.754547 1 Infinity 1 +364 0 -617.483643 0 0 0 +365 0 -528.1131 0 0 0 +367 1 1990.46619 1 0 1 +369 0 -141.63562 0 0 0 +372 0 -431.748932 0 0 0 +374 0 -418.961884 0 0 0 +375 0 -425.9555 0 0 0 +380 0 -425.9555 0 0 0 +382 0 -248.732788 0 0 0 +385 0 -124.032074 0 0 0 +386 1 986.1454 1 0 1 +390 0 -477.798676 0 0 0 +393 0 -296.141541 0 0 0 +394 0 -131.020447 0 0 0 +397 0 -464.3166 0 0 0 +400 1 1161.08484 1 0 1 +401 0 -451.529541 0 0 0 +402 0 -41.73947 7.460671E-19 0 0 +403 0 -238.811249 0 0 0 +405 0 -515.32605 0 0 0 +407 0 -515.32605 0 0 0 +408 0 -106.253662 0 0 0 +410 0 -515.32605 0 0 0 +411 0 +412 1 1142.7179 1 0 1 +417 0 -515.32605 0 0 0 +420 0 -151.036346 0 0 0 +421 1 1101.12451 1 0 1 +424 0 -451.529541 0 0 0 +425 1 1700.89758 1 0 1 +426 0 413.445831 1 Infinity 1 +427 1 1406.78931 1 0 1 +431 0 -758.7747 0 0 0 +432 0 -484.831055 0 0 0 +433 0 -114.107391 0 0 0 +435 1 1690.2677 1 0 1 +437 0 -464.3166 0 0 0 +438 0 -145.385315 0 0 0 +443 0 -355.049377 0 0 0 +444 0 -508.651184 0 0 0 +445 0 -438.7425 0 0 0 +446 0 -425.9555 0 0 0 +447 0 -477.103638 0 0 0 +448 0 -296.141541 0 0 0 +458 0 -355.1654 0 0 0 +459 0 -233.227112 0 0 0 +460 0 -402.048828 0 0 0 +461 0 -167.905182 0 0 0 +462 0 -414.835876 0 0 0 +463 0 -382.673462 0 0 0 +468 0 -464.3166 0 0 0 +469 0 -393.387817 0 0 0 +470 0 -411.0862 0 0 0 +471 0 -414.835876 0 0 0 +472 0 -360.076721 0 0 0 +473 0 -464.3166 0 0 0 +475 0 -451.529541 0 0 0 +476 0 -342.378357 0 0 0 +477 0 -464.3166 0 0 0 +478 0 -336.6745 0 0 0 +479 1 1756.93933 1 0 1 +481 0 38.2750549 1 Infinity 1 +485 0 -79.2746 3.728034E-35 0 0 +486 0 -411.0862 0 0 0 +488 1 1032.36389 1 0 1 +490 0 -425.9555 0 0 0 +491 1 1566.10583 1 0 1 +494 0 -82.79285 1.10541041E-36 0 0 +496 0 -296.141541 0 0 0 +498 0 -553.6872 0 0 0 +499 0 -553.6872 0 0 0 +500 0 -668.631836 0 0 0 +503 0 -643.057739 0 0 0 +505 0 -170.671326 0 0 0 +506 1 1927.56653 1 0 1 +508 0 -477.103638 0 0 0 +509 0 -438.7425 0 0 0 +511 0 -566.4742 0 0 0 +512 0 -477.103638 0 0 0 +515 1 1918.68909 1 0 1 +516 0 -296.141541 0 0 0 +518 0 -292.0639 0 0 0 +524 0 -540.900146 0 0 0 +525 0 -414.002136 0 0 0 +526 0 -464.3166 0 0 0 +530 1 944.298462 1 0 1 +536 0 -655.8448 0 0 0 +537 0 -533.9065 0 0 0 +542 0 -196.245392 0 0 0 +543 0 -553.6872 0 0 0 +545 0 -566.4742 0 0 0 +550 0 -540.900146 0 0 0 +551 0 -604.696655 0 0 0 +552 0 -338.1034 0 0 0 +553 0 240.835052 1 Infinity 1 +554 0 -451.529541 0 0 0 +555 0 119.931915 1 Infinity 1 +556 0 -136.7655 0 0 0 +562 0 -604.696655 0 0 0 +564 0 -561.4146 0 0 0 +567 0 -411.875 0 0 0 +568 1 595.410767 1 0 1 +570 1 542.114868 1 0 1 +571 1 1942.85852 1 0 1 +572 0 -540.900146 0 0 0 +573 0 -515.32605 0 0 0 +574 1 1153.99219 1 0 1 +575 0 -533.9065 0 0 0 +576 0 -566.4742 0 0 0 +579 0 -604.696655 0 0 0 +580 0 -444.53595 0 0 0 +583 0 -451.529541 0 0 0 +585 0 -425.9555 0 0 0 +587 0 -484.831055 0 0 0 +588 1 962.983643 1 0 1 +589 0 -477.103638 0 0 0 +591 1 1292.75977 1 0 1 +592 1 495.973053 1 0 1 +595 0 -566.4742 0 0 0 +596 0 -431.748932 0 0 0 +597 0 -411.968262 0 0 0 +598 0 -540.900146 0 0 0 +599 0 158.954132 1 Infinity 1 +601 0 -385.512115 0 0 0 +603 1 666.133057 1 0 1 +605 1 1270.70056 1 0 1 +608 1 1017.25757 1 0 1 +610 1 401.661346 1 0 1 +611 1 1604.39221 1 0 1 +615 0 -309.810669 0 0 0 +616 0 -540.900146 0 0 0 +620 0 -540.900146 0 0 0 +623 0 -425.9555 0 0 0 +625 0 -124.748688 0 0 0 +626 1 592.070435 1 0 1 +628 0 -438.7425 0 0 0 +630 0 -105.585052 0 0 0 +631 0 -566.4742 0 0 0 +632 0 -425.9555 0 0 0 +635 0 -85.71478 5.95035441E-38 0 0 +636 1 933.872437 1 0 1 +637 0 98.51761 1 Infinity 1 +640 0 -389.261841 0 0 0 +643 0 -425.9555 0 0 0 +646 0 -163.588135 0 0 0 +647 0 -350.9007 0 0 0 +648 1 893.8468 1 0 1 +650 0 -331.097778 0 0 0 +651 0 -299.842163 0 0 0 +655 0 -540.900146 0 0 0 +658 1 1548.16321 1 0 1 +659 0 -425.9555 0 0 0 +662 0 -271.4496 0 0 0 +663 0 -271.4496 0 0 0 +664 0 -465.845337 0 0 0 +666 0 -209.536652 0 0 0 +667 0 -630.2707 0 0 0 +669 1 2239.64185 1 0 1 +671 0 -452.314178 0 0 0 +672 0 -617.483643 0 0 0 +673 0 -204.121124 0 0 0 +674 0 -515.32605 0 0 0 +675 0 -98.50183 1.667545E-43 0 0 +676 0 -393.387817 0 0 0 +677 0 -477.103638 0 0 0 +684 0 -425.9555 0 0 0 +686 0 -425.9555 0 0 0 +687 0 -377.613831 0 0 0 +690 0 -350.9007 0 0 0 +695 0 -438.7425 0 0 0 diff --git a/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt new file mode 100644 index 0000000000..d27eaf83bd --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt @@ -0,0 +1,45 @@ +maml.exe TrainTest test=%Data% tr=SymSGD{nt=1} norm=No dout=%Output% data=%Data% out=%Output% seed=1 +Not adding a normalizer. +Data fully loaded into memory. +Not training a calibrator because it is not needed. +Warning: The predictor produced non-finite prediction values on 16 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable. +TEST POSITIVE RATIO: 0.3499 (239.0/(239.0+444.0)) +Confusion table + ||====================== +PREDICTED || positive | negative | Recall +TRUTH ||====================== + positive || 152 | 87 | 0.6360 + negative || 2 | 442 | 0.9955 + ||====================== +Precision || 0.9870 | 0.8355 | +OVERALL 0/1 ACCURACY: 0.869693 +LOG LOSS/instance: Infinity +Test-set entropy (prior Log-Loss/instance): 0.934003 +LOG-LOSS REDUCTION (RIG): -Infinity +AUC: 0.984941 + +OVERALL RESULTS +--------------------------------------- +AUC: 0.984941 (0.0000) +Accuracy: 0.869693 (0.0000) +Positive precision: 0.987013 (0.0000) +Positive recall: 0.635983 (0.0000) +Negative precision: 0.835539 (0.0000) +Negative recall: 0.995495 (0.0000) +Log-loss: Infinity (0.0000) +Log-loss reduction: -Infinity (0.0000) +F1 Score: 0.773537 (0.0000) +AUPRC: 0.977633 (0.0000) + +--------------------------------------- +Physical memory usage(MB): %Number% +Virtual memory usage(MB): %Number% +%DateTime% Time elapsed(s): %Number% + +--- Progress log --- +[1] 'Preprocessing' started. +[1] 'Preprocessing' finished in %Time%. +[2] 'Training' started. +[2] 'Training' finished in %Time%. +[3] 'Saving model' started. +[3] 'Saving model' finished in %Time%. diff --git a/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer-rp.txt b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer-rp.txt new file mode 100644 index 0000000000..c056310ab0 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer-rp.txt @@ -0,0 +1,4 @@ +SymSGD +AUC Accuracy Positive precision Positive recall Negative precision Negative recall Log-loss Log-loss reduction F1 Score AUPRC /nt Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings +0.984941 0.869693 0.987013 0.635983 0.835539 0.995495 Infinity -Infinity 0.773537 0.977633 1 SymSGD %Data% %Data% %Output% 99 0 0 maml.exe TrainTest test=%Data% tr=SymSGD{nt=1} norm=No dout=%Output% data=%Data% out=%Output% seed=1 /nt:1 + diff --git a/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer-summary.txt b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer-summary.txt new file mode 100644 index 0000000000..8fb0a3fe22 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer-summary.txt @@ -0,0 +1,12 @@ +Linear Binary Classification Predictor non-zero weights + +(Bias) -448.1 +f1 49.29393 +f4 -25.15009 +f5 23.68305 +f3 16.76877 +f7 13.76585 +f6 -6.658058 +f8 4.843107 +f2 -3.424153 +f0 -0.3852913 diff --git a/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer.txt b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer.txt new file mode 100644 index 0000000000..f744bc3525 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/SymSGD/SymSGD-TrainTest-breast-cancer.txt @@ -0,0 +1,700 @@ +Instance Label Score Probability Log-loss Assigned +0 0 -415.3703 0 0 0 +1 0 -109.523041 0 0 0 +2 0 -390.916656 0 0 0 +3 0 33.8270569 1 Infinity 1 +4 0 -381.447449 0 0 0 +5 1 243.726929 1 0 1 +6 0 -200.681656 0 0 0 +7 0 -417.63858 0 0 0 +8 0 -381.525879 0 0 0 +9 0 -359.03302 0 0 0 +10 0 -388.679047 0 0 0 +11 0 -407.556366 0 0 0 +12 1 -208.0876 0 Infinity 0 +13 0 -366.463 0 0 0 +14 1 136.450256 1 0 1 +15 1 -314.800262 0 Infinity 0 +16 0 -408.326935 0 0 0 +17 0 -414.985 0 0 0 +18 1 102.798035 1 0 1 +19 0 -415.755585 0 0 0 +20 1 10.1497192 0.9999609 5.6411412351548271E-05 1 +21 1 -61.521698 1.91190378E-27 88.757048641689195 0 +22 0 -407.94165 0 0 0 +23 1 +24 0 -413.829132 0 0 0 +25 1 -111.690765 0 Infinity 0 +26 0 -333.49762 0 0 0 +27 0 -408.712219 0 0 0 +28 0 -407.556366 0 0 0 +29 0 -407.361328 0 0 0 +30 0 -382.791565 0 0 0 +31 0 -414.214417 0 0 0 +32 1 -140.35733 0 Infinity 0 +33 0 -397.445648 0 0 0 +34 0 -411.365784 0 0 0 +35 0 -407.556366 0 0 0 +36 1 89.1488 1 0 1 +37 0 -367.9438 0 0 0 +38 1 125.049805 1 0 1 +39 1 -120.420319 0 Infinity 0 +40 0 +41 1 -214.114883 0 Infinity 0 +42 1 86.67383 1 0 1 +43 1 -299.954132 0 Infinity 0 +44 1 -14.4605713 5.246304E-07 20.86219519410675 0 +45 0 -402.327942 0 0 0 +46 1 139.792358 1 0 1 +47 0 -407.171082 0 0 0 +48 0 -381.447449 0 0 0 +49 1 141.825745 1 0 1 +50 1 -170.3082 0 Infinity 0 +51 1 -160.977692 0 Infinity 0 +52 1 -127.401 0 Infinity 0 +53 1 -119.971008 0 Infinity 0 +54 1 -161.725189 0 Infinity 0 +55 1 -41.1490173 1.34650766E-18 59.365483272233966 0 +56 1 199.693787 1 0 1 +57 1 -410.444183 0 Infinity 0 +58 1 -273.693665 0 Infinity 0 +59 1 -206.296631 0 Infinity 0 +60 1 -123.953339 0 Infinity 0 +61 0 -383.488 0 0 0 +62 1 -11.9711 6.324332E-06 17.270655479328113 0 +63 1 -269.080566 0 Infinity 0 +64 0 -407.171082 0 0 0 +65 1 -209.844543 0 Infinity 0 +66 0 -414.985 0 0 0 +67 1 -161.339172 0 Infinity 0 +68 1 -58.4673157 4.05478141E-26 84.350506323497086 0 +69 0 -400.063263 0 0 0 +70 0 -415.560547 0 0 0 +71 1 77.64612 1 0 1 +72 0 -318.187164 0 0 0 +73 1 70.44263 1 0 1 +74 1 -111.748413 0 Infinity 0 +75 0 -386.462433 0 0 0 +76 0 -417.4435 0 0 0 +77 0 -293.3556 0 0 0 +78 0 -367.233582 0 0 0 +79 0 -432.706451 0 0 0 +80 0 -369.826782 0 0 0 +81 0 -391.558167 0 0 0 +82 0 -366.076355 0 0 0 +83 0 -417.5489 0 0 0 +84 1 44.29547 1 0 1 +85 1 41.82538 1 0 1 +86 1 -157.855713 0 Infinity 0 +87 1 122.36322 1 0 1 +88 0 -414.985 0 0 0 +89 0 -415.937683 0 0 0 +90 0 -413.829132 0 0 0 +91 0 -384.514832 0 0 0 +92 0 -414.985 0 0 0 +93 0 -407.171082 0 0 0 +94 0 -414.214417 0 0 0 +95 0 -413.829132 0 0 0 +96 0 -384.129547 0 0 0 +97 0 -415.3703 0 0 0 +98 1 -103.478882 1.401298E-45 149 0 +99 1 150.554138 1 0 1 +100 1 -310.138184 0 Infinity 0 +101 1 -119.139008 0 Infinity 0 +102 0 -418.409149 0 0 0 +103 1 -453.947021 0 Infinity 0 +104 1 42.2218933 1 0 1 +105 1 -212.478668 0 Infinity 0 +106 1 319.402039 1 0 1 +107 1 91.11633 1 0 1 +108 0 -379.639343 0 0 0 +109 1 65.18329 1 0 1 +110 0 -255.730743 0 0 0 +111 1 -32.7773132 5.820948E-15 47.287667279651664 0 +112 1 49.09488 1 0 1 +113 1 -39.6409 6.08381538E-18 57.189729333347941 0 +114 0 -272.0699 0 0 0 +115 0 -305.7808 0 0 0 +116 0 -287.3377 0 0 0 +117 1 144.5932 1 0 1 +118 0 -403.447083 0 0 0 +119 0 -341.6227 0 0 0 +120 0 -400.4845 0 0 0 +121 0 -342.008 0 0 0 +122 1 48.14859 1 0 1 +123 1 -254.01651 0 Infinity 0 +124 1 -127.346558 0 Infinity 0 +125 0 -407.171082 0 0 0 +126 1 85.45001 1 0 1 +127 0 -414.5997 0 0 0 +128 1 -56.6429138 2.51359385E-25 81.718450817485731 0 +129 0 -601.7137 0 0 0 +130 0 -415.560547 0 0 0 +131 0 -414.214417 0 0 0 +132 1 295.987366 1 0 1 +133 0 -394.175781 0 0 0 +134 0 -433.091736 0 0 0 +135 0 -364.155518 0 0 0 +136 0 -408.326935 0 0 0 +137 0 -401.2836 0 0 0 +138 0 -411.7511 0 0 0 +139 0 +140 0 -401.2836 0 0 0 +141 0 -400.898315 0 0 0 +142 1 -108.134125 0 Infinity 0 +143 0 -305.7808 0 0 0 +144 0 -407.556366 0 0 0 +145 0 +146 1 -205.1228 0 Infinity 0 +147 0 -408.638123 0 0 0 +148 0 -448.917847 0 0 0 +149 1 69.0268555 1 0 1 +150 0 -388.679047 0 0 0 +151 1 -226.9046 0 Infinity 0 +152 1 221.257813 1 0 1 +153 0 -354.3028 0 0 0 +154 0 -400.513 0 0 0 +155 1 39.9500122 1 0 1 +156 0 -361.30127 0 0 0 +157 0 -414.214417 0 0 0 +158 0 +159 1 214.183228 1 0 1 +160 1 120.047485 1 0 1 +161 0 -401.219147 0 0 0 +162 0 -414.5997 0 0 0 +163 0 -282.1694 0 0 0 +164 0 +165 0 -377.536072 0 0 0 +166 1 123.761658 1 0 1 +167 1 -9.15014648 0.000106192965 13.201024182527696 0 +168 0 -414.5997 0 0 0 +169 0 -358.594147 0 0 0 +170 0 -401.2836 0 0 0 +171 0 -413.829132 0 0 0 +172 0 -407.171082 0 0 0 +173 1 316.5832 1 0 1 +174 1 34.576416 1 0 1 +175 1 90.98065 1 0 1 +176 0 -414.214417 0 0 0 +177 1 0.357513428 0.5884384 0.76503671898226377 1 +178 0 -414.985 0 0 0 +179 1 -177.546082 0 Infinity 0 +180 0 -388.679047 0 0 0 +181 0 -400.513 0 0 0 +182 0 -415.755585 0 0 0 +183 1 230.525391 1 0 1 +184 1 61.9541 1 0 1 +185 0 -389.064331 0 0 0 +186 1 30.8129272 1 0 1 +187 1 125.775085 1 0 1 +188 1 244.60907 1 0 1 +189 0 -371.383484 0 0 0 +190 1 275.354 1 0 1 +191 1 40.0039978 1 0 1 +192 0 -408.712219 0 0 0 +193 0 -413.829132 0 0 0 +194 0 -414.5997 0 0 0 +195 0 -414.985 0 0 0 +196 0 -45.47174 1.785969E-20 0 0 +197 0 -365.063965 0 0 0 +198 0 -400.513 0 0 0 +199 0 -407.94165 0 0 0 +200 1 142.494385 1 0 1 +201 1 -67.245575 6.24622866E-30 97.014857463075344 0 +202 0 -413.829132 0 0 0 +203 0 -415.3703 0 0 0 +204 0 -413.829132 0 0 0 +205 1 360.787842 1 0 1 +206 1 56.5380859 1 0 1 +207 0 -388.679047 0 0 0 +208 0 -388.679047 0 0 0 +209 0 -390.220184 0 0 0 +210 1 399.735962 1 0 1 +211 1 291.975525 1 0 1 +212 0 -413.829132 0 0 0 +213 1 345.636963 1 0 1 +214 1 356.6704 1 0 1 +215 1 90.58252 1 0 1 +216 0 -407.171082 0 0 0 +217 0 -413.829132 0 0 0 +218 1 173.8518 1 0 1 +219 0 -422.603882 0 0 0 +220 0 -397.060364 0 0 0 +221 1 -51.67093 3.62744371E-23 74.545392954249706 0 +222 1 -254.9071 0 Infinity 0 +223 1 -47.25174 3.01182883E-21 68.169850212425573 0 +224 1 126.361267 1 0 1 +225 0 -407.171082 0 0 0 +226 1 49.32419 1 0 1 +227 1 143.052124 1 0 1 +228 0 -388.679047 0 0 0 +229 1 124.959839 1 0 1 +230 1 -79.35242 3.44892051E-35 114.48133844099975 0 +231 1 122.692749 1 0 1 +232 0 -256.504028 0 0 0 +233 1 -84.9973755 1.21929513E-37 122.62529213957316 0 +234 0 -275.7568 0 0 0 +235 0 +236 1 116.098267 1 0 1 +237 1 26.0986938 1 0 1 +238 1 370.027954 1 0 1 +239 1 -52.4384155 1.68378053E-23 75.652642077469366 0 +240 0 -330.808228 0 0 0 +241 0 -355.912079 0 0 0 +242 0 -414.214417 0 0 0 +243 0 -332.413025 0 0 0 +244 0 -413.829132 0 0 0 +245 0 -374.918457 0 0 0 +246 1 321.109131 1 0 1 +247 1 -61.9207153 1.28284749E-27 89.332708892981643 0 +248 0 -346.1557 0 0 0 +249 0 +250 0 -354.643219 0 0 0 +251 1 108.2807 1 0 1 +252 0 -4.193939 0.0148625113 0.021603009492489122 0 +253 1 86.67383 1 0 1 +254 1 -11.9711 6.324332E-06 17.270655479328113 0 +255 1 62.42392 1 0 1 +256 0 -401.2836 0 0 0 +257 0 -407.94165 0 0 0 +258 0 -414.5997 0 0 0 +259 0 -8.52298 0.0001988065 0.00028684567329613589 0 +260 1 91.19623 1 0 1 +261 1 134.843628 1 0 1 +262 1 158.870361 1 0 1 +263 1 25.5258484 1 0 1 +264 1 -11.0758057 1.54821755E-05 15.979032265129009 0 +265 0 -411.8769 0 0 0 +266 1 256.479553 1 0 1 +267 1 -151.574524 0 Infinity 0 +268 1 49.6661072 1 0 1 +269 0 -413.829132 0 0 0 +270 1 13.7780151 0.999999 1.4618532729665815E-06 1 +271 0 -415.3703 0 0 0 +272 1 -151.574524 0 Infinity 0 +273 1 -303.688629 0 Infinity 0 +274 0 -400.833862 0 0 0 +275 0 +276 0 -407.94165 0 0 0 +277 0 -407.171082 0 0 0 +278 0 -413.829132 0 0 0 +279 1 -28.7467346 3.276814E-13 41.472771430985198 0 +280 0 -414.5997 0 0 0 +281 0 -397.445648 0 0 0 +282 1 96.48364 1 0 1 +283 1 -59.1726379 2.00285667E-26 85.368071284909448 0 +284 1 183.314758 1 0 1 +285 1 255.142639 1 0 1 +286 1 319.219543 1 0 1 +287 0 -433.091736 0 0 0 +288 1 -267.595245 0 Infinity 0 +289 1 175.671082 1 0 1 +290 0 -400.513 0 0 0 +291 0 -413.829132 0 0 0 +292 1 +293 1 51.4936523 1 0 1 +294 0 +295 1 5.8543396 0.997140765 0.0041309123233919023 1 +296 0 -173.148254 0 0 0 +297 0 +298 0 -429.3664 0 0 0 +299 1 -112.8385 0 Infinity 0 +300 1 -114.377228 0 Infinity 0 +301 0 -413.829132 0 0 0 +302 1 274.0891 1 0 1 +303 0 -413.829132 0 0 0 +304 1 21.4684753 1 0 1 +305 1 269.0639 1 0 1 +306 0 -413.829132 0 0 0 +307 0 -413.829132 0 0 0 +308 1 66.88153 1 0 1 +309 0 -333.1836 0 0 0 +310 0 -432.706451 0 0 0 +311 0 -400.513 0 0 0 +312 1 -175.547363 0 Infinity 0 +313 0 -400.513 0 0 0 +314 0 -382.020966 0 0 0 +315 0 +316 1 -56.5515747 2.753995E-25 81.586676422594735 0 +317 1 168.155884 1 0 1 +318 0 -489.2794 0 0 0 +319 0 -232.038055 0 0 0 +320 1 16.8273315 0.99999994 8.5991327994145617E-08 1 +321 0 +322 0 -414.5997 0 0 0 +323 1 72.79901 1 0 1 +324 0 -413.829132 0 0 0 +325 0 -334.540161 0 0 0 +326 1 -176.167816 0 Infinity 0 +327 0 -407.171082 0 0 0 +328 1 131.3811 1 0 1 +329 1 -125.164459 0 Infinity 0 +330 1 -127.383881 0 Infinity 0 +331 0 -410.5272 0 0 0 +332 0 -332.307831 0 0 0 +333 1 -17.0445251 3.95964967E-08 24.590051963836686 0 +334 1 77.07416 1 0 1 +335 0 -400.513 0 0 0 +336 1 89.2496948 1 0 1 +337 0 -413.829132 0 0 0 +338 0 -382.020966 0 0 0 +339 1 68.04913 1 0 1 +340 1 -70.21268 3.213822E-31 101.29548091112767 0 +341 0 -413.829132 0 0 0 +342 0 -400.898315 0 0 0 +343 0 -400.513 0 0 0 +344 1 -25.81427 6.151839E-12 37.242119386792353 0 +345 0 -400.513 0 0 0 +346 0 -337.034 0 0 0 +347 0 -347.8312 0 0 0 +348 1 -173.990021 0 Infinity 0 +349 1 -122.635986 0 Infinity 0 +350 0 -368.516632 0 0 0 +351 0 -414.214417 0 0 0 +352 0 -263.0901 0 0 0 +353 1 207.045776 1 0 1 +354 0 -407.171082 0 0 0 +355 0 -419.711182 0 0 0 +356 1 -264.1968 0 Infinity 0 +357 1 143.662415 1 0 1 +358 1 94.0296 1 0 1 +359 1 -100.401337 2.522337E-44 144.83007499855768 0 +360 1 294.12854 1 0 1 +361 1 286.876648 1 0 1 +362 0 -365.00592 0 0 0 +363 0 -206.582718 0 0 0 +364 0 -414.214417 0 0 0 +365 0 -407.556366 0 0 0 +366 1 336.557373 1 0 1 +367 1 294.622864 1 0 1 +368 0 -407.361328 0 0 0 +369 0 -388.8693 0 0 0 +370 0 -338.313324 0 0 0 +371 0 -407.361328 0 0 0 +372 0 -411.7511 0 0 0 +373 0 -391.943481 0 0 0 +374 0 -411.365784 0 0 0 +375 0 -400.513 0 0 0 +376 0 -407.171082 0 0 0 +377 0 -382.020966 0 0 0 +378 0 -369.0196 0 0 0 +379 0 -377.355042 0 0 0 +380 0 -400.513 0 0 0 +381 1 153.2738 1 0 1 +382 0 -338.1953 0 0 0 +383 0 -400.898315 0 0 0 +384 0 -400.898315 0 0 0 +385 0 -291.049133 0 0 0 +386 1 33.2608032 1 0 1 +387 0 -332.012054 0 0 0 +388 0 -393.7905 0 0 0 +389 0 -396.413422 0 0 0 +390 0 -390.4023 0 0 0 +391 1 229.0141 1 0 1 +392 0 -407.94165 0 0 0 +393 0 -375.3629 0 0 0 +394 0 -364.725433 0 0 0 +395 0 -407.94165 0 0 0 +396 0 -414.5997 0 0 0 +397 0 -401.668884 0 0 0 +398 0 -344.881836 0 0 0 +399 0 -358.067383 0 0 0 +400 1 216.038391 1 0 1 +401 0 -401.2836 0 0 0 +402 0 -316.972656 0 0 0 +403 0 -330.234436 0 0 0 +404 0 -303.91568 0 0 0 +405 0 -407.171082 0 0 0 +406 0 -362.457153 0 0 0 +407 0 -407.171082 0 0 0 +408 0 -278.598877 0 0 0 +409 0 -411.365784 0 0 0 +410 0 -407.171082 0 0 0 +411 0 +412 1 21.296814 1 0 1 +413 0 -418.794434 0 0 0 +414 1 17.1500549 0.99999994 8.5991327994145617E-08 1 +415 0 -158.312744 0 0 0 +416 1 -78.57623 7.495069E-35 113.36154154761103 0 +417 0 -407.171082 0 0 0 +418 0 -310.439728 0 0 0 +419 0 -377.760681 0 0 0 +420 0 -287.8263 0 0 0 +421 1 88.56616 1 0 1 +422 0 -295.7137 0 0 0 +423 0 -415.560547 0 0 0 +424 0 -401.2836 0 0 0 +425 1 236.068481 1 0 1 +426 0 -320.587067 0 0 0 +427 1 -44.00345 7.754345E-20 63.483556937036411 0 +428 0 -407.171082 0 0 0 +429 0 -407.556366 0 0 0 +430 0 -294.817322 0 0 0 +431 0 -418.3671 0 0 0 +432 0 -394.946381 0 0 0 +433 0 -321.876282 0 0 0 +434 0 122.56488 1 Infinity 1 +435 1 71.62463 1 0 1 +436 1 -2.05007935 0.114044361 3.1323329838202993 0 +437 0 -401.668884 0 0 0 +438 0 -374.979675 0 0 0 +439 0 -402.054169 0 0 0 +440 1 -9.103485 0.000111264941 13.133713305553275 0 +441 0 -234.828949 0 0 0 +442 0 -319.609375 0 0 0 +443 0 -376.829956 0 0 0 +444 0 -350.0185 0 0 0 +445 0 -400.898315 0 0 0 +446 0 -400.513 0 0 0 +447 0 -402.054169 0 0 0 +448 0 -375.3629 0 0 0 +449 1 173.530884 1 0 1 +450 0 -349.6393 0 0 0 +451 0 -402.054169 0 0 0 +452 0 -367.746063 0 0 0 +453 1 34.4112854 1 0 1 +454 0 -327.4605 0 0 0 +455 1 -234.852478 0 Infinity 0 +456 1 106.093872 1 0 1 +457 1 9.033997 0.999880731 0.0001720789042225489 1 +458 0 -405.478333 0 0 0 +459 0 -408.902466 0 0 0 +460 0 -368.516632 0 0 0 +461 0 -306.5514 0 0 0 +462 0 -368.901917 0 0 0 +463 0 -387.903 0 0 0 +464 0 -401.668884 0 0 0 +465 1 131.093628 1 0 1 +466 1 34.3149719 1 0 1 +467 1 90.55585 1 0 1 +468 0 -401.668884 0 0 0 +469 0 -410.5952 0 0 0 +470 0 -382.791565 0 0 0 +471 0 -368.901917 0 0 0 +472 0 -377.289368 0 0 0 +473 0 -401.668884 0 0 0 +474 0 -402.054169 0 0 0 +475 0 -401.2836 0 0 0 +476 0 -405.093018 0 0 0 +477 0 -401.668884 0 0 0 +478 0 -352.760254 0 0 0 +479 1 251.724915 1 0 1 +480 0 -376.904083 0 0 0 +481 0 -256.584167 0 0 0 +482 1 161.241211 1 0 1 +483 1 143.984924 1 0 1 +484 0 -405.478333 0 0 0 +485 0 -294.4593 0 0 0 +486 0 -382.791565 0 0 0 +487 1 245.079346 1 0 1 +488 1 36.02997 1 0 1 +489 1 -281.350861 0 Infinity 0 +490 0 -400.513 0 0 0 +491 1 113.965393 1 0 1 +492 0 -383.176849 0 0 0 +493 1 300.856079 1 0 1 +494 0 -216.785461 0 0 0 +495 0 -382.791565 0 0 0 +496 0 -375.3629 0 0 0 +497 0 -352.374939 0 0 0 +498 0 -408.326935 0 0 0 +499 0 -408.326935 0 0 0 +500 0 -415.755585 0 0 0 +501 0 -408.326935 0 0 0 +502 0 -391.558167 0 0 0 +503 0 -414.985 0 0 0 +504 0 -400.513 0 0 0 +505 0 -302.69574 0 0 0 +506 1 234.96405 1 0 1 +507 0 -329.463867 0 0 0 +508 0 -402.054169 0 0 0 +509 0 -400.898315 0 0 0 +510 0 -400.513 0 0 0 +511 0 -408.712219 0 0 0 +512 0 -402.054169 0 0 0 +513 0 -382.791565 0 0 0 +514 1 244.530945 1 0 1 +515 1 390.9422 1 0 1 +516 0 -375.3629 0 0 0 +517 0 -382.020966 0 0 0 +518 0 -387.938965 0 0 0 +519 1 13.2459412 0.9999982 2.5797420694119618E-06 1 +520 0 -425.6631 0 0 0 +521 0 -426.81897 0 0 0 +522 1 -162.323669 0 Infinity 0 +523 1 91.895874 1 0 1 +524 0 -407.94165 0 0 0 +525 0 -384.514832 0 0 0 +526 0 -401.668884 0 0 0 +527 0 -414.985 0 0 0 +528 0 -392.519 0 0 0 +529 0 -383.176849 0 0 0 +530 1 8.005951 0.999666631 0.00048102966772982418 1 +531 0 -362.457153 0 0 0 +532 0 -388.679047 0 0 0 +533 0 -407.94165 0 0 0 +534 0 -407.556366 0 0 0 +535 0 -403.908661 0 0 0 +536 0 -415.3703 0 0 0 +537 0 -418.794434 0 0 0 +538 0 -408.326935 0 0 0 +539 0 -409.097534 0 0 0 +540 0 -385.029175 0 0 0 +541 0 -401.2836 0 0 0 +542 0 -303.4663 0 0 0 +543 0 -408.326935 0 0 0 +544 0 -397.6359 0 0 0 +545 0 -408.712219 0 0 0 +546 1 408.09906 1 0 1 +547 0 -375.7482 0 0 0 +548 0 -376.133484 0 0 0 +549 1 141.684875 1 0 1 +550 0 -407.94165 0 0 0 +551 0 -413.829132 0 0 0 +552 0 -344.853363 0 0 0 +553 0 -164.293976 0 0 0 +554 0 -401.2836 0 0 0 +555 0 -226.694183 0 0 0 +556 0 -320.923584 0 0 0 +557 0 -368.516632 0 0 0 +558 0 -407.556366 0 0 0 +559 0 -408.712219 0 0 0 +560 0 -415.3703 0 0 0 +561 0 -415.3703 0 0 0 +562 0 -413.829132 0 0 0 +563 0 -407.94165 0 0 0 +564 0 -401.219147 0 0 0 +565 1 215.478271 1 0 1 +566 0 -418.023865 0 0 0 +567 0 -360.960815 0 0 0 +568 1 -99.8051453 4.484155E-44 144 0 +569 1 128.35553 1 0 1 +570 1 109.875549 1 0 1 +571 1 143.371338 1 0 1 +572 0 -407.94165 0 0 0 +573 0 -407.171082 0 0 0 +574 1 19.3930969 1 0 1 +575 0 -418.794434 0 0 0 +576 0 -408.712219 0 0 0 +577 0 -407.171082 0 0 0 +578 0 -407.171082 0 0 0 +579 0 -413.829132 0 0 0 +580 0 -412.136383 0 0 0 +581 1 99.5480957 1 0 1 +582 1 348.034058 1 0 1 +583 0 -401.2836 0 0 0 +584 0 -343.360443 0 0 0 +585 0 -400.513 0 0 0 +586 1 337.054138 1 0 1 +587 0 -394.946381 0 0 0 +588 1 -21.5651855 4.30882552E-10 31.111986269672936 0 +589 0 -402.054169 0 0 0 +590 1 -25.1401978 1.20712112E-11 36.269638601249056 0 +591 1 51.3045654 1 0 1 +592 1 -80.5669556 1.02380155E-35 116.23354721913097 0 +593 0 -405.478333 0 0 0 +594 1 82.7042847 1 0 1 +595 0 -408.712219 0 0 0 +596 0 -411.7511 0 0 0 +597 0 -422.2186 0 0 0 +598 0 -407.94165 0 0 0 +599 0 -337.8826 0 0 0 +600 0 -407.94165 0 0 0 +601 0 -382.020966 0 0 0 +602 0 -408.326935 0 0 0 +603 1 -126.92569 0 Infinity 0 +604 1 -93.5751953 2.295187E-41 135.00044034278059 0 +605 1 -47.6547852 2.01274975E-21 68.751322185612878 0 +606 0 -384.900116 0 0 0 +607 0 -400.513 0 0 0 +608 1 113.71698 1 0 1 +609 0 -402.054169 0 0 0 +610 1 40.86389 1 0 1 +611 1 118.382507 1 0 1 +612 1 380.642151 1 0 1 +613 0 -308.9685 0 0 0 +614 0 -382.40625 0 0 0 +615 0 -415.175232 0 0 0 +616 0 -407.94165 0 0 0 +617 0 +618 0 -408.326935 0 0 0 +619 0 -408.712219 0 0 0 +620 0 -407.94165 0 0 0 +621 0 -311.9197 0 0 0 +622 0 -379.369446 0 0 0 +623 0 -400.513 0 0 0 +624 0 -366.793365 0 0 0 +625 0 -362.232849 0 0 0 +626 1 -60.24713 6.839168E-27 86.918237673162878 0 +627 0 -306.166077 0 0 0 +628 0 -400.898315 0 0 0 +629 0 -401.668884 0 0 0 +630 0 -359.993835 0 0 0 +631 0 -408.712219 0 0 0 +632 0 -400.513 0 0 0 +633 1 1.76364136 0.8536651 0.22825787272289805 1 +634 0 -401.2836 0 0 0 +635 0 -411.55603 0 0 0 +636 1 84.4071045 1 0 1 +637 0 -312.084869 0 0 0 +638 0 -401.668884 0 0 0 +639 0 -368.516632 0 0 0 +640 0 -368.131348 0 0 0 +641 0 -407.94165 0 0 0 +642 0 -407.94165 0 0 0 +643 0 -400.513 0 0 0 +644 0 -400.898315 0 0 0 +645 0 -407.94165 0 0 0 +646 0 -354.643219 0 0 0 +647 0 -366.975464 0 0 0 +648 1 91.73328 1 0 1 +649 0 -407.94165 0 0 0 +650 0 -338.615753 0 0 0 +651 0 -324.3396 0 0 0 +652 0 -394.946381 0 0 0 +653 0 -408.326935 0 0 0 +654 0 -414.5997 0 0 0 +655 0 -407.94165 0 0 0 +656 0 -408.712219 0 0 0 +657 0 -363.256348 0 0 0 +658 1 190.37738 1 0 1 +659 0 -400.513 0 0 0 +660 0 -407.171082 0 0 0 +661 0 -414.985 0 0 0 +662 0 -414.019379 0 0 0 +663 0 -414.019379 0 0 0 +664 0 -374.4041 0 0 0 +665 0 -400.513 0 0 0 +666 0 -334.5725 0 0 0 +667 0 -414.5997 0 0 0 +668 1 -123.532867 0 Infinity 0 +669 1 231.38147 1 0 1 +670 1 230.878479 1 0 1 +671 0 -348.537 0 0 0 +672 0 -414.214417 0 0 0 +673 0 -332.040558 0 0 0 +674 0 -407.171082 0 0 0 +675 0 -411.941345 0 0 0 +676 0 -410.5952 0 0 0 +677 0 -402.054169 0 0 0 +678 0 -400.513 0 0 0 +679 0 -400.898315 0 0 0 +680 1 390.4923 1 0 1 +681 1 376.423279 1 0 1 +682 0 -401.604431 0 0 0 +683 0 -400.513 0 0 0 +684 0 -400.513 0 0 0 +685 0 -400.513 0 0 0 +686 0 -400.513 0 0 0 +687 0 -380.409943 0 0 0 +688 0 -401.668884 0 0 0 +689 0 -366.611267 0 0 0 +690 0 -366.975464 0 0 0 +691 1 143.604248 1 0 1 +692 0 -401.2836 0 0 0 +693 0 -403.098541 0 0 0 +694 0 -402.750641 0 0 0 +695 0 -400.898315 0 0 0 +696 1 48.056427 1 0 1 +697 1 31.3800049 1 0 1 +698 1 12.3017273 0.99999547 6.5353555352246199E-06 1 diff --git a/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-CV-breast-cancer-out.txt b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-CV-breast-cancer-out.txt new file mode 100644 index 0000000000..e8f77de3b1 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-CV-breast-cancer-out.txt @@ -0,0 +1,65 @@ +maml.exe CV tr=SymSGD{nt=1} threads=- norm=No dout=%Output% data=%Data% seed=1 +Not adding a normalizer. +Data fully loaded into memory. +Not training a calibrator because it is not needed. +Not adding a normalizer. +Data fully loaded into memory. +Not training a calibrator because it is not needed. +Warning: The predictor produced non-finite prediction values on 8 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable. +TEST POSITIVE RATIO: 0.3785 (134.0/(134.0+220.0)) +Confusion table + ||====================== +PREDICTED || positive | negative | Recall +TRUTH ||====================== + positive || 132 | 2 | 0.9851 + negative || 8 | 212 | 0.9636 + ||====================== +Precision || 0.9429 | 0.9907 | +OVERALL 0/1 ACCURACY: 0.971751 +LOG LOSS/instance: Infinity +Test-set entropy (prior Log-Loss/instance): 0.956998 +LOG-LOSS REDUCTION (RIG): -Infinity +AUC: 0.991045 +Warning: The predictor produced non-finite prediction values on 8 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable. +TEST POSITIVE RATIO: 0.3191 (105.0/(105.0+224.0)) +Confusion table + ||====================== +PREDICTED || positive | negative | Recall +TRUTH ||====================== + positive || 96 | 9 | 0.9143 + negative || 11 | 213 | 0.9509 + ||====================== +Precision || 0.8972 | 0.9595 | +OVERALL 0/1 ACCURACY: 0.939210 +LOG LOSS/instance: Infinity +Test-set entropy (prior Log-Loss/instance): 0.903454 +LOG-LOSS REDUCTION (RIG): -Infinity +AUC: 0.963435 + +OVERALL RESULTS +--------------------------------------- +AUC: 0.977240 (0.0138) +Accuracy: 0.955481 (0.0163) +Positive precision: 0.920027 (0.0228) +Positive recall: 0.949680 (0.0354) +Negative precision: 0.975057 (0.0156) +Negative recall: 0.957265 (0.0064) +Log-loss: Infinity (NaN) +Log-loss reduction: -Infinity (NaN) +F1 Score: 0.934582 (0.0289) +AUPRC: 0.964431 (0.0168) + +--------------------------------------- +Physical memory usage(MB): %Number% +Virtual memory usage(MB): %Number% +%DateTime% Time elapsed(s): %Number% + +--- Progress log --- +[1] 'Preprocessing' started. +[1] 'Preprocessing' finished in %Time%. +[2] 'Training' started. +[2] 'Training' finished in %Time%. +[3] 'Preprocessing #2' started. +[3] 'Preprocessing #2' finished in %Time%. +[4] 'Training #2' started. +[4] 'Training #2' finished in %Time%. diff --git a/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-CV-breast-cancer-rp.txt b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-CV-breast-cancer-rp.txt new file mode 100644 index 0000000000..cbe85dbf29 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-CV-breast-cancer-rp.txt @@ -0,0 +1,4 @@ +SymSGD +AUC Accuracy Positive precision Positive recall Negative precision Negative recall Log-loss Log-loss reduction F1 Score AUPRC /nt Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings +0.97724 0.955481 0.920027 0.94968 0.975057 0.957265 Infinity -Infinity 0.934582 0.964431 1 SymSGD %Data% %Output% 99 0 0 maml.exe CV tr=SymSGD{nt=1} threads=- norm=No dout=%Output% data=%Data% seed=1 /nt:1 + diff --git a/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-CV-breast-cancer.txt b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-CV-breast-cancer.txt new file mode 100644 index 0000000000..d0e7499c6d --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-CV-breast-cancer.txt @@ -0,0 +1,700 @@ +Instance Label Score Probability Log-loss Assigned +5 1 1021.923 1 0 1 +6 0 101.91272 1 Infinity 1 +8 0 -334.817749 0 0 0 +9 0 -235.957581 0 0 0 +10 0 -272.79837 0 0 0 +11 0 -320.737427 0 0 0 +18 1 647.4429 1 0 1 +20 1 217.233337 1 0 1 +21 1 479.27478 1 0 1 +25 1 140.3844 1 0 1 +28 0 -320.737427 0 0 0 +31 0 -296.864 0 0 0 +32 1 454.251282 1 0 1 +35 0 -320.737427 0 0 0 +37 0 -78.9787 5.011722E-35 0 0 +40 0 +41 1 199.0091 1 0 1 +44 1 656.8247 1 0 1 +45 0 -322.804565 0 0 0 +46 1 682.001953 1 0 1 +48 0 -308.837372 0 0 0 +50 1 269.319824 1 0 1 +51 1 55.29657 1 0 1 +52 1 287.521057 1 0 1 +54 1 316.589355 1 0 1 +56 1 913.6865 1 0 1 +60 1 119.077576 1 0 1 +63 1 4.26116943 0.9860904 0.020208151201720485 1 +64 0 -325.252838 0 0 0 +66 0 -287.8332 0 0 0 +68 1 560.21106 1 0 1 +69 0 -288.1764 0 0 0 +70 0 -262.6084 0 0 0 +71 1 377.826416 1 0 1 +72 0 -7.10409546 0.00082105794 0.0011850227709561453 0 +73 1 361.900269 1 0 1 +74 1 308.8277 1 0 1 +76 0 -258.3786 0 0 0 +77 0 -166.195282 0 0 0 +79 0 -349.318481 0 0 0 +82 0 -207.568787 0 0 0 +88 0 -287.8332 0 0 0 +90 0 -301.379425 0 0 0 +91 0 -350.5975 0 0 0 +92 0 -287.8332 0 0 0 +93 0 -325.252838 0 0 0 +95 0 -301.379425 0 0 0 +96 0 -355.1129 0 0 0 +97 0 -283.3178 0 0 0 +98 1 380.50592 1 0 1 +99 1 571.659058 1 0 1 +100 1 51.7908325 1 0 1 +102 0 -265.5418 0 0 0 +104 1 527.381836 1 0 1 +105 1 1.51104736 0.8192164 0.28768354956614323 1 +106 1 975.2739 1 0 1 +108 0 -298.846741 0 0 0 +109 1 582.2739 1 0 1 +111 1 427.5822 1 0 1 +112 1 370.9096 1 0 1 +113 1 805.0089 1 0 1 +115 0 -169.885254 0 0 0 +117 1 521.3834 1 0 1 +120 0 -289.590118 0 0 0 +121 0 -167.273956 0 0 0 +122 1 864.813965 1 0 1 +123 1 263.304382 1 0 1 +125 0 -325.252838 0 0 0 +128 1 277.651428 1 0 1 +129 0 -577.7748 0 0 0 +131 0 -296.864 0 0 0 +132 1 762.661865 1 0 1 +133 0 -303.018982 0 0 0 +137 0 -340.095428 0 0 0 +138 0 -289.415222 0 0 0 +141 0 -344.61084 0 0 0 +144 0 -320.737427 0 0 0 +145 0 +147 0 -309.023651 0 0 0 +150 0 -272.79837 0 0 0 +151 1 249.55658 1 0 1 +152 1 884.802856 1 0 1 +154 0 -349.126221 0 0 0 +156 0 -227.212433 0 0 0 +161 0 -274.6302 0 0 0 +164 0 +167 1 283.316284 1 0 1 +169 0 -331.047241 0 0 0 +171 0 -301.379425 0 0 0 +173 1 1041.47913 1 0 1 +174 1 612.9641 1 0 1 +176 0 -296.864 0 0 0 +177 1 578.026 1 0 1 +179 1 180.726685 1 0 1 +180 0 -272.79837 0 0 0 +181 0 -349.126221 0 0 0 +183 1 834.9143 1 0 1 +187 1 544.907654 1 0 1 +188 1 709.276855 1 0 1 +189 0 -181.0476 0 0 0 +191 1 457.053833 1 0 1 +192 0 -307.1912 0 0 0 +196 0 384.403748 1 Infinity 1 +198 0 -349.126221 0 0 0 +199 0 -316.222015 0 0 0 +201 1 688.0464 1 0 1 +202 0 -301.379425 0 0 0 +204 0 -301.379425 0 0 0 +205 1 1028.11108 1 0 1 +206 1 667.4751 1 0 1 +207 0 -272.79837 0 0 0 +209 0 -254.736725 0 0 0 +210 1 1086.058 1 0 1 +211 1 925.716064 1 0 1 +212 0 -301.379425 0 0 0 +216 0 -325.252838 0 0 0 +218 1 734.4236 1 0 1 +219 0 -234.219574 0 0 0 +223 1 539.3552 1 0 1 +226 1 715.6188 1 0 1 +228 0 -272.79837 0 0 0 +233 1 355.085876 1 0 1 +237 1 459.5498 1 0 1 +239 1 451.899231 1 0 1 +240 0 -212.39624 0 0 0 +241 0 -284.771729 0 0 0 +242 0 -296.864 0 0 0 +244 0 -301.379425 0 0 0 +246 1 1044.54126 1 0 1 +247 1 409.237671 1 0 1 +248 0 -221.818024 0 0 0 +249 0 +250 0 -251.085815 0 0 0 +252 0 307.68988 1 Infinity 1 +254 1 728.536743 1 0 1 +257 0 -316.222015 0 0 0 +258 0 -292.348633 0 0 0 +259 0 572.462158 1 Infinity 1 +260 1 531.1308 1 0 1 +262 1 912.760254 1 0 1 +267 1 408.017578 1 0 1 +268 1 736.6128 1 0 1 +269 0 -301.379425 0 0 0 +271 0 -283.3178 0 0 0 +272 1 408.017578 1 0 1 +275 0 +276 0 -316.222015 0 0 0 +277 0 -325.252838 0 0 0 +278 0 -301.379425 0 0 0 +279 1 438.823853 1 0 1 +280 0 -292.348633 0 0 0 +283 1 555.4083 1 0 1 +284 1 445.768616 1 0 1 +285 1 1001.10779 1 0 1 +288 1 54.43335 1 0 1 +290 0 -349.126221 0 0 0 +291 0 -301.379425 0 0 0 +293 1 386.6949 1 0 1 +296 0 139.220642 1 Infinity 1 +297 0 +299 1 227.814941 1 0 1 +300 1 407.6792 1 0 1 +301 0 -301.379425 0 0 0 +303 0 -301.379425 0 0 0 +304 1 265.7011 1 0 1 +308 1 588.7855 1 0 1 +309 0 -65.36084 4.11289773E-29 0 0 +311 0 -349.126221 0 0 0 +312 1 -94.50748 9.035572E-42 136.34536397147204 0 +314 0 -296.671753 0 0 0 +316 1 466.170166 1 0 1 +317 1 736.0132 1 0 1 +319 0 161.598083 1 Infinity 1 +321 0 +323 1 388.03302 1 0 1 +327 0 -325.252838 0 0 0 +328 1 584.984 1 0 1 +329 1 459.499573 1 0 1 +331 0 -280.869537 0 0 0 +332 0 -206.44986 0 0 0 +333 1 399.285278 1 0 1 +336 1 425.5838 1 0 1 +338 0 -296.671753 0 0 0 +343 0 -349.126221 0 0 0 +344 1 473.410339 1 0 1 +346 0 -245.701279 0 0 0 +347 0 -294.1391 0 0 0 +348 1 -152.083282 0 Infinity 0 +349 1 237.0896 1 0 1 +350 0 -352.0688 0 0 0 +352 0 39.6635437 1 Infinity 1 +353 1 660.963257 1 0 1 +354 0 -325.252838 0 0 0 +355 0 -327.084656 0 0 0 +358 1 590.6726 1 0 1 +360 1 943.9149 1 0 1 +361 1 683.516 1 0 1 +366 1 969.8589 1 0 1 +368 0 -304.543427 0 0 0 +370 0 -166.498291 0 0 0 +371 0 -304.543427 0 0 0 +373 0 -317.6933 0 0 0 +376 0 -325.252838 0 0 0 +377 0 -296.671753 0 0 0 +378 0 -363.106323 0 0 0 +379 0 -122.1077 0 0 0 +381 1 651.466431 1 0 1 +383 0 -344.61084 0 0 0 +384 0 -344.61084 0 0 0 +387 0 -126.32016 0 0 0 +388 0 -307.5344 0 0 0 +389 0 -277.759 0 0 0 +391 1 947.5735 1 0 1 +392 0 -316.222015 0 0 0 +395 0 -316.222015 0 0 0 +396 0 -292.348633 0 0 0 +398 0 -227.269958 0 0 0 +399 0 -228.794418 0 0 0 +404 0 -281.178345 0 0 0 +406 0 -213.6662 0 0 0 +409 0 -293.9306 0 0 0 +413 0 -261.0264 0 0 0 +414 1 646.5259 1 0 1 +415 0 -57.21808 1.41417837E-25 0 0 +416 1 582.0941 1 0 1 +418 0 -135.317932 0 0 0 +419 0 -278.856018 0 0 0 +422 0 -65.34631 4.17307942E-29 0 0 +423 0 -262.6084 0 0 0 +428 0 -325.252838 0 0 0 +429 0 -320.737427 0 0 0 +430 0 -160.551788 0 0 0 +434 0 683.9347 1 Infinity 1 +436 1 575.846558 1 0 1 +439 0 -331.0646 0 0 0 +440 1 429.814148 1 0 1 +441 0 -130.0997 0 0 0 +442 0 -280.509918 0 0 0 +449 1 803.155 1 0 1 +450 0 -304.1297 0 0 0 +451 0 -331.0646 0 0 0 +452 0 -361.0996 0 0 0 +453 1 566.1138 1 0 1 +454 0 -221.693909 0 0 0 +455 1 16.0523376 0.9999999 1.7198266111377426E-07 1 +456 1 636.5864 1 0 1 +457 1 647.529541 1 0 1 +464 0 -335.580017 0 0 0 +465 1 677.602661 1 0 1 +466 1 787.547852 1 0 1 +467 1 679.4784 1 0 1 +474 0 -331.0646 0 0 0 +480 0 -302.483521 0 0 0 +482 1 781.156738 1 0 1 +483 1 878.3706 1 0 1 +484 0 -308.7732 0 0 0 +487 1 945.2323 1 0 1 +489 1 24.710083 1 0 1 +492 0 -283.125549 0 0 0 +493 1 920.3756 1 0 1 +495 0 -287.64093 0 0 0 +497 0 -259.831 0 0 0 +501 0 -311.7066 0 0 0 +502 0 -322.208679 0 0 0 +504 0 -349.126221 0 0 0 +507 0 -214.695511 0 0 0 +510 0 -349.126221 0 0 0 +513 0 -287.64093 0 0 0 +514 1 754.5259 1 0 1 +517 0 -296.671753 0 0 0 +519 1 791.0669 1 0 1 +520 0 -377.7073 0 0 0 +521 0 -364.161072 0 0 0 +522 1 296.72522 1 0 1 +523 1 465.930176 1 0 1 +527 0 -287.8332 0 0 0 +528 0 -292.468475 0 0 0 +529 0 -283.125549 0 0 0 +531 0 -213.6662 0 0 0 +532 0 -272.79837 0 0 0 +533 0 -316.222015 0 0 0 +534 0 -320.737427 0 0 0 +535 0 -267.2987 0 0 0 +538 0 -311.7066 0 0 0 +539 0 -302.675781 0 0 0 +540 0 -262.380951 0 0 0 +541 0 -340.095428 0 0 0 +544 0 -286.656677 0 0 0 +546 1 1069.49353 1 0 1 +547 0 -316.029755 0 0 0 +548 0 -311.514343 0 0 0 +549 1 549.9683 1 0 1 +557 0 -352.0688 0 0 0 +558 0 -320.737427 0 0 0 +559 0 -307.1912 0 0 0 +560 0 -283.3178 0 0 0 +561 0 -283.3178 0 0 0 +563 0 -316.222015 0 0 0 +565 1 880.770142 1 0 1 +566 0 -270.05722 0 0 0 +569 1 740.958 1 0 1 +577 0 -325.252838 0 0 0 +578 0 -325.252838 0 0 0 +581 1 785.143066 1 0 1 +582 1 986.6736 1 0 1 +584 0 -412.1561 0 0 0 +586 1 1092.98242 1 0 1 +590 1 617.891 1 0 1 +593 0 -308.7732 0 0 0 +594 1 774.4863 1 0 1 +600 0 -316.222015 0 0 0 +602 0 -311.7066 0 0 0 +604 1 256.321228 1 0 1 +606 0 -346.0821 0 0 0 +607 0 -349.126221 0 0 0 +609 0 -331.0646 0 0 0 +612 1 1115.01685 1 0 1 +613 0 -169.23941 0 0 0 +614 0 -292.156342 0 0 0 +617 0 +618 0 -311.7066 0 0 0 +619 0 -307.1912 0 0 0 +621 0 -15.8763733 1.27344038E-07 1.8371862313930792E-07 0 +622 0 -296.873169 0 0 0 +624 0 -289.1122 0 0 0 +627 0 -165.369843 0 0 0 +629 0 -335.580017 0 0 0 +633 1 390.5418 1 0 1 +634 0 -340.095428 0 0 0 +638 0 -335.580017 0 0 0 +639 0 -352.0688 0 0 0 +641 0 -316.222015 0 0 0 +642 0 -316.222015 0 0 0 +644 0 -344.61084 0 0 0 +645 0 -316.222015 0 0 0 +649 0 -316.222015 0 0 0 +652 0 -293.988159 0 0 0 +653 0 -311.7066 0 0 0 +654 0 -292.348633 0 0 0 +656 0 -307.1912 0 0 0 +657 0 -72.37637 3.69266974E-32 0 0 +660 0 -325.252838 0 0 0 +661 0 -287.8332 0 0 0 +665 0 -349.126221 0 0 0 +668 1 342.943665 1 0 1 +670 1 812.6575 1 0 1 +678 0 -349.126221 0 0 0 +679 0 -344.61084 0 0 0 +680 1 1145.28394 1 0 1 +681 1 969.3158 1 0 1 +682 0 -270.114777 0 0 0 +683 0 -349.126221 0 0 0 +685 0 -349.126221 0 0 0 +688 0 -335.580017 0 0 0 +689 0 -331.988342 0 0 0 +691 1 742.5989 1 0 1 +692 0 -340.095428 0 0 0 +693 0 -313.773743 0 0 0 +694 0 -323.866241 0 0 0 +696 1 765.3994 1 0 1 +697 1 661.339233 1 0 1 +698 1 685.243042 1 0 1 +0 0 -655.8448 0 0 0 +1 0 352.095367 1 Infinity 1 +2 0 -559.3646 0 0 0 +3 0 1740.70276 1 Infinity 1 +4 0 -568.0029 0 0 0 +7 0 -495.5454 0 0 0 +12 1 241.3157 1 0 1 +13 0 -462.88446 0 0 0 +14 1 942.8955 1 0 1 +15 1 16.80188 0.99999994 8.5991327994145617E-08 1 +16 0 -553.6872 0 0 0 +17 0 -643.057739 0 0 0 +19 0 -668.631836 0 0 0 +22 0 -540.900146 0 0 0 +23 1 +24 0 -604.696655 0 0 0 +26 0 -270.657074 0 0 0 +27 0 -566.4742 0 0 0 +29 0 -182.078979 0 0 0 +30 0 -411.0862 0 0 0 +33 0 -579.956238 0 0 0 +34 0 -418.961884 0 0 0 +36 1 1591.25452 1 0 1 +38 1 1299.106 1 0 1 +39 1 138.435211 1 0 1 +42 1 1330.6217 1 0 1 +43 1 -316.820679 0 Infinity 0 +47 0 -515.32605 0 0 0 +49 1 1846.631 1 0 1 +53 1 -274.015076 0 Infinity 0 +55 1 1099.324 1 0 1 +57 1 -608.199 0 Infinity 0 +58 1 -331.3812 0 Infinity 0 +59 1 272.168976 1 0 1 +61 0 -444.419952 0 0 0 +62 1 1275.43091 1 0 1 +65 1 -452.4966 0 Infinity 0 +67 1 413.874054 1 0 1 +75 0 -419.5797 0 0 0 +78 0 -488.458527 0 0 0 +80 0 -582.7846 0 0 0 +81 0 -516.1598 0 0 0 +83 0 -916.846863 0 0 0 +84 1 897.471436 1 0 1 +85 1 746.382935 1 0 1 +86 1 635.3207 1 0 1 +87 1 1356.65088 1 0 1 +89 0 -620.399658 0 0 0 +94 0 -617.483643 0 0 0 +101 1 841.3391 1 0 1 +103 1 -1044.82617 0 Infinity 0 +107 1 1459.15247 1 0 1 +110 0 -263.1876 0 0 0 +114 0 -85.66205 6.272564E-38 0 0 +116 0 -16.9219666 4.47593E-08 6.4574021972779571E-08 0 +118 0 -543.7712 0 0 0 +119 0 -418.9355 0 0 0 +124 1 404.972565 1 0 1 +126 1 567.644653 1 0 1 +127 0 -630.2707 0 0 0 +130 0 -322.597717 0 0 0 +134 0 -670.7141 0 0 0 +135 0 -421.652374 0 0 0 +136 0 -553.6872 0 0 0 +139 0 +140 0 -451.529541 0 0 0 +142 1 488.315338 1 0 1 +143 0 -142.331116 0 0 0 +146 1 204.013458 1 0 1 +148 0 -941.3387 0 0 0 +149 1 752.705933 1 0 1 +153 0 -322.504425 0 0 0 +155 1 1089.60254 1 0 1 +157 0 -617.483643 0 0 0 +158 0 +159 1 1923.50525 1 0 1 +160 1 1494.5094 1 0 1 +162 0 -630.2707 0 0 0 +163 0 -310.768677 0 0 0 +165 0 -490.5085 0 0 0 +166 1 1570.026 1 0 1 +168 0 -630.2707 0 0 0 +170 0 -451.529541 0 0 0 +172 0 -515.32605 0 0 0 +175 1 1194.714 1 0 1 +178 0 -643.057739 0 0 0 +182 0 -668.631836 0 0 0 +184 1 1070.428 1 0 1 +185 0 -487.669739 0 0 0 +186 1 1482.4314 1 0 1 +190 1 2258.611 1 0 1 +193 0 -604.696655 0 0 0 +194 0 -630.2707 0 0 0 +195 0 -643.057739 0 0 0 +197 0 -543.2626 0 0 0 +200 1 1415.398 1 0 1 +203 0 -655.8448 0 0 0 +208 0 -474.8827 0 0 0 +213 1 2248.68115 1 0 1 +214 1 2270.2356 1 0 1 +215 1 1645.00281 1 0 1 +217 0 -604.696655 0 0 0 +220 0 -567.16925 0 0 0 +221 1 221.44339 1 0 1 +222 1 -65.19409 4.85921E-29 94.055192962286199 0 +224 1 1289.00232 1 0 1 +225 0 -515.32605 0 0 0 +227 1 1841.02966 1 0 1 +229 1 1514.78943 1 0 1 +230 1 930.4363 1 0 1 +231 1 1347.16248 1 0 1 +232 0 355.92923 1 Infinity 1 +234 0 50.92752 1 Infinity 1 +235 0 +236 1 1204.77161 1 0 1 +238 1 2421.62354 1 0 1 +243 0 -499.813416 0 0 0 +245 0 -547.411255 0 0 0 +251 1 963.113037 1 0 1 +253 1 1330.6217 1 0 1 +255 1 1480.26221 1 0 1 +256 0 -451.529541 0 0 0 +261 1 1146.05249 1 0 1 +263 1 671.7992 1 0 1 +264 1 168.765533 1 0 1 +265 0 -208.3869 0 0 0 +266 1 1778.757 1 0 1 +270 1 1587.86926 1 0 1 +273 1 71.0156555 1 0 1 +274 0 -548.627563 0 0 0 +281 0 -579.956238 0 0 0 +282 1 1006.21533 1 0 1 +286 1 1933.38391 1 0 1 +287 0 -670.7141 0 0 0 +289 1 1586.01746 1 0 1 +292 1 +294 0 +295 1 906.5393 1 0 1 +298 0 -764.4775 0 0 0 +302 1 1682.504 1 0 1 +305 1 1757.95251 1 0 1 +306 0 -604.696655 0 0 0 +307 0 -604.696655 0 0 0 +310 0 -657.927 0 0 0 +313 0 -425.9555 0 0 0 +315 0 +318 0 -994.1385 0 0 0 +320 1 276.6519 1 0 1 +322 0 -630.2707 0 0 0 +324 0 -604.696655 0 0 0 +325 0 -115.24649 0 0 0 +326 1 -15.4368286 1.97637988E-07 22.270636392005638 0 +330 1 698.895264 1 0 1 +334 1 1535.56677 1 0 1 +335 0 -425.9555 0 0 0 +337 0 -604.696655 0 0 0 +339 1 1217.44373 1 0 1 +340 1 493.843048 1 0 1 +341 0 -604.696655 0 0 0 +342 0 -438.7425 0 0 0 +345 0 -425.9555 0 0 0 +351 0 -617.483643 0 0 0 +356 1 -20.48114 1.27395428E-09 29.54803935647886 0 +357 1 1071.41785 1 0 1 +359 1 718.953 1 0 1 +362 0 -396.3484 0 0 0 +363 0 412.754547 1 Infinity 1 +364 0 -617.483643 0 0 0 +365 0 -528.1131 0 0 0 +367 1 1990.46619 1 0 1 +369 0 -141.63562 0 0 0 +372 0 -431.748932 0 0 0 +374 0 -418.961884 0 0 0 +375 0 -425.9555 0 0 0 +380 0 -425.9555 0 0 0 +382 0 -248.732788 0 0 0 +385 0 -124.032074 0 0 0 +386 1 986.1454 1 0 1 +390 0 -477.798676 0 0 0 +393 0 -296.141541 0 0 0 +394 0 -131.020447 0 0 0 +397 0 -464.3166 0 0 0 +400 1 1161.08484 1 0 1 +401 0 -451.529541 0 0 0 +402 0 -41.73947 7.460671E-19 0 0 +403 0 -238.811249 0 0 0 +405 0 -515.32605 0 0 0 +407 0 -515.32605 0 0 0 +408 0 -106.253662 0 0 0 +410 0 -515.32605 0 0 0 +411 0 +412 1 1142.7179 1 0 1 +417 0 -515.32605 0 0 0 +420 0 -151.036346 0 0 0 +421 1 1101.12451 1 0 1 +424 0 -451.529541 0 0 0 +425 1 1700.89758 1 0 1 +426 0 413.445831 1 Infinity 1 +427 1 1406.78931 1 0 1 +431 0 -758.7747 0 0 0 +432 0 -484.831055 0 0 0 +433 0 -114.107391 0 0 0 +435 1 1690.2677 1 0 1 +437 0 -464.3166 0 0 0 +438 0 -145.385315 0 0 0 +443 0 -355.049377 0 0 0 +444 0 -508.651184 0 0 0 +445 0 -438.7425 0 0 0 +446 0 -425.9555 0 0 0 +447 0 -477.103638 0 0 0 +448 0 -296.141541 0 0 0 +458 0 -355.1654 0 0 0 +459 0 -233.227112 0 0 0 +460 0 -402.048828 0 0 0 +461 0 -167.905182 0 0 0 +462 0 -414.835876 0 0 0 +463 0 -382.673462 0 0 0 +468 0 -464.3166 0 0 0 +469 0 -393.387817 0 0 0 +470 0 -411.0862 0 0 0 +471 0 -414.835876 0 0 0 +472 0 -360.076721 0 0 0 +473 0 -464.3166 0 0 0 +475 0 -451.529541 0 0 0 +476 0 -342.378357 0 0 0 +477 0 -464.3166 0 0 0 +478 0 -336.6745 0 0 0 +479 1 1756.93933 1 0 1 +481 0 38.2750549 1 Infinity 1 +485 0 -79.2746 3.728034E-35 0 0 +486 0 -411.0862 0 0 0 +488 1 1032.36389 1 0 1 +490 0 -425.9555 0 0 0 +491 1 1566.10583 1 0 1 +494 0 -82.79285 1.10541041E-36 0 0 +496 0 -296.141541 0 0 0 +498 0 -553.6872 0 0 0 +499 0 -553.6872 0 0 0 +500 0 -668.631836 0 0 0 +503 0 -643.057739 0 0 0 +505 0 -170.671326 0 0 0 +506 1 1927.56653 1 0 1 +508 0 -477.103638 0 0 0 +509 0 -438.7425 0 0 0 +511 0 -566.4742 0 0 0 +512 0 -477.103638 0 0 0 +515 1 1918.68909 1 0 1 +516 0 -296.141541 0 0 0 +518 0 -292.0639 0 0 0 +524 0 -540.900146 0 0 0 +525 0 -414.002136 0 0 0 +526 0 -464.3166 0 0 0 +530 1 944.298462 1 0 1 +536 0 -655.8448 0 0 0 +537 0 -533.9065 0 0 0 +542 0 -196.245392 0 0 0 +543 0 -553.6872 0 0 0 +545 0 -566.4742 0 0 0 +550 0 -540.900146 0 0 0 +551 0 -604.696655 0 0 0 +552 0 -338.1034 0 0 0 +553 0 240.835052 1 Infinity 1 +554 0 -451.529541 0 0 0 +555 0 119.931915 1 Infinity 1 +556 0 -136.7655 0 0 0 +562 0 -604.696655 0 0 0 +564 0 -561.4146 0 0 0 +567 0 -411.875 0 0 0 +568 1 595.410767 1 0 1 +570 1 542.114868 1 0 1 +571 1 1942.85852 1 0 1 +572 0 -540.900146 0 0 0 +573 0 -515.32605 0 0 0 +574 1 1153.99219 1 0 1 +575 0 -533.9065 0 0 0 +576 0 -566.4742 0 0 0 +579 0 -604.696655 0 0 0 +580 0 -444.53595 0 0 0 +583 0 -451.529541 0 0 0 +585 0 -425.9555 0 0 0 +587 0 -484.831055 0 0 0 +588 1 962.983643 1 0 1 +589 0 -477.103638 0 0 0 +591 1 1292.75977 1 0 1 +592 1 495.973053 1 0 1 +595 0 -566.4742 0 0 0 +596 0 -431.748932 0 0 0 +597 0 -411.968262 0 0 0 +598 0 -540.900146 0 0 0 +599 0 158.954132 1 Infinity 1 +601 0 -385.512115 0 0 0 +603 1 666.133057 1 0 1 +605 1 1270.70056 1 0 1 +608 1 1017.25757 1 0 1 +610 1 401.661346 1 0 1 +611 1 1604.39221 1 0 1 +615 0 -309.810669 0 0 0 +616 0 -540.900146 0 0 0 +620 0 -540.900146 0 0 0 +623 0 -425.9555 0 0 0 +625 0 -124.748688 0 0 0 +626 1 592.070435 1 0 1 +628 0 -438.7425 0 0 0 +630 0 -105.585052 0 0 0 +631 0 -566.4742 0 0 0 +632 0 -425.9555 0 0 0 +635 0 -85.71478 5.95035441E-38 0 0 +636 1 933.872437 1 0 1 +637 0 98.51761 1 Infinity 1 +640 0 -389.261841 0 0 0 +643 0 -425.9555 0 0 0 +646 0 -163.588135 0 0 0 +647 0 -350.9007 0 0 0 +648 1 893.8468 1 0 1 +650 0 -331.097778 0 0 0 +651 0 -299.842163 0 0 0 +655 0 -540.900146 0 0 0 +658 1 1548.16321 1 0 1 +659 0 -425.9555 0 0 0 +662 0 -271.4496 0 0 0 +663 0 -271.4496 0 0 0 +664 0 -465.845337 0 0 0 +666 0 -209.536652 0 0 0 +667 0 -630.2707 0 0 0 +669 1 2239.64185 1 0 1 +671 0 -452.314178 0 0 0 +672 0 -617.483643 0 0 0 +673 0 -204.121124 0 0 0 +674 0 -515.32605 0 0 0 +675 0 -98.50183 1.667545E-43 0 0 +676 0 -393.387817 0 0 0 +677 0 -477.103638 0 0 0 +684 0 -425.9555 0 0 0 +686 0 -425.9555 0 0 0 +687 0 -377.613831 0 0 0 +690 0 -350.9007 0 0 0 +695 0 -438.7425 0 0 0 diff --git a/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt new file mode 100644 index 0000000000..d27eaf83bd --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt @@ -0,0 +1,45 @@ +maml.exe TrainTest test=%Data% tr=SymSGD{nt=1} norm=No dout=%Output% data=%Data% out=%Output% seed=1 +Not adding a normalizer. +Data fully loaded into memory. +Not training a calibrator because it is not needed. +Warning: The predictor produced non-finite prediction values on 16 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable. +TEST POSITIVE RATIO: 0.3499 (239.0/(239.0+444.0)) +Confusion table + ||====================== +PREDICTED || positive | negative | Recall +TRUTH ||====================== + positive || 152 | 87 | 0.6360 + negative || 2 | 442 | 0.9955 + ||====================== +Precision || 0.9870 | 0.8355 | +OVERALL 0/1 ACCURACY: 0.869693 +LOG LOSS/instance: Infinity +Test-set entropy (prior Log-Loss/instance): 0.934003 +LOG-LOSS REDUCTION (RIG): -Infinity +AUC: 0.984941 + +OVERALL RESULTS +--------------------------------------- +AUC: 0.984941 (0.0000) +Accuracy: 0.869693 (0.0000) +Positive precision: 0.987013 (0.0000) +Positive recall: 0.635983 (0.0000) +Negative precision: 0.835539 (0.0000) +Negative recall: 0.995495 (0.0000) +Log-loss: Infinity (0.0000) +Log-loss reduction: -Infinity (0.0000) +F1 Score: 0.773537 (0.0000) +AUPRC: 0.977633 (0.0000) + +--------------------------------------- +Physical memory usage(MB): %Number% +Virtual memory usage(MB): %Number% +%DateTime% Time elapsed(s): %Number% + +--- Progress log --- +[1] 'Preprocessing' started. +[1] 'Preprocessing' finished in %Time%. +[2] 'Training' started. +[2] 'Training' finished in %Time%. +[3] 'Saving model' started. +[3] 'Saving model' finished in %Time%. diff --git a/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer-rp.txt b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer-rp.txt new file mode 100644 index 0000000000..c056310ab0 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer-rp.txt @@ -0,0 +1,4 @@ +SymSGD +AUC Accuracy Positive precision Positive recall Negative precision Negative recall Log-loss Log-loss reduction F1 Score AUPRC /nt Learner Name Train Dataset Test Dataset Results File Run Time Physical Memory Virtual Memory Command Line Settings +0.984941 0.869693 0.987013 0.635983 0.835539 0.995495 Infinity -Infinity 0.773537 0.977633 1 SymSGD %Data% %Data% %Output% 99 0 0 maml.exe TrainTest test=%Data% tr=SymSGD{nt=1} norm=No dout=%Output% data=%Data% out=%Output% seed=1 /nt:1 + diff --git a/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer-summary.txt b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer-summary.txt new file mode 100644 index 0000000000..8fb0a3fe22 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer-summary.txt @@ -0,0 +1,12 @@ +Linear Binary Classification Predictor non-zero weights + +(Bias) -448.1 +f1 49.29393 +f4 -25.15009 +f5 23.68305 +f3 16.76877 +f7 13.76585 +f6 -6.658058 +f8 4.843107 +f2 -3.424153 +f0 -0.3852913 diff --git a/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer.txt b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer.txt new file mode 100644 index 0000000000..f744bc3525 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/SymSGD/SymSGD-TrainTest-breast-cancer.txt @@ -0,0 +1,700 @@ +Instance Label Score Probability Log-loss Assigned +0 0 -415.3703 0 0 0 +1 0 -109.523041 0 0 0 +2 0 -390.916656 0 0 0 +3 0 33.8270569 1 Infinity 1 +4 0 -381.447449 0 0 0 +5 1 243.726929 1 0 1 +6 0 -200.681656 0 0 0 +7 0 -417.63858 0 0 0 +8 0 -381.525879 0 0 0 +9 0 -359.03302 0 0 0 +10 0 -388.679047 0 0 0 +11 0 -407.556366 0 0 0 +12 1 -208.0876 0 Infinity 0 +13 0 -366.463 0 0 0 +14 1 136.450256 1 0 1 +15 1 -314.800262 0 Infinity 0 +16 0 -408.326935 0 0 0 +17 0 -414.985 0 0 0 +18 1 102.798035 1 0 1 +19 0 -415.755585 0 0 0 +20 1 10.1497192 0.9999609 5.6411412351548271E-05 1 +21 1 -61.521698 1.91190378E-27 88.757048641689195 0 +22 0 -407.94165 0 0 0 +23 1 +24 0 -413.829132 0 0 0 +25 1 -111.690765 0 Infinity 0 +26 0 -333.49762 0 0 0 +27 0 -408.712219 0 0 0 +28 0 -407.556366 0 0 0 +29 0 -407.361328 0 0 0 +30 0 -382.791565 0 0 0 +31 0 -414.214417 0 0 0 +32 1 -140.35733 0 Infinity 0 +33 0 -397.445648 0 0 0 +34 0 -411.365784 0 0 0 +35 0 -407.556366 0 0 0 +36 1 89.1488 1 0 1 +37 0 -367.9438 0 0 0 +38 1 125.049805 1 0 1 +39 1 -120.420319 0 Infinity 0 +40 0 +41 1 -214.114883 0 Infinity 0 +42 1 86.67383 1 0 1 +43 1 -299.954132 0 Infinity 0 +44 1 -14.4605713 5.246304E-07 20.86219519410675 0 +45 0 -402.327942 0 0 0 +46 1 139.792358 1 0 1 +47 0 -407.171082 0 0 0 +48 0 -381.447449 0 0 0 +49 1 141.825745 1 0 1 +50 1 -170.3082 0 Infinity 0 +51 1 -160.977692 0 Infinity 0 +52 1 -127.401 0 Infinity 0 +53 1 -119.971008 0 Infinity 0 +54 1 -161.725189 0 Infinity 0 +55 1 -41.1490173 1.34650766E-18 59.365483272233966 0 +56 1 199.693787 1 0 1 +57 1 -410.444183 0 Infinity 0 +58 1 -273.693665 0 Infinity 0 +59 1 -206.296631 0 Infinity 0 +60 1 -123.953339 0 Infinity 0 +61 0 -383.488 0 0 0 +62 1 -11.9711 6.324332E-06 17.270655479328113 0 +63 1 -269.080566 0 Infinity 0 +64 0 -407.171082 0 0 0 +65 1 -209.844543 0 Infinity 0 +66 0 -414.985 0 0 0 +67 1 -161.339172 0 Infinity 0 +68 1 -58.4673157 4.05478141E-26 84.350506323497086 0 +69 0 -400.063263 0 0 0 +70 0 -415.560547 0 0 0 +71 1 77.64612 1 0 1 +72 0 -318.187164 0 0 0 +73 1 70.44263 1 0 1 +74 1 -111.748413 0 Infinity 0 +75 0 -386.462433 0 0 0 +76 0 -417.4435 0 0 0 +77 0 -293.3556 0 0 0 +78 0 -367.233582 0 0 0 +79 0 -432.706451 0 0 0 +80 0 -369.826782 0 0 0 +81 0 -391.558167 0 0 0 +82 0 -366.076355 0 0 0 +83 0 -417.5489 0 0 0 +84 1 44.29547 1 0 1 +85 1 41.82538 1 0 1 +86 1 -157.855713 0 Infinity 0 +87 1 122.36322 1 0 1 +88 0 -414.985 0 0 0 +89 0 -415.937683 0 0 0 +90 0 -413.829132 0 0 0 +91 0 -384.514832 0 0 0 +92 0 -414.985 0 0 0 +93 0 -407.171082 0 0 0 +94 0 -414.214417 0 0 0 +95 0 -413.829132 0 0 0 +96 0 -384.129547 0 0 0 +97 0 -415.3703 0 0 0 +98 1 -103.478882 1.401298E-45 149 0 +99 1 150.554138 1 0 1 +100 1 -310.138184 0 Infinity 0 +101 1 -119.139008 0 Infinity 0 +102 0 -418.409149 0 0 0 +103 1 -453.947021 0 Infinity 0 +104 1 42.2218933 1 0 1 +105 1 -212.478668 0 Infinity 0 +106 1 319.402039 1 0 1 +107 1 91.11633 1 0 1 +108 0 -379.639343 0 0 0 +109 1 65.18329 1 0 1 +110 0 -255.730743 0 0 0 +111 1 -32.7773132 5.820948E-15 47.287667279651664 0 +112 1 49.09488 1 0 1 +113 1 -39.6409 6.08381538E-18 57.189729333347941 0 +114 0 -272.0699 0 0 0 +115 0 -305.7808 0 0 0 +116 0 -287.3377 0 0 0 +117 1 144.5932 1 0 1 +118 0 -403.447083 0 0 0 +119 0 -341.6227 0 0 0 +120 0 -400.4845 0 0 0 +121 0 -342.008 0 0 0 +122 1 48.14859 1 0 1 +123 1 -254.01651 0 Infinity 0 +124 1 -127.346558 0 Infinity 0 +125 0 -407.171082 0 0 0 +126 1 85.45001 1 0 1 +127 0 -414.5997 0 0 0 +128 1 -56.6429138 2.51359385E-25 81.718450817485731 0 +129 0 -601.7137 0 0 0 +130 0 -415.560547 0 0 0 +131 0 -414.214417 0 0 0 +132 1 295.987366 1 0 1 +133 0 -394.175781 0 0 0 +134 0 -433.091736 0 0 0 +135 0 -364.155518 0 0 0 +136 0 -408.326935 0 0 0 +137 0 -401.2836 0 0 0 +138 0 -411.7511 0 0 0 +139 0 +140 0 -401.2836 0 0 0 +141 0 -400.898315 0 0 0 +142 1 -108.134125 0 Infinity 0 +143 0 -305.7808 0 0 0 +144 0 -407.556366 0 0 0 +145 0 +146 1 -205.1228 0 Infinity 0 +147 0 -408.638123 0 0 0 +148 0 -448.917847 0 0 0 +149 1 69.0268555 1 0 1 +150 0 -388.679047 0 0 0 +151 1 -226.9046 0 Infinity 0 +152 1 221.257813 1 0 1 +153 0 -354.3028 0 0 0 +154 0 -400.513 0 0 0 +155 1 39.9500122 1 0 1 +156 0 -361.30127 0 0 0 +157 0 -414.214417 0 0 0 +158 0 +159 1 214.183228 1 0 1 +160 1 120.047485 1 0 1 +161 0 -401.219147 0 0 0 +162 0 -414.5997 0 0 0 +163 0 -282.1694 0 0 0 +164 0 +165 0 -377.536072 0 0 0 +166 1 123.761658 1 0 1 +167 1 -9.15014648 0.000106192965 13.201024182527696 0 +168 0 -414.5997 0 0 0 +169 0 -358.594147 0 0 0 +170 0 -401.2836 0 0 0 +171 0 -413.829132 0 0 0 +172 0 -407.171082 0 0 0 +173 1 316.5832 1 0 1 +174 1 34.576416 1 0 1 +175 1 90.98065 1 0 1 +176 0 -414.214417 0 0 0 +177 1 0.357513428 0.5884384 0.76503671898226377 1 +178 0 -414.985 0 0 0 +179 1 -177.546082 0 Infinity 0 +180 0 -388.679047 0 0 0 +181 0 -400.513 0 0 0 +182 0 -415.755585 0 0 0 +183 1 230.525391 1 0 1 +184 1 61.9541 1 0 1 +185 0 -389.064331 0 0 0 +186 1 30.8129272 1 0 1 +187 1 125.775085 1 0 1 +188 1 244.60907 1 0 1 +189 0 -371.383484 0 0 0 +190 1 275.354 1 0 1 +191 1 40.0039978 1 0 1 +192 0 -408.712219 0 0 0 +193 0 -413.829132 0 0 0 +194 0 -414.5997 0 0 0 +195 0 -414.985 0 0 0 +196 0 -45.47174 1.785969E-20 0 0 +197 0 -365.063965 0 0 0 +198 0 -400.513 0 0 0 +199 0 -407.94165 0 0 0 +200 1 142.494385 1 0 1 +201 1 -67.245575 6.24622866E-30 97.014857463075344 0 +202 0 -413.829132 0 0 0 +203 0 -415.3703 0 0 0 +204 0 -413.829132 0 0 0 +205 1 360.787842 1 0 1 +206 1 56.5380859 1 0 1 +207 0 -388.679047 0 0 0 +208 0 -388.679047 0 0 0 +209 0 -390.220184 0 0 0 +210 1 399.735962 1 0 1 +211 1 291.975525 1 0 1 +212 0 -413.829132 0 0 0 +213 1 345.636963 1 0 1 +214 1 356.6704 1 0 1 +215 1 90.58252 1 0 1 +216 0 -407.171082 0 0 0 +217 0 -413.829132 0 0 0 +218 1 173.8518 1 0 1 +219 0 -422.603882 0 0 0 +220 0 -397.060364 0 0 0 +221 1 -51.67093 3.62744371E-23 74.545392954249706 0 +222 1 -254.9071 0 Infinity 0 +223 1 -47.25174 3.01182883E-21 68.169850212425573 0 +224 1 126.361267 1 0 1 +225 0 -407.171082 0 0 0 +226 1 49.32419 1 0 1 +227 1 143.052124 1 0 1 +228 0 -388.679047 0 0 0 +229 1 124.959839 1 0 1 +230 1 -79.35242 3.44892051E-35 114.48133844099975 0 +231 1 122.692749 1 0 1 +232 0 -256.504028 0 0 0 +233 1 -84.9973755 1.21929513E-37 122.62529213957316 0 +234 0 -275.7568 0 0 0 +235 0 +236 1 116.098267 1 0 1 +237 1 26.0986938 1 0 1 +238 1 370.027954 1 0 1 +239 1 -52.4384155 1.68378053E-23 75.652642077469366 0 +240 0 -330.808228 0 0 0 +241 0 -355.912079 0 0 0 +242 0 -414.214417 0 0 0 +243 0 -332.413025 0 0 0 +244 0 -413.829132 0 0 0 +245 0 -374.918457 0 0 0 +246 1 321.109131 1 0 1 +247 1 -61.9207153 1.28284749E-27 89.332708892981643 0 +248 0 -346.1557 0 0 0 +249 0 +250 0 -354.643219 0 0 0 +251 1 108.2807 1 0 1 +252 0 -4.193939 0.0148625113 0.021603009492489122 0 +253 1 86.67383 1 0 1 +254 1 -11.9711 6.324332E-06 17.270655479328113 0 +255 1 62.42392 1 0 1 +256 0 -401.2836 0 0 0 +257 0 -407.94165 0 0 0 +258 0 -414.5997 0 0 0 +259 0 -8.52298 0.0001988065 0.00028684567329613589 0 +260 1 91.19623 1 0 1 +261 1 134.843628 1 0 1 +262 1 158.870361 1 0 1 +263 1 25.5258484 1 0 1 +264 1 -11.0758057 1.54821755E-05 15.979032265129009 0 +265 0 -411.8769 0 0 0 +266 1 256.479553 1 0 1 +267 1 -151.574524 0 Infinity 0 +268 1 49.6661072 1 0 1 +269 0 -413.829132 0 0 0 +270 1 13.7780151 0.999999 1.4618532729665815E-06 1 +271 0 -415.3703 0 0 0 +272 1 -151.574524 0 Infinity 0 +273 1 -303.688629 0 Infinity 0 +274 0 -400.833862 0 0 0 +275 0 +276 0 -407.94165 0 0 0 +277 0 -407.171082 0 0 0 +278 0 -413.829132 0 0 0 +279 1 -28.7467346 3.276814E-13 41.472771430985198 0 +280 0 -414.5997 0 0 0 +281 0 -397.445648 0 0 0 +282 1 96.48364 1 0 1 +283 1 -59.1726379 2.00285667E-26 85.368071284909448 0 +284 1 183.314758 1 0 1 +285 1 255.142639 1 0 1 +286 1 319.219543 1 0 1 +287 0 -433.091736 0 0 0 +288 1 -267.595245 0 Infinity 0 +289 1 175.671082 1 0 1 +290 0 -400.513 0 0 0 +291 0 -413.829132 0 0 0 +292 1 +293 1 51.4936523 1 0 1 +294 0 +295 1 5.8543396 0.997140765 0.0041309123233919023 1 +296 0 -173.148254 0 0 0 +297 0 +298 0 -429.3664 0 0 0 +299 1 -112.8385 0 Infinity 0 +300 1 -114.377228 0 Infinity 0 +301 0 -413.829132 0 0 0 +302 1 274.0891 1 0 1 +303 0 -413.829132 0 0 0 +304 1 21.4684753 1 0 1 +305 1 269.0639 1 0 1 +306 0 -413.829132 0 0 0 +307 0 -413.829132 0 0 0 +308 1 66.88153 1 0 1 +309 0 -333.1836 0 0 0 +310 0 -432.706451 0 0 0 +311 0 -400.513 0 0 0 +312 1 -175.547363 0 Infinity 0 +313 0 -400.513 0 0 0 +314 0 -382.020966 0 0 0 +315 0 +316 1 -56.5515747 2.753995E-25 81.586676422594735 0 +317 1 168.155884 1 0 1 +318 0 -489.2794 0 0 0 +319 0 -232.038055 0 0 0 +320 1 16.8273315 0.99999994 8.5991327994145617E-08 1 +321 0 +322 0 -414.5997 0 0 0 +323 1 72.79901 1 0 1 +324 0 -413.829132 0 0 0 +325 0 -334.540161 0 0 0 +326 1 -176.167816 0 Infinity 0 +327 0 -407.171082 0 0 0 +328 1 131.3811 1 0 1 +329 1 -125.164459 0 Infinity 0 +330 1 -127.383881 0 Infinity 0 +331 0 -410.5272 0 0 0 +332 0 -332.307831 0 0 0 +333 1 -17.0445251 3.95964967E-08 24.590051963836686 0 +334 1 77.07416 1 0 1 +335 0 -400.513 0 0 0 +336 1 89.2496948 1 0 1 +337 0 -413.829132 0 0 0 +338 0 -382.020966 0 0 0 +339 1 68.04913 1 0 1 +340 1 -70.21268 3.213822E-31 101.29548091112767 0 +341 0 -413.829132 0 0 0 +342 0 -400.898315 0 0 0 +343 0 -400.513 0 0 0 +344 1 -25.81427 6.151839E-12 37.242119386792353 0 +345 0 -400.513 0 0 0 +346 0 -337.034 0 0 0 +347 0 -347.8312 0 0 0 +348 1 -173.990021 0 Infinity 0 +349 1 -122.635986 0 Infinity 0 +350 0 -368.516632 0 0 0 +351 0 -414.214417 0 0 0 +352 0 -263.0901 0 0 0 +353 1 207.045776 1 0 1 +354 0 -407.171082 0 0 0 +355 0 -419.711182 0 0 0 +356 1 -264.1968 0 Infinity 0 +357 1 143.662415 1 0 1 +358 1 94.0296 1 0 1 +359 1 -100.401337 2.522337E-44 144.83007499855768 0 +360 1 294.12854 1 0 1 +361 1 286.876648 1 0 1 +362 0 -365.00592 0 0 0 +363 0 -206.582718 0 0 0 +364 0 -414.214417 0 0 0 +365 0 -407.556366 0 0 0 +366 1 336.557373 1 0 1 +367 1 294.622864 1 0 1 +368 0 -407.361328 0 0 0 +369 0 -388.8693 0 0 0 +370 0 -338.313324 0 0 0 +371 0 -407.361328 0 0 0 +372 0 -411.7511 0 0 0 +373 0 -391.943481 0 0 0 +374 0 -411.365784 0 0 0 +375 0 -400.513 0 0 0 +376 0 -407.171082 0 0 0 +377 0 -382.020966 0 0 0 +378 0 -369.0196 0 0 0 +379 0 -377.355042 0 0 0 +380 0 -400.513 0 0 0 +381 1 153.2738 1 0 1 +382 0 -338.1953 0 0 0 +383 0 -400.898315 0 0 0 +384 0 -400.898315 0 0 0 +385 0 -291.049133 0 0 0 +386 1 33.2608032 1 0 1 +387 0 -332.012054 0 0 0 +388 0 -393.7905 0 0 0 +389 0 -396.413422 0 0 0 +390 0 -390.4023 0 0 0 +391 1 229.0141 1 0 1 +392 0 -407.94165 0 0 0 +393 0 -375.3629 0 0 0 +394 0 -364.725433 0 0 0 +395 0 -407.94165 0 0 0 +396 0 -414.5997 0 0 0 +397 0 -401.668884 0 0 0 +398 0 -344.881836 0 0 0 +399 0 -358.067383 0 0 0 +400 1 216.038391 1 0 1 +401 0 -401.2836 0 0 0 +402 0 -316.972656 0 0 0 +403 0 -330.234436 0 0 0 +404 0 -303.91568 0 0 0 +405 0 -407.171082 0 0 0 +406 0 -362.457153 0 0 0 +407 0 -407.171082 0 0 0 +408 0 -278.598877 0 0 0 +409 0 -411.365784 0 0 0 +410 0 -407.171082 0 0 0 +411 0 +412 1 21.296814 1 0 1 +413 0 -418.794434 0 0 0 +414 1 17.1500549 0.99999994 8.5991327994145617E-08 1 +415 0 -158.312744 0 0 0 +416 1 -78.57623 7.495069E-35 113.36154154761103 0 +417 0 -407.171082 0 0 0 +418 0 -310.439728 0 0 0 +419 0 -377.760681 0 0 0 +420 0 -287.8263 0 0 0 +421 1 88.56616 1 0 1 +422 0 -295.7137 0 0 0 +423 0 -415.560547 0 0 0 +424 0 -401.2836 0 0 0 +425 1 236.068481 1 0 1 +426 0 -320.587067 0 0 0 +427 1 -44.00345 7.754345E-20 63.483556937036411 0 +428 0 -407.171082 0 0 0 +429 0 -407.556366 0 0 0 +430 0 -294.817322 0 0 0 +431 0 -418.3671 0 0 0 +432 0 -394.946381 0 0 0 +433 0 -321.876282 0 0 0 +434 0 122.56488 1 Infinity 1 +435 1 71.62463 1 0 1 +436 1 -2.05007935 0.114044361 3.1323329838202993 0 +437 0 -401.668884 0 0 0 +438 0 -374.979675 0 0 0 +439 0 -402.054169 0 0 0 +440 1 -9.103485 0.000111264941 13.133713305553275 0 +441 0 -234.828949 0 0 0 +442 0 -319.609375 0 0 0 +443 0 -376.829956 0 0 0 +444 0 -350.0185 0 0 0 +445 0 -400.898315 0 0 0 +446 0 -400.513 0 0 0 +447 0 -402.054169 0 0 0 +448 0 -375.3629 0 0 0 +449 1 173.530884 1 0 1 +450 0 -349.6393 0 0 0 +451 0 -402.054169 0 0 0 +452 0 -367.746063 0 0 0 +453 1 34.4112854 1 0 1 +454 0 -327.4605 0 0 0 +455 1 -234.852478 0 Infinity 0 +456 1 106.093872 1 0 1 +457 1 9.033997 0.999880731 0.0001720789042225489 1 +458 0 -405.478333 0 0 0 +459 0 -408.902466 0 0 0 +460 0 -368.516632 0 0 0 +461 0 -306.5514 0 0 0 +462 0 -368.901917 0 0 0 +463 0 -387.903 0 0 0 +464 0 -401.668884 0 0 0 +465 1 131.093628 1 0 1 +466 1 34.3149719 1 0 1 +467 1 90.55585 1 0 1 +468 0 -401.668884 0 0 0 +469 0 -410.5952 0 0 0 +470 0 -382.791565 0 0 0 +471 0 -368.901917 0 0 0 +472 0 -377.289368 0 0 0 +473 0 -401.668884 0 0 0 +474 0 -402.054169 0 0 0 +475 0 -401.2836 0 0 0 +476 0 -405.093018 0 0 0 +477 0 -401.668884 0 0 0 +478 0 -352.760254 0 0 0 +479 1 251.724915 1 0 1 +480 0 -376.904083 0 0 0 +481 0 -256.584167 0 0 0 +482 1 161.241211 1 0 1 +483 1 143.984924 1 0 1 +484 0 -405.478333 0 0 0 +485 0 -294.4593 0 0 0 +486 0 -382.791565 0 0 0 +487 1 245.079346 1 0 1 +488 1 36.02997 1 0 1 +489 1 -281.350861 0 Infinity 0 +490 0 -400.513 0 0 0 +491 1 113.965393 1 0 1 +492 0 -383.176849 0 0 0 +493 1 300.856079 1 0 1 +494 0 -216.785461 0 0 0 +495 0 -382.791565 0 0 0 +496 0 -375.3629 0 0 0 +497 0 -352.374939 0 0 0 +498 0 -408.326935 0 0 0 +499 0 -408.326935 0 0 0 +500 0 -415.755585 0 0 0 +501 0 -408.326935 0 0 0 +502 0 -391.558167 0 0 0 +503 0 -414.985 0 0 0 +504 0 -400.513 0 0 0 +505 0 -302.69574 0 0 0 +506 1 234.96405 1 0 1 +507 0 -329.463867 0 0 0 +508 0 -402.054169 0 0 0 +509 0 -400.898315 0 0 0 +510 0 -400.513 0 0 0 +511 0 -408.712219 0 0 0 +512 0 -402.054169 0 0 0 +513 0 -382.791565 0 0 0 +514 1 244.530945 1 0 1 +515 1 390.9422 1 0 1 +516 0 -375.3629 0 0 0 +517 0 -382.020966 0 0 0 +518 0 -387.938965 0 0 0 +519 1 13.2459412 0.9999982 2.5797420694119618E-06 1 +520 0 -425.6631 0 0 0 +521 0 -426.81897 0 0 0 +522 1 -162.323669 0 Infinity 0 +523 1 91.895874 1 0 1 +524 0 -407.94165 0 0 0 +525 0 -384.514832 0 0 0 +526 0 -401.668884 0 0 0 +527 0 -414.985 0 0 0 +528 0 -392.519 0 0 0 +529 0 -383.176849 0 0 0 +530 1 8.005951 0.999666631 0.00048102966772982418 1 +531 0 -362.457153 0 0 0 +532 0 -388.679047 0 0 0 +533 0 -407.94165 0 0 0 +534 0 -407.556366 0 0 0 +535 0 -403.908661 0 0 0 +536 0 -415.3703 0 0 0 +537 0 -418.794434 0 0 0 +538 0 -408.326935 0 0 0 +539 0 -409.097534 0 0 0 +540 0 -385.029175 0 0 0 +541 0 -401.2836 0 0 0 +542 0 -303.4663 0 0 0 +543 0 -408.326935 0 0 0 +544 0 -397.6359 0 0 0 +545 0 -408.712219 0 0 0 +546 1 408.09906 1 0 1 +547 0 -375.7482 0 0 0 +548 0 -376.133484 0 0 0 +549 1 141.684875 1 0 1 +550 0 -407.94165 0 0 0 +551 0 -413.829132 0 0 0 +552 0 -344.853363 0 0 0 +553 0 -164.293976 0 0 0 +554 0 -401.2836 0 0 0 +555 0 -226.694183 0 0 0 +556 0 -320.923584 0 0 0 +557 0 -368.516632 0 0 0 +558 0 -407.556366 0 0 0 +559 0 -408.712219 0 0 0 +560 0 -415.3703 0 0 0 +561 0 -415.3703 0 0 0 +562 0 -413.829132 0 0 0 +563 0 -407.94165 0 0 0 +564 0 -401.219147 0 0 0 +565 1 215.478271 1 0 1 +566 0 -418.023865 0 0 0 +567 0 -360.960815 0 0 0 +568 1 -99.8051453 4.484155E-44 144 0 +569 1 128.35553 1 0 1 +570 1 109.875549 1 0 1 +571 1 143.371338 1 0 1 +572 0 -407.94165 0 0 0 +573 0 -407.171082 0 0 0 +574 1 19.3930969 1 0 1 +575 0 -418.794434 0 0 0 +576 0 -408.712219 0 0 0 +577 0 -407.171082 0 0 0 +578 0 -407.171082 0 0 0 +579 0 -413.829132 0 0 0 +580 0 -412.136383 0 0 0 +581 1 99.5480957 1 0 1 +582 1 348.034058 1 0 1 +583 0 -401.2836 0 0 0 +584 0 -343.360443 0 0 0 +585 0 -400.513 0 0 0 +586 1 337.054138 1 0 1 +587 0 -394.946381 0 0 0 +588 1 -21.5651855 4.30882552E-10 31.111986269672936 0 +589 0 -402.054169 0 0 0 +590 1 -25.1401978 1.20712112E-11 36.269638601249056 0 +591 1 51.3045654 1 0 1 +592 1 -80.5669556 1.02380155E-35 116.23354721913097 0 +593 0 -405.478333 0 0 0 +594 1 82.7042847 1 0 1 +595 0 -408.712219 0 0 0 +596 0 -411.7511 0 0 0 +597 0 -422.2186 0 0 0 +598 0 -407.94165 0 0 0 +599 0 -337.8826 0 0 0 +600 0 -407.94165 0 0 0 +601 0 -382.020966 0 0 0 +602 0 -408.326935 0 0 0 +603 1 -126.92569 0 Infinity 0 +604 1 -93.5751953 2.295187E-41 135.00044034278059 0 +605 1 -47.6547852 2.01274975E-21 68.751322185612878 0 +606 0 -384.900116 0 0 0 +607 0 -400.513 0 0 0 +608 1 113.71698 1 0 1 +609 0 -402.054169 0 0 0 +610 1 40.86389 1 0 1 +611 1 118.382507 1 0 1 +612 1 380.642151 1 0 1 +613 0 -308.9685 0 0 0 +614 0 -382.40625 0 0 0 +615 0 -415.175232 0 0 0 +616 0 -407.94165 0 0 0 +617 0 +618 0 -408.326935 0 0 0 +619 0 -408.712219 0 0 0 +620 0 -407.94165 0 0 0 +621 0 -311.9197 0 0 0 +622 0 -379.369446 0 0 0 +623 0 -400.513 0 0 0 +624 0 -366.793365 0 0 0 +625 0 -362.232849 0 0 0 +626 1 -60.24713 6.839168E-27 86.918237673162878 0 +627 0 -306.166077 0 0 0 +628 0 -400.898315 0 0 0 +629 0 -401.668884 0 0 0 +630 0 -359.993835 0 0 0 +631 0 -408.712219 0 0 0 +632 0 -400.513 0 0 0 +633 1 1.76364136 0.8536651 0.22825787272289805 1 +634 0 -401.2836 0 0 0 +635 0 -411.55603 0 0 0 +636 1 84.4071045 1 0 1 +637 0 -312.084869 0 0 0 +638 0 -401.668884 0 0 0 +639 0 -368.516632 0 0 0 +640 0 -368.131348 0 0 0 +641 0 -407.94165 0 0 0 +642 0 -407.94165 0 0 0 +643 0 -400.513 0 0 0 +644 0 -400.898315 0 0 0 +645 0 -407.94165 0 0 0 +646 0 -354.643219 0 0 0 +647 0 -366.975464 0 0 0 +648 1 91.73328 1 0 1 +649 0 -407.94165 0 0 0 +650 0 -338.615753 0 0 0 +651 0 -324.3396 0 0 0 +652 0 -394.946381 0 0 0 +653 0 -408.326935 0 0 0 +654 0 -414.5997 0 0 0 +655 0 -407.94165 0 0 0 +656 0 -408.712219 0 0 0 +657 0 -363.256348 0 0 0 +658 1 190.37738 1 0 1 +659 0 -400.513 0 0 0 +660 0 -407.171082 0 0 0 +661 0 -414.985 0 0 0 +662 0 -414.019379 0 0 0 +663 0 -414.019379 0 0 0 +664 0 -374.4041 0 0 0 +665 0 -400.513 0 0 0 +666 0 -334.5725 0 0 0 +667 0 -414.5997 0 0 0 +668 1 -123.532867 0 Infinity 0 +669 1 231.38147 1 0 1 +670 1 230.878479 1 0 1 +671 0 -348.537 0 0 0 +672 0 -414.214417 0 0 0 +673 0 -332.040558 0 0 0 +674 0 -407.171082 0 0 0 +675 0 -411.941345 0 0 0 +676 0 -410.5952 0 0 0 +677 0 -402.054169 0 0 0 +678 0 -400.513 0 0 0 +679 0 -400.898315 0 0 0 +680 1 390.4923 1 0 1 +681 1 376.423279 1 0 1 +682 0 -401.604431 0 0 0 +683 0 -400.513 0 0 0 +684 0 -400.513 0 0 0 +685 0 -400.513 0 0 0 +686 0 -400.513 0 0 0 +687 0 -380.409943 0 0 0 +688 0 -401.668884 0 0 0 +689 0 -366.611267 0 0 0 +690 0 -366.975464 0 0 0 +691 1 143.604248 1 0 1 +692 0 -401.2836 0 0 0 +693 0 -403.098541 0 0 0 +694 0 -402.750641 0 0 0 +695 0 -400.898315 0 0 0 +696 1 48.056427 1 0 1 +697 1 31.3800049 1 0 1 +698 1 12.3017273 0.99999547 6.5353555352246199E-06 1 diff --git a/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj b/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj index 63234b0900..568bd94e3f 100644 --- a/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj +++ b/test/Microsoft.ML.Predictor.Tests/Microsoft.ML.Predictor.Tests.csproj @@ -28,6 +28,11 @@ + + + + + diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 60c9b21495..f0c992dc4c 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -231,6 +231,14 @@ public void BinaryClassifierLogisticRegressionTest() Done(); } + [Fact] + [TestCategory("Binary")] + public void BinaryClassifierSymSgdTest() + { + RunOneAllTests(TestLearners.symSGD, TestDatasets.breastCancer, summary: true); + Done(); + } + [Fact] [TestCategory("Binary")] public void BinaryClassifierTesterThresholdingTest() diff --git a/test/Microsoft.ML.TestFramework/Learners.cs b/test/Microsoft.ML.TestFramework/Learners.cs index 9672c88fe8..4eaad61355 100644 --- a/test/Microsoft.ML.TestFramework/Learners.cs +++ b/test/Microsoft.ML.TestFramework/Learners.cs @@ -163,6 +163,14 @@ static TestLearnersBase() BaselineProgress = true }; + // New. + public static PredictorAndArgs symSGD = new PredictorAndArgs + { + Trainer = new SubComponent("SymSGD", "nt=1"), + MamlArgs = new[] { "norm=no" }, + BaselineProgress = true + }; + // New. public static PredictorAndArgs logisticRegressionNonNegative = new PredictorAndArgs { diff --git a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj index 2eb04a1437..6c99725ffe 100644 --- a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj +++ b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj @@ -11,6 +11,7 @@ + diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index b088cfb05e..7360f0f4d0 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -22,5 +22,10 @@ + + + + + \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 9c4283df45..9e07e9448b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Transforms; using System.Collections.Generic; using System.Linq; +using System.Runtime.InteropServices; using Xunit; namespace Microsoft.ML.Scenarios @@ -32,6 +33,18 @@ public void TrainAndPredictSentimentModelTest() ValidateBinaryMetrics(metrics); } + [Fact] + public void TrainAndPredictSymSGDSentimentModelTest() + { + var pipeline = PreparePipelineSymSGD(); + var model = pipeline.Train(); + var testData = PrepareTextLoaderTestData(); + var evaluator = new BinaryClassificationEvaluator(); + var metrics = evaluator.Evaluate(model, testData); + ValidateExamplesSymSGD(model); + ValidateBinaryMetricsSymSGD(metrics); + } + [Fact] public void TrainAndPredictLightGBMSentimentModelTest() { @@ -175,6 +188,39 @@ public void CrossValidateSentimentModelTest() Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); } + private void ValidateBinaryMetricsSymSGD(BinaryClassificationMetrics metrics) + { + + Assert.Equal(.8889, metrics.Accuracy, 4); + Assert.Equal(1, metrics.Auc, 1); + Assert.Equal(0.96, metrics.Auprc, 2); + Assert.Equal(1, metrics.Entropy, 3); + Assert.Equal(.9, metrics.F1Score, 4); + Assert.Equal(.97, metrics.LogLoss, 3); + Assert.Equal(3.030, metrics.LogLossReduction, 3); + Assert.Equal(1, metrics.NegativePrecision, 3); + Assert.Equal(.778, metrics.NegativeRecall, 3); + Assert.Equal(.818, metrics.PositivePrecision, 3); + Assert.Equal(1, metrics.PositiveRecall); + + var matrix = metrics.ConfusionMatrix; + Assert.Equal(2, matrix.Order); + Assert.Equal(2, matrix.ClassNames.Count); + Assert.Equal("positive", matrix.ClassNames[0]); + Assert.Equal("negative", matrix.ClassNames[1]); + + Assert.Equal(9, matrix[0, 0]); + Assert.Equal(9, matrix["positive", "positive"]); + Assert.Equal(0, matrix[0, 1]); + Assert.Equal(0, matrix["positive", "negative"]); + + Assert.Equal(2, matrix[1, 0]); + Assert.Equal(2, matrix["negative", "positive"]); + Assert.Equal(7, matrix[1, 1]); + Assert.Equal(7, matrix["negative", "negative"]); + + } + private void ValidateBinaryMetricsLightGBM(BinaryClassificationMetrics metrics) { @@ -338,6 +384,55 @@ private LearningPipeline PreparePipelineLightGBM() return pipeline; } + private LearningPipeline PreparePipelineSymSGD() + { + var dataPath = GetDataPath(SentimentDataPath); + var pipeline = new LearningPipeline(); + + pipeline.Add(new Data.TextLoader(dataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { '\t' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(0) }, + Type = Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SentimentText", + Source = new [] { new TextLoaderRange(1) }, + Type = Data.DataKind.Text + } + } + } + }); + + pipeline.Add(new TextFeaturizer("Features", "SentimentText") + { + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + StopWordsRemover = new PredefinedStopWordsRemover(), + VectorNormalizer = TextTransformTextNormKind.L2, + CharFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false }, + WordFeatureExtractor = new NGramNgramExtractor() { NgramLength = 2, AllLengths = true } + }); + + + pipeline.Add(new SymSgdBinaryClassifier() { NumberOfThreads = 1}); + + pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); + return pipeline; + } + private void ValidateExamples(PredictionModel model, bool useLightGBM = false) { var sentiments = GetTestData(); @@ -359,6 +454,16 @@ private void ValidateExamplesLightGBM(PredictionModel model) + { + var sentiments = GetTestData(); + var predictions = model.Predict(sentiments); + Assert.Equal(2, predictions.Count()); + + Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); + Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); + } + private Data.TextLoader PrepareTextLoaderTestData() { var testDataPath = GetDataPath(SentimentTestPath);