diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs index de68f9be1e..a0b0d35ce6 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs @@ -89,8 +89,19 @@ namespace Microsoft.ML.Trainers public sealed class MatrixFactorizationTrainer : TrainerBase, IEstimator { + public enum LossFunctionType { SquareLossRegression = 0, SquareLossOneClass = 12 }; + public sealed class Arguments { + /// + /// Loss function minimized for finding factor matrices. Two values are allowed, 0 or 12. The values 0 means traditional collaborative filtering + /// problem with squared loss. The value 12 triggers one-class matrix factorization for implicit-feedback recommendation problem. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Loss function minimized for finding factor matrices.")] + [TGUI(SuggestedSweeps = "0,12")] + [TlcModule.SweepableDiscreteParam("LossFunction", new object[] { LossFunctionType.SquareLossRegression, LossFunctionType.SquareLossOneClass })] + public LossFunctionType LossFunction = LossFunctionType.SquareLossRegression; + [Argument(ArgumentType.AtMostOnce, HelpText = "Regularization parameter. " + "It's the weight of factor matrices' norms in the objective function minimized by matrix factorization's algorithm. " + "A small value could cause over-fitting.")] @@ -116,6 +127,33 @@ public sealed class Arguments [TlcModule.SweepableDiscreteParam("Eta", new object[] { 0.001f, 0.01f, 0.1f })] public double Eta = 0.1; + /// + /// Importance of unobserved (i.e., negative) entries' loss in one-class matrix factorization. + /// In general, only a few of matrix entries (e.g., less than 1%) in the training are observed (i.e., positive). + /// To balance the contributions from unobserved and obverved in the overall loss function, this parameter is + /// usually a small value so that the solver is able to find a factorization equally good to unobserved and observed + /// entries. If only 10000 observed entries present in a 200000-by-300000 training matrix, one can try Alpha = 10000 / (200000*300000 - 10000). + /// When most entries in the training matrix are observed, one can use Alpha >> 1; for example, if only 10000 in previous + /// matrix is not observed, one can try Alpha = (200000 * 300000 - 10000) / 10000. Consequently, + /// Alpha = (# of observed entries) / (# of unobserved entries) can make observed and unobserved entries equally important + /// in the minimized loss function. However, the best setting in machine learning is alwasy data-depedent so user still needs to + /// try multiple values. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Importance of unobserved entries' loss in one-class matrix factorization.")] + [TGUI(SuggestedSweeps = "1,0.01,0.0001,0.000001")] + [TlcModule.SweepableDiscreteParam("Alpha", new object[] { 1f, 0.01f, 0.0001f, 0.000001f})] + public double Alpha = 0.0001; + + /// + /// Desired negative entries value in one-class matrix factorization. In one-class matrix factorization, all matrix values observed are one + /// (which can be viewed as positive cases in binary classification) while unobserved values (which can be viewed as negative cases in binary + /// classification) need to be specified manually using this option. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Desired negative entries' value in one-class matrix factorization")] + [TGUI(SuggestedSweeps = "0.000001,0,0001,0.01")] + [TlcModule.SweepableDiscreteParam("C", new object[] { 0.000001f, 0.0001f, 0.01f })] + public double C = 0.000001f; + [Argument(ArgumentType.AtMostOnce, HelpText = "Number of threads can be used in the training procedure.", ShortName = "t")] public int? NumThreads; @@ -131,10 +169,13 @@ public sealed class Arguments + "and the values of the matrix are ratings. "; // LIBMF's parameter + private readonly int _fun; private readonly double _lambda; private readonly int _k; private readonly int _iter; private readonly double _eta; + private readonly double _alpha; + private readonly double _c; private readonly int _threads; private readonly bool _quiet; private readonly bool _doNmf; @@ -192,11 +233,15 @@ public MatrixFactorizationTrainer(IHostEnvironment env, Arguments args) : base(e Host.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), posError); Host.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), posError); Host.CheckUserArg(args.Eta > 0, nameof(args.Eta), posError); + Host.CheckUserArg(args.Alpha > 0, nameof(args.Alpha), posError); + _fun = (int)args.LossFunction; _lambda = args.Lambda; _k = args.K; _iter = args.NumIterations; _eta = args.Eta; + _alpha = args.Alpha; + _c = args.C; _threads = args.NumThreads ?? Environment.ProcessorCount; _quiet = args.Quiet; _doNmf = args.NonNegative; @@ -224,10 +269,13 @@ public MatrixFactorizationTrainer(IHostEnvironment env, var args = new Arguments(); advancedSettings?.Invoke(args); + _fun = (int)args.LossFunction; _lambda = args.Lambda; _k = args.K; _iter = args.NumIterations; _eta = args.Eta; + _alpha = args.Alpha; + _c = args.C; _threads = args.NumThreads ?? Environment.ProcessorCount; _quiet = args.Quiet; _doNmf = args.NonNegative; @@ -338,8 +386,8 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, private SafeTrainingAndModelBuffer PrepareBuffer() { - return new SafeTrainingAndModelBuffer(Host, _k, Math.Max(20, 2 * _threads), - _threads, _iter, _lambda, _eta, _doNmf, _quiet, copyData: false); + return new SafeTrainingAndModelBuffer(Host, _fun, _k, _threads, Math.Max(20, 2 * _threads), + _iter, _lambda, _eta, _alpha, _c, _doNmf, _quiet, copyData: false); } /// diff --git a/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs b/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs index 615b0875f8..33bb90ae0b 100644 --- a/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs +++ b/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs @@ -44,23 +44,107 @@ private unsafe struct MFProblem [StructLayout(LayoutKind.Explicit)] private struct MFParameter { + /// + /// Enum of loss functions which can be minimized. + /// 0: square loss for regression. + /// 1: absolute loss for regression. + /// 2: KL-divergence for regression. + /// 5: logistic loss for binary classification. + /// 6: squared hinge loss for binary classification. + /// 7: hinge loss for binary classification. + /// 10: row-wise Bayesian personalized ranking. + /// 11: column-wise Bayesian personalized ranking. + /// 12: squared loss for implicit-feedback matrix factorization. + /// Fun 12 is solved by a coordinate descent method while other functions invoke + /// a stochastic gradient method. + /// [FieldOffset(0)] - public int K; + public int Fun; + + /// + /// Rank of factor matrices. + /// [FieldOffset(4)] - public int NrThreads; + public int K; + + /// + /// Number of threads which can be used for training. + /// [FieldOffset(8)] - public int NrBins; + public int NrThreads; + + /// + /// Number of blocks that the training matrix is divided into. The parallel stochastic gradient + /// method in LIBMF processes assigns each thread a block at one time. The ratings in one block + /// would be sequentially accessed (not randomaly accessed like standard stochastic gradient methods). + /// [FieldOffset(12)] - public int NrIters; + public int NrBins; + + /// + /// Number of training iteration. At one iteration, all values in the training matrix are roughly accessed once. + /// [FieldOffset(16)] - public float Lambda; + public int NrIters; + + /// + /// L1-norm regularization coefficient of left factor matrix. + /// [FieldOffset(20)] - public float Eta; + public float LambdaP1; + + /// + /// L2-norm regularization coefficient of left factor matrix. + /// [FieldOffset(24)] - public int DoNmf; + public float LambdaP2; + + /// + /// L1-norm regularization coefficient of right factor matrix. + /// [FieldOffset(28)] - public int Quiet; + public float LambdaQ1; + + /// + /// L2-norm regularization coefficient of right factor matrix. + /// [FieldOffset(32)] + public float LambdaQ2; + + /// + /// Learning rate of LIBMF's stochastic gradient method. + /// + [FieldOffset(36)] + public float Eta; + + /// + /// Coefficient of loss function on unobserved entries in the training matrix. It's used only with fun=12. + /// + [FieldOffset(40)] + public float Alpha; + + /// + /// Desired value of unobserved entries in the training matrix. It's used only with fun=12. + /// + [FieldOffset(44)] + public float C; + + /// + /// Specify if the factor matrices should be non-negative. + /// + [FieldOffset(48)] + public int DoNmf; + + /// + /// Set to true so that LIBMF may produce less information to STDOUT. + /// + [FieldOffset(52)] + public int Quiet; + + /// + /// Set to false so that LIBMF may reuse and modifiy the data passed in. + /// + [FieldOffset(56)] public int CopyData; } @@ -68,14 +152,36 @@ private struct MFParameter private unsafe struct MFModel { [FieldOffset(0)] - public int M; + public int Fun; + /// + /// Number of rows in the training matrix. + /// [FieldOffset(4)] - public int N; + public int M; + /// + /// Number of columns in the training matrix. + /// [FieldOffset(8)] + public int N; + /// + /// Rank of factor matrices. + /// + [FieldOffset(12)] public int K; + /// + /// Average value in the training matrix. + /// [FieldOffset(16)] + public float B; + /// + /// Left factor matrix. Its shape is M-by-K stored in row-major format. + /// + [FieldOffset(24)] // pointer is 8-byte on 64-bit machine. public float* P; - [FieldOffset(24)] + /// + /// Right factor matrix. Its shape is N-by-K stored in row-major format. + /// + [FieldOffset(32)] // pointer is 8-byte on 64-bit machine. public float* Q; } @@ -100,16 +206,23 @@ private unsafe struct MFModel private unsafe MFModel* _pMFModel; private readonly IHost _host; - public SafeTrainingAndModelBuffer(IHostEnvironment env, int k, int nrBins, int nrThreads, int nrIters, double lambda, double eta, + public SafeTrainingAndModelBuffer(IHostEnvironment env, int fun, int k, int nrThreads, + int nrBins, int nrIters, double lambda, double eta, double alpha, double c, bool doNmf, bool quiet, bool copyData) { _host = env.Register("SafeTrainingAndModelBuffer"); + _mfParam.Fun = fun; _mfParam.K = k; - _mfParam.NrBins = nrBins; _mfParam.NrThreads = nrThreads; + _mfParam.NrBins = nrBins; _mfParam.NrIters = nrIters; - _mfParam.Lambda = (float)lambda; + _mfParam.LambdaP1 = 0; + _mfParam.LambdaP2 = (float)lambda; + _mfParam.LambdaQ1 = 0; + _mfParam.LambdaQ2 = (float)lambda; _mfParam.Eta = (float)eta; + _mfParam.Alpha = (float)alpha; + _mfParam.C = (float)c; _mfParam.DoNmf = doNmf ? 1 : 0; _mfParam.Quiet = quiet ? 1 : 0; _mfParam.CopyData = copyData ? 1 : 0; diff --git a/src/Native/MatrixFactorizationNative/libmf b/src/Native/MatrixFactorizationNative/libmf index 1ecc365249..f92a18161b 160000 --- a/src/Native/MatrixFactorizationNative/libmf +++ b/src/Native/MatrixFactorizationNative/libmf @@ -1 +1 @@ -Subproject commit 1ecc365249e5cac5e72c66317a141298dc52f6e3 +Subproject commit f92a18161b6824fda4c4ab698a69d299a836841a diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index 3fd11ec6f2..4b1e8cf92e 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -331,5 +331,112 @@ public void MatrixFactorizationInMemoryDataZeroBaseIndex() // The presence of out-of-range indexes may lead to NaN Assert.True(float.IsNaN(pred.Score)); } + + // The following ingredients are used to define a 3-by-2 one-class + // matrix used in a test, OneClassMatrixFactorizationInMemoryDataZeroBaseIndex, + // for one-class matrix factorization. One-class matrix means that all + // the available elements in the training matrix are 1. Such a matrix + // is common. Let's use online game store as an example. Assume that + // user IDs are row indexes and game IDs are column indexes. By + // encoding all users' purchase history as a matrix (i.e., if the value + // at the u-th row and the v-th column is 1, then the u-th user owns + // the v-th game), a one-class matrix gets created because all matrix + // elements are 1. If you train a prediction model from that matrix + // using standard collaborative filtering, all your predictions would + // be 1! One-class matrix factorization assumes unspecified matrix + // entries are all 0 (or a small constant value selected by the user) + // so that the trainined model can assign purchased itemas higher + // scores than those not purchased. + private const int _oneClassMatrixColumnCount = 2; + private const int _oneClassMatrixRowCount = 3; + + private class OneClassMatrixElementZeroBased + { + [KeyType(Contiguous = true, Count = _oneClassMatrixColumnCount, Min = 0)] + public uint MatrixColumnIndex; + [KeyType(Contiguous = true, Count = _oneClassMatrixRowCount, Min = 0)] + public uint MatrixRowIndex; + public float Value; + } + + private class OneClassMatrixElementZeroBasedForScore + { + [KeyType(Contiguous = true, Count = _oneClassMatrixColumnCount, Min = 0)] + public uint MatrixColumnIndex; + [KeyType(Contiguous = true, Count = _oneClassMatrixRowCount, Min = 0)] + public uint MatrixRowIndex; + public float Value; + public float Score; + } + + [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // This test is being fixed as part of issue #1441. + public void OneClassMatrixFactorizationInMemoryDataZeroBaseIndex() + { + // Create an in-memory matrix as a list of tuples (column index, row index, value). For one-class matrix + // factorization problem, unspecified matrix elements are all a constant provided by user. If that constant is 0.15, + // the following list means a 3-by-2 training matrix with elements: + // (0, 0, 1), (1, 1, 1), (0, 2, 1), (0, 1, 0.15), (1, 0, 0.15), (1, 2, 0.15). + // because matrix elements at (0, 1), (1, 0), and (1, 2) are not specified. + var dataMatrix = new List(); + dataMatrix.Add(new OneClassMatrixElementZeroBased() { MatrixColumnIndex = 0, MatrixRowIndex = 0, Value = 1 }); + dataMatrix.Add(new OneClassMatrixElementZeroBased() { MatrixColumnIndex = 1, MatrixRowIndex = 1, Value = 1 }); + dataMatrix.Add(new OneClassMatrixElementZeroBased() { MatrixColumnIndex = 0, MatrixRowIndex = 2, Value = 1 }); + + // Convert the in-memory matrix into an IDataView so that ML.NET components can consume it. + var dataView = ComponentCreation.CreateDataView(Env, dataMatrix); + + // Create a matrix factorization trainer which may consume "Value" as the training label, "MatrixColumnIndex" as the + // matrix's column index, and "MatrixRowIndex" as the matrix's row index. + var mlContext = new MLContext(seed: 1, conc: 1); + var pipeline = new MatrixFactorizationTrainer(mlContext, + nameof(OneClassMatrixElementZeroBased.MatrixColumnIndex), + nameof(OneClassMatrixElementZeroBased.MatrixRowIndex), + nameof(OneClassMatrixElementZeroBased.Value), + advancedSettings: s => + { + s.LossFunction = MatrixFactorizationTrainer.LossFunctionType.SquareLossOneClass; + s.NumIterations = 100; + s.NumThreads = 1; // To eliminate randomness, # of threads must be 1. + // Let's test non-default regularization coefficient. + s.Lambda = 0.025; + s.K = 16; + // Importance coefficient of loss function over matrix elements not specified in the input matrix. + s.Alpha = 0.01; + // Desired value for matrix elements not specified in the input matrix. + s.C = 0.15; + }); + + // Train a matrix factorization model. + var model = pipeline.Fit(dataView); + + // Apply the trained model to the training set. + var prediction = model.Transform(dataView); + + // Calculate regression matrices for the prediction result. + var metrics = mlContext.Regression.Evaluate(prediction, label: "Value", score: "Score"); + + // Make sure the prediction error is not too large. + Assert.InRange(metrics.L2, 0, 0.0016); + + // Create data for testing. Note that the 2nd element is not specified in the training data so it should + // be close to the constant specified by s.C = 0.15. Comparing with the data structure used in training phase, + // one extra float is added into OneClassMatrixElementZeroBasedForScore for storing the prediction result. Note + // that the prediction engine may ignore Value and assign the predicted value to Score. + var testDataMatrix = new List(); + testDataMatrix.Add(new OneClassMatrixElementZeroBasedForScore() { MatrixColumnIndex = 0, MatrixRowIndex = 0, Value = 0, Score = 0 }); + testDataMatrix.Add(new OneClassMatrixElementZeroBasedForScore() { MatrixColumnIndex = 1, MatrixRowIndex = 2, Value = 0, Score = 0 }); + + // Convert the in-memory matrix into an IDataView so that ML.NET components can consume it. + var testDataView = ComponentCreation.CreateDataView(Env, testDataMatrix); + + // Apply the trained model to the test data. + var testPrediction = model.Transform(testDataView); + + var testResults = new List(testPrediction.AsEnumerable(mlContext, false)); + // Positive example (i.e., examples can be found in dataMatrix) is close to 1. + CompareNumbersWithTolerance(0.982391, testResults[0].Score, digitsOfPrecision: 5); + // Negative example (i.e., examples can not be found in dataMatrix) is close to 0.15 (specified by s.C = 0.15 in the trainer). + CompareNumbersWithTolerance(0.141411, testResults[1].Score, digitsOfPrecision: 5); + } } } \ No newline at end of file