diff --git a/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs b/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs index 9bd70d9e1a..fd80cc216f 100644 --- a/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs +++ b/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs @@ -133,19 +133,19 @@ private struct MFParameter /// Specify if the factor matrices should be non-negative. /// [FieldOffset(48)] - public int DoNmf; + public byte DoNmf; /// /// Set to true so that LIBMF may produce less information to STDOUT. /// - [FieldOffset(52)] - public int Quiet; + [FieldOffset(49)] + public byte Quiet; /// /// Set to false so that LIBMF may reuse and modifiy the data passed in. /// - [FieldOffset(56)] - public int CopyData; + [FieldOffset(50)] + public byte CopyData; } [StructLayout(LayoutKind.Explicit)] @@ -223,9 +223,9 @@ public SafeTrainingAndModelBuffer(IHostEnvironment env, int fun, int k, int nrTh _mfParam.Eta = (float)eta; _mfParam.Alpha = (float)alpha; _mfParam.C = (float)c; - _mfParam.DoNmf = doNmf ? 1 : 0; - _mfParam.Quiet = quiet ? 1 : 0; - _mfParam.CopyData = copyData ? 1 : 0; + _mfParam.DoNmf = doNmf ? (byte)1 : (byte)0; + _mfParam.Quiet = quiet ? (byte)1 : (byte)0; + _mfParam.CopyData = copyData ? (byte)1 : (byte)0; } ~SafeTrainingAndModelBuffer() diff --git a/src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp b/src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp index 6f93cf817e..75b6ccac93 100644 --- a/src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp +++ b/src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp @@ -9,25 +9,48 @@ using namespace mf; +mf_parameter make_param(const mf_parameter_bridge *param_bridge) +{ + mf_parameter param; + param.fun = param_bridge->fun; + param.k = param_bridge->k; + param.nr_threads = param_bridge->nr_threads; + param.nr_bins = param_bridge->nr_bins; + param.nr_iters = param_bridge->nr_iters; + param.lambda_p1 = param_bridge->lambda_p1; + param.lambda_p2 = param_bridge->lambda_p2; + param.lambda_q1 = param_bridge->lambda_q1; + param.lambda_q2 = param_bridge->lambda_q2; + param.eta = param_bridge->eta; + param.alpha = param_bridge->alpha; + param.c = param_bridge->c; + param.do_nmf = param_bridge->do_nmf != 0 ? true : false; + param.quiet = param_bridge->quiet != 0 ? true : false; + param.copy_data = param_bridge->copy_data != 0 ? true : false; + return param; +} + EXPORT_API(void) MFDestroyModel(mf_model *&model) { return mf_destroy_model(&model); } -EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter *param) +EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter_bridge *param_bridge) { - return mf_train(prob, *param); + auto param = make_param(param_bridge); + return mf_train(prob, param); } -EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter *param) +EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter_bridge *param_bridge) { - return mf_train_with_validation(tr, va, *param); + auto param = make_param(param_bridge); + return mf_train_with_validation(tr, va, param); } - -EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter *param) +EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter_bridge *param_bridge) { - return mf_cross_validation(prob, nr_folds, *param); + auto param = make_param(param_bridge); + return mf_cross_validation(prob, nr_folds, param); } EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx) diff --git a/src/Native/MatrixFactorizationNative/UnmanagedMemory.h b/src/Native/MatrixFactorizationNative/UnmanagedMemory.h index 6007d35e30..2b07d7843b 100644 --- a/src/Native/MatrixFactorizationNative/UnmanagedMemory.h +++ b/src/Native/MatrixFactorizationNative/UnmanagedMemory.h @@ -8,12 +8,31 @@ using namespace mf; +struct mf_parameter_bridge +{ + int32_t fun; + int32_t k; + int32_t nr_threads; + int32_t nr_bins; + int32_t nr_iters; + float lambda_p1; + float lambda_p2; + float lambda_q1; + float lambda_q2; + float eta; + float alpha; + float c; + uint8_t do_nmf; + uint8_t quiet; + uint8_t copy_data; +}; + EXPORT_API(void) MFDestroyModel(mf_model *&model); -EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter *param); +EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter_bridge *parameter_bridge); -EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter *param); +EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter_bridge *parameter_bridge); -EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter* param); +EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter_bridge* parameter_bridge); EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx); diff --git a/src/Native/MatrixFactorizationNative/libmf b/src/Native/MatrixFactorizationNative/libmf index 5b055ea473..8262f339db 160000 --- a/src/Native/MatrixFactorizationNative/libmf +++ b/src/Native/MatrixFactorizationNative/libmf @@ -1 +1 @@ -Subproject commit 5b055ea473756bd14f56b49db7e0483271788cc2 +Subproject commit 8262f339dba0792bf0f3892bae92b3dd4432afc5