diff --git a/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs b/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs
index 9bd70d9e1a..fd80cc216f 100644
--- a/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs
+++ b/src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs
@@ -133,19 +133,19 @@ private struct MFParameter
/// Specify if the factor matrices should be non-negative.
///
[FieldOffset(48)]
- public int DoNmf;
+ public byte DoNmf;
///
/// Set to true so that LIBMF may produce less information to STDOUT.
///
- [FieldOffset(52)]
- public int Quiet;
+ [FieldOffset(49)]
+ public byte Quiet;
///
/// Set to false so that LIBMF may reuse and modifiy the data passed in.
///
- [FieldOffset(56)]
- public int CopyData;
+ [FieldOffset(50)]
+ public byte CopyData;
}
[StructLayout(LayoutKind.Explicit)]
@@ -223,9 +223,9 @@ public SafeTrainingAndModelBuffer(IHostEnvironment env, int fun, int k, int nrTh
_mfParam.Eta = (float)eta;
_mfParam.Alpha = (float)alpha;
_mfParam.C = (float)c;
- _mfParam.DoNmf = doNmf ? 1 : 0;
- _mfParam.Quiet = quiet ? 1 : 0;
- _mfParam.CopyData = copyData ? 1 : 0;
+ _mfParam.DoNmf = doNmf ? (byte)1 : (byte)0;
+ _mfParam.Quiet = quiet ? (byte)1 : (byte)0;
+ _mfParam.CopyData = copyData ? (byte)1 : (byte)0;
}
~SafeTrainingAndModelBuffer()
diff --git a/src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp b/src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp
index 6f93cf817e..75b6ccac93 100644
--- a/src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp
+++ b/src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp
@@ -9,25 +9,48 @@
using namespace mf;
+mf_parameter make_param(const mf_parameter_bridge *param_bridge)
+{
+ mf_parameter param;
+ param.fun = param_bridge->fun;
+ param.k = param_bridge->k;
+ param.nr_threads = param_bridge->nr_threads;
+ param.nr_bins = param_bridge->nr_bins;
+ param.nr_iters = param_bridge->nr_iters;
+ param.lambda_p1 = param_bridge->lambda_p1;
+ param.lambda_p2 = param_bridge->lambda_p2;
+ param.lambda_q1 = param_bridge->lambda_q1;
+ param.lambda_q2 = param_bridge->lambda_q2;
+ param.eta = param_bridge->eta;
+ param.alpha = param_bridge->alpha;
+ param.c = param_bridge->c;
+ param.do_nmf = param_bridge->do_nmf != 0 ? true : false;
+ param.quiet = param_bridge->quiet != 0 ? true : false;
+ param.copy_data = param_bridge->copy_data != 0 ? true : false;
+ return param;
+}
+
EXPORT_API(void) MFDestroyModel(mf_model *&model)
{
return mf_destroy_model(&model);
}
-EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter *param)
+EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter_bridge *param_bridge)
{
- return mf_train(prob, *param);
+ auto param = make_param(param_bridge);
+ return mf_train(prob, param);
}
-EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter *param)
+EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter_bridge *param_bridge)
{
- return mf_train_with_validation(tr, va, *param);
+ auto param = make_param(param_bridge);
+ return mf_train_with_validation(tr, va, param);
}
-
-EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter *param)
+EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter_bridge *param_bridge)
{
- return mf_cross_validation(prob, nr_folds, *param);
+ auto param = make_param(param_bridge);
+ return mf_cross_validation(prob, nr_folds, param);
}
EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx)
diff --git a/src/Native/MatrixFactorizationNative/UnmanagedMemory.h b/src/Native/MatrixFactorizationNative/UnmanagedMemory.h
index 6007d35e30..2b07d7843b 100644
--- a/src/Native/MatrixFactorizationNative/UnmanagedMemory.h
+++ b/src/Native/MatrixFactorizationNative/UnmanagedMemory.h
@@ -8,12 +8,31 @@
using namespace mf;
+struct mf_parameter_bridge
+{
+ int32_t fun;
+ int32_t k;
+ int32_t nr_threads;
+ int32_t nr_bins;
+ int32_t nr_iters;
+ float lambda_p1;
+ float lambda_p2;
+ float lambda_q1;
+ float lambda_q2;
+ float eta;
+ float alpha;
+ float c;
+ uint8_t do_nmf;
+ uint8_t quiet;
+ uint8_t copy_data;
+};
+
EXPORT_API(void) MFDestroyModel(mf_model *&model);
-EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter *param);
+EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter_bridge *parameter_bridge);
-EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter *param);
+EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter_bridge *parameter_bridge);
-EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter* param);
+EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter_bridge* parameter_bridge);
EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx);
diff --git a/src/Native/MatrixFactorizationNative/libmf b/src/Native/MatrixFactorizationNative/libmf
index 5b055ea473..8262f339db 160000
--- a/src/Native/MatrixFactorizationNative/libmf
+++ b/src/Native/MatrixFactorizationNative/libmf
@@ -1 +1 @@
-Subproject commit 5b055ea473756bd14f56b49db7e0483271788cc2
+Subproject commit 8262f339dba0792bf0f3892bae92b3dd4432afc5