Skip to content

Commit b2ac8e0

Browse files
authored
Fix marshalling of bool flags in MF (#3210)
1 parent d9df721 commit b2ac8e0

File tree

4 files changed

+61
-19
lines changed

4 files changed

+61
-19
lines changed

src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,19 @@ private struct MFParameter
133133
/// Specify if the factor matrices should be non-negative.
134134
/// </summary>
135135
[FieldOffset(48)]
136-
public int DoNmf;
136+
public byte DoNmf;
137137

138138
/// <summary>
139139
/// Set to true so that LIBMF may produce less information to STDOUT.
140140
/// </summary>
141-
[FieldOffset(52)]
142-
public int Quiet;
141+
[FieldOffset(49)]
142+
public byte Quiet;
143143

144144
/// <summary>
145145
/// Set to false so that LIBMF may reuse and modifiy the data passed in.
146146
/// </summary>
147-
[FieldOffset(56)]
148-
public int CopyData;
147+
[FieldOffset(50)]
148+
public byte CopyData;
149149
}
150150

151151
[StructLayout(LayoutKind.Explicit)]
@@ -223,9 +223,9 @@ public SafeTrainingAndModelBuffer(IHostEnvironment env, int fun, int k, int nrTh
223223
_mfParam.Eta = (float)eta;
224224
_mfParam.Alpha = (float)alpha;
225225
_mfParam.C = (float)c;
226-
_mfParam.DoNmf = doNmf ? 1 : 0;
227-
_mfParam.Quiet = quiet ? 1 : 0;
228-
_mfParam.CopyData = copyData ? 1 : 0;
226+
_mfParam.DoNmf = doNmf ? (byte)1 : (byte)0;
227+
_mfParam.Quiet = quiet ? (byte)1 : (byte)0;
228+
_mfParam.CopyData = copyData ? (byte)1 : (byte)0;
229229
}
230230

231231
~SafeTrainingAndModelBuffer()

src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,48 @@
99

1010
using namespace mf;
1111

12+
mf_parameter make_param(const mf_parameter_bridge *param_bridge)
13+
{
14+
mf_parameter param;
15+
param.fun = param_bridge->fun;
16+
param.k = param_bridge->k;
17+
param.nr_threads = param_bridge->nr_threads;
18+
param.nr_bins = param_bridge->nr_bins;
19+
param.nr_iters = param_bridge->nr_iters;
20+
param.lambda_p1 = param_bridge->lambda_p1;
21+
param.lambda_p2 = param_bridge->lambda_p2;
22+
param.lambda_q1 = param_bridge->lambda_q1;
23+
param.lambda_q2 = param_bridge->lambda_q2;
24+
param.eta = param_bridge->eta;
25+
param.alpha = param_bridge->alpha;
26+
param.c = param_bridge->c;
27+
param.do_nmf = param_bridge->do_nmf != 0 ? true : false;
28+
param.quiet = param_bridge->quiet != 0 ? true : false;
29+
param.copy_data = param_bridge->copy_data != 0 ? true : false;
30+
return param;
31+
}
32+
1233
EXPORT_API(void) MFDestroyModel(mf_model *&model)
1334
{
1435
return mf_destroy_model(&model);
1536
}
1637

17-
EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter *param)
38+
EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter_bridge *param_bridge)
1839
{
19-
return mf_train(prob, *param);
40+
auto param = make_param(param_bridge);
41+
return mf_train(prob, param);
2042
}
2143

22-
EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter *param)
44+
EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter_bridge *param_bridge)
2345
{
24-
return mf_train_with_validation(tr, va, *param);
46+
auto param = make_param(param_bridge);
47+
return mf_train_with_validation(tr, va, param);
2548
}
2649

27-
28-
EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter *param)
50+
EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter_bridge *param_bridge)
2951
{
30-
return mf_cross_validation(prob, nr_folds, *param);
52+
auto param = make_param(param_bridge);
53+
return mf_cross_validation(prob, nr_folds, param);
3154
}
3255

3356
EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx)

src/Native/MatrixFactorizationNative/UnmanagedMemory.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,31 @@
88

99
using namespace mf;
1010

11+
struct mf_parameter_bridge
12+
{
13+
int32_t fun;
14+
int32_t k;
15+
int32_t nr_threads;
16+
int32_t nr_bins;
17+
int32_t nr_iters;
18+
float lambda_p1;
19+
float lambda_p2;
20+
float lambda_q1;
21+
float lambda_q2;
22+
float eta;
23+
float alpha;
24+
float c;
25+
uint8_t do_nmf;
26+
uint8_t quiet;
27+
uint8_t copy_data;
28+
};
29+
1130
EXPORT_API(void) MFDestroyModel(mf_model *&model);
1231

13-
EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter *param);
32+
EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter_bridge *parameter_bridge);
1433

15-
EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter *param);
34+
EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter_bridge *parameter_bridge);
1635

17-
EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter* param);
36+
EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter_bridge* parameter_bridge);
1837

1938
EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx);
Submodule libmf updated 1 file

0 commit comments

Comments
 (0)