Skip to content

Commit fdd24cf

Browse files
authored
Polish marshalling of MF model and MF problem and enable 32-bit tests (#3227)
1 parent 8a03cf9 commit fdd24cf

File tree

4 files changed

+146
-54
lines changed

4 files changed

+146
-54
lines changed

src/Microsoft.ML.Recommender/SafeTrainingAndModelBuffer.cs

+39-33
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,50 @@ namespace Microsoft.ML.Recommender.Internal
1717
/// </summary>
1818
internal sealed class SafeTrainingAndModelBuffer : IDisposable
1919
{
20-
[StructLayout(LayoutKind.Explicit)]
20+
[StructLayout(LayoutKind.Sequential)]
2121
private struct MFNode
2222
{
23-
[FieldOffset(0)]
23+
/// <summary>
24+
/// Row index.
25+
/// </summary>
2426
public int U;
25-
[FieldOffset(4)]
27+
28+
/// <summary>
29+
/// Column index;
30+
/// </summary>
2631
public int V;
27-
[FieldOffset(8)]
32+
33+
/// <summary>
34+
/// Matrix element's value at <see cref="U"/>-th row and <see cref="V"/>-th column.
35+
/// </summary>
2836
public float R;
2937
}
3038

31-
[StructLayout(LayoutKind.Explicit)]
39+
[StructLayout(LayoutKind.Sequential)]
3240
private unsafe struct MFProblem
3341
{
34-
[FieldOffset(0)]
42+
/// <summary>
43+
/// Number of rows.
44+
/// </summary>
3545
public int M;
36-
[FieldOffset(4)]
46+
47+
/// <summary>
48+
/// Number of columns.
49+
/// </summary>
3750
public int N;
38-
[FieldOffset(8)]
51+
52+
/// <summary>
53+
/// Number of specified matrix elements in <see cref="R"/>.
54+
/// </summary>
3955
public long Nnz;
40-
[FieldOffset(16)]
56+
57+
/// <summary>
58+
/// Specified matrix elements.
59+
/// </summary>
4160
public MFNode* R;
4261
}
4362

44-
[StructLayout(LayoutKind.Explicit)]
63+
[StructLayout(LayoutKind.Sequential)]
4564
private struct MFParameter
4665
{
4766
/// <summary>
@@ -58,130 +77,117 @@ private struct MFParameter
5877
/// Fun 12 is solved by a coordinate descent method while other functions invoke
5978
/// a stochastic gradient method.
6079
/// </summary>
61-
[FieldOffset(0)]
6280
public int Fun;
6381

6482
/// <summary>
6583
/// Rank of factor matrices.
6684
/// </summary>
67-
[FieldOffset(4)]
6885
public int K;
6986

7087
/// <summary>
7188
/// Number of threads which can be used for training.
7289
/// </summary>
73-
[FieldOffset(8)]
7490
public int NrThreads;
7591

7692
/// <summary>
7793
/// Number of blocks that the training matrix is divided into. The parallel stochastic gradient
7894
/// method in LIBMF processes assigns each thread a block at one time. The ratings in one block
7995
/// would be sequentially accessed (not randomaly accessed like standard stochastic gradient methods).
8096
/// </summary>
81-
[FieldOffset(12)]
8297
public int NrBins;
8398

8499
/// <summary>
85100
/// Number of training iteration. At one iteration, all values in the training matrix are roughly accessed once.
86101
/// </summary>
87-
[FieldOffset(16)]
88102
public int NrIters;
89103

90104
/// <summary>
91105
/// L1-norm regularization coefficient of left factor matrix.
92106
/// </summary>
93-
[FieldOffset(20)]
94107
public float LambdaP1;
95108

96109
/// <summary>
97110
/// L2-norm regularization coefficient of left factor matrix.
98111
/// </summary>
99-
[FieldOffset(24)]
100112
public float LambdaP2;
101113

102114
/// <summary>
103115
/// L1-norm regularization coefficient of right factor matrix.
104116
/// </summary>
105-
[FieldOffset(28)]
106117
public float LambdaQ1;
107118

108119
/// <summary>
109120
/// L2-norm regularization coefficient of right factor matrix.
110121
/// </summary>
111-
[FieldOffset(32)]
112122
public float LambdaQ2;
113123

114124
/// <summary>
115125
/// Learning rate of LIBMF's stochastic gradient method.
116126
/// </summary>
117-
[FieldOffset(36)]
118127
public float Eta;
119128

120129
/// <summary>
121130
/// Coefficient of loss function on unobserved entries in the training matrix. It's used only with fun=12.
122131
/// </summary>
123-
[FieldOffset(40)]
124132
public float Alpha;
125133

126134
/// <summary>
127135
/// Desired value of unobserved entries in the training matrix. It's used only with fun=12.
128136
/// </summary>
129-
[FieldOffset(44)]
130137
public float C;
131138

132139
/// <summary>
133140
/// Specify if the factor matrices should be non-negative.
134141
/// </summary>
135-
[FieldOffset(48)]
136142
public byte DoNmf;
137143

138144
/// <summary>
139145
/// Set to true so that LIBMF may produce less information to STDOUT.
140146
/// </summary>
141-
[FieldOffset(49)]
142147
public byte Quiet;
143148

144149
/// <summary>
145150
/// Set to false so that LIBMF may reuse and modifiy the data passed in.
146151
/// </summary>
147-
[FieldOffset(50)]
148152
public byte CopyData;
149153
}
150154

151-
[StructLayout(LayoutKind.Explicit)]
155+
[StructLayout(LayoutKind.Sequential)]
152156
private unsafe struct MFModel
153157
{
154-
[FieldOffset(0)]
158+
/// <summary>
159+
/// See <see cref="MFParameter.Fun"/>.
160+
/// </summary>
155161
public int Fun;
162+
156163
/// <summary>
157164
/// Number of rows in the training matrix.
158165
/// </summary>
159-
[FieldOffset(4)]
160166
public int M;
167+
161168
/// <summary>
162169
/// Number of columns in the training matrix.
163170
/// </summary>
164-
[FieldOffset(8)]
165171
public int N;
172+
166173
/// <summary>
167174
/// Rank of factor matrices.
168175
/// </summary>
169-
[FieldOffset(12)]
170176
public int K;
177+
171178
/// <summary>
172179
/// Average value in the training matrix.
173180
/// </summary>
174-
[FieldOffset(16)]
175181
public float B;
182+
176183
/// <summary>
177184
/// Left factor matrix. Its shape is M-by-K stored in row-major format.
178185
/// </summary>
179-
[FieldOffset(24)] // pointer is 8-byte on 64-bit machine.
180186
public float* P;
187+
181188
/// <summary>
182189
/// Right factor matrix. Its shape is N-by-K stored in row-major format.
183190
/// </summary>
184-
[FieldOffset(32)] // pointer is 8-byte on 64-bit machine.
185191
public float* Q;
186192
}
187193

src/Native/MatrixFactorizationNative/UnmanagedMemory.cpp

+81-14
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
using namespace mf;
1111

12-
mf_parameter make_param(const mf_parameter_bridge *param_bridge)
12+
inline mf_parameter TranslateToParam(const mf_parameter_bridge *param_bridge)
1313
{
1414
mf_parameter param;
1515
param.fun = param_bridge->fun;
@@ -30,30 +30,97 @@ mf_parameter make_param(const mf_parameter_bridge *param_bridge)
3030
return param;
3131
}
3232

33-
EXPORT_API(void) MFDestroyModel(mf_model *&model)
33+
inline mf_problem TranslateToProblem(const mf_problem_bridge *prob_bridge)
3434
{
35-
return mf_destroy_model(&model);
35+
mf_problem prob;
36+
prob.m = prob_bridge->m;
37+
prob.n = prob_bridge->n;
38+
prob.nnz = prob_bridge->nnz;
39+
prob.R = prob_bridge->R;
40+
return prob;
3641
}
3742

38-
EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter_bridge *param_bridge)
43+
inline void TranslateToModelBridge(const mf_model *model, mf_model_bridge *model_bridge)
3944
{
40-
auto param = make_param(param_bridge);
41-
return mf_train(prob, param);
45+
model_bridge->fun = model->fun;
46+
model_bridge->m = model->m;
47+
model_bridge->n = model->n;
48+
model_bridge->k = model->k;
49+
model_bridge->b = model->b;
50+
model_bridge->P = model->P;
51+
model_bridge->Q = model->Q;
4252
}
4353

44-
EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter_bridge *param_bridge)
54+
inline void TranslateToModel(const mf_model_bridge *model_bridge, mf_model *model)
4555
{
46-
auto param = make_param(param_bridge);
47-
return mf_train_with_validation(tr, va, param);
56+
model->fun = model_bridge->fun;
57+
model->m = model_bridge->m;
58+
model->n = model_bridge->n;
59+
model->k = model_bridge->k;
60+
model->b = model_bridge->b;
61+
model->P = model_bridge->P;
62+
model->Q = model_bridge->Q;
4863
}
4964

50-
EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter_bridge *param_bridge)
65+
EXPORT_API(void) MFDestroyModel(mf_model_bridge *&model_bridge)
5166
{
52-
auto param = make_param(param_bridge);
53-
return mf_cross_validation(prob, nr_folds, param);
67+
// Transfer the ownership of P and Q back to the original LIBMF class, so that
68+
// mf_destroy_model can be called.
69+
auto model = new mf_model;
70+
model->P = model_bridge->P;
71+
model->Q = model_bridge->Q;
72+
mf_destroy_model(&model); // delete model, model->P, amd model->Q.
73+
74+
// Delete bridge class allocated in MFTrain, MFTrainWithValidation, or MFCrossValidation.
75+
delete model_bridge;
76+
model_bridge = nullptr;
77+
}
78+
79+
EXPORT_API(mf_model_bridge*) MFTrain(const mf_problem_bridge *prob_bridge, const mf_parameter_bridge *param_bridge)
80+
{
81+
// Convert objects created outside LIBMF. Notice that the called LIBMF function doesn't take the ownership of
82+
// allocated memory in those external objects.
83+
auto prob = TranslateToProblem(prob_bridge);
84+
auto param = TranslateToParam(param_bridge);
85+
86+
// The model contains 3 allocated things --- itself, P, and Q.
87+
// We will delete itself and transfer the ownership of P and Q to the associated bridge class. The bridge class
88+
// will then be sent to C#.
89+
auto model = mf_train(&prob, param);
90+
auto model_bridge = new mf_model_bridge;
91+
TranslateToModelBridge(model, model_bridge);
92+
delete model;
93+
return model_bridge; // To clean memory up, we need to delete model_bridge, model_bridge->P, and model_bridge->Q.
94+
}
95+
96+
EXPORT_API(mf_model_bridge*) MFTrainWithValidation(const mf_problem_bridge *tr_bridge, const mf_problem_bridge *va_bridge, const mf_parameter_bridge *param_bridge)
97+
{
98+
// Convert objects created outside LIBMF. Notice that the called LIBMF function doesn't take the ownership of
99+
// allocated memory in those external objects.
100+
auto tr = TranslateToProblem(tr_bridge);
101+
auto va = TranslateToProblem(va_bridge);
102+
auto param = TranslateToParam(param_bridge);
103+
104+
// The model contains 3 allocated things --- itself, P, and Q.
105+
// We will delete itself and transfer the ownership of P and Q to the associated bridge class. The bridge class
106+
// will then be sent to C#.
107+
auto model = mf_train_with_validation(&tr, &va, param);
108+
auto model_bridge = new mf_model_bridge;
109+
TranslateToModelBridge(model, model_bridge);
110+
delete model;
111+
return model_bridge; // To clean memory up, we need to delete model_bridge, model_bridge->P, and model_bridge->Q.
112+
}
113+
114+
EXPORT_API(float) MFCrossValidation(const mf_problem_bridge *prob_bridge, int32_t nr_folds, const mf_parameter_bridge *param_bridge)
115+
{
116+
auto param = TranslateToParam(param_bridge);
117+
auto prob = TranslateToProblem(prob_bridge);
118+
return mf_cross_validation(&prob, nr_folds, param);
54119
}
55120

56-
EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx)
121+
EXPORT_API(float) MFPredict(const mf_model_bridge *model_bridge, int32_t p_idx, int32_t q_idx)
57122
{
58-
return mf_predict(model, p_idx, q_idx);
123+
mf_model model;
124+
TranslateToModel(model_bridge, &model);
125+
return mf_predict(&model, p_idx, q_idx);
59126
}

src/Native/MatrixFactorizationNative/UnmanagedMemory.h

+24-5
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,31 @@ struct mf_parameter_bridge
2727
uint8_t copy_data;
2828
};
2929

30-
EXPORT_API(void) MFDestroyModel(mf_model *&model);
30+
struct mf_problem_bridge
31+
{
32+
int32_t m;
33+
int32_t n;
34+
int64_t nnz;
35+
struct mf_node *R;
36+
};
37+
38+
struct mf_model_bridge
39+
{
40+
int32_t fun;
41+
int32_t m;
42+
int32_t n;
43+
int32_t k;
44+
float b;
45+
float *P;
46+
float *Q;
47+
};
48+
49+
EXPORT_API(void) MFDestroyModel(mf_model_bridge *&model);
3150

32-
EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter_bridge *parameter_bridge);
51+
EXPORT_API(mf_model_bridge*) MFTrain(const mf_problem_bridge *prob_bridge, const mf_parameter_bridge *parameter_bridge);
3352

34-
EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter_bridge *parameter_bridge);
53+
EXPORT_API(mf_model_bridge*) MFTrainWithValidation(const mf_problem_bridge *tr, const mf_problem_bridge *va, const mf_parameter_bridge *parameter_bridge);
3554

36-
EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter_bridge* parameter_bridge);
55+
EXPORT_API(float) MFCrossValidation(const mf_problem_bridge *prob, int32_t nr_folds, const mf_parameter_bridge* parameter_bridge);
3756

38-
EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx);
57+
EXPORT_API(float) MFPredict(const mf_model_bridge *model, int32_t p_idx, int32_t q_idx);

test/Microsoft.ML.TestFramework/Attributes/MatrixFactorizationFactAttribute.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ namespace Microsoft.ML.TestFramework.Attributes
1010
/// </summary>
1111
public sealed class MatrixFactorizationFactAttribute : EnvironmentSpecificFactAttribute
1212
{
13-
public MatrixFactorizationFactAttribute() : base("Disabled - this test is being fixed as part of https://github.com/dotnet/machinelearning/issues/1441")
13+
public MatrixFactorizationFactAttribute() : base("")
1414
{
1515
}
1616

1717
/// <inheritdoc />
1818
protected override bool IsEnvironmentSupported()
1919
{
20-
return Environment.Is64BitProcess;
20+
return true;
2121
}
2222
}
2323
}

0 commit comments

Comments
 (0)