Skip to content

Commit 3e3bd50

Browse files
authored
Convert RFF transform to estimators (#1122)
1 parent 2a4681b commit 3e3bd50

File tree

7 files changed

+654
-333
lines changed

7 files changed

+654
-333
lines changed

src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,22 +172,16 @@ private static KeyToVectorTransform Create(IHostEnvironment env, ModelLoadContex
172172

173173
host.CheckValue(ctx, nameof(ctx));
174174
ctx.CheckAtModel(GetVersionInfo());
175-
176-
return new KeyToVectorTransform(host, ctx);
177-
}
178-
179-
private static ModelLoadContext ReadFloatFromCtx(IHostEnvironment env, ModelLoadContext ctx)
180-
{
181175
if (ctx.Header.ModelVerWritten == 0x00010001)
182176
{
183177
int cbFloat = ctx.Reader.ReadInt32();
184178
env.CheckDecode(cbFloat == sizeof(float));
185179
}
186-
return ctx;
180+
return new KeyToVectorTransform(host, ctx);
187181
}
188182

189183
private KeyToVectorTransform(IHost host, ModelLoadContext ctx)
190-
: base(host, ReadFloatFromCtx(host, ctx))
184+
: base(host, ctx)
191185
{
192186
var columnsLength = ColumnPairs.Length;
193187
// *** Binary format ***

src/Microsoft.ML.Transforms/FourierDistributionSampler.cs

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
75
using Microsoft.ML.Runtime;
86
using Microsoft.ML.Runtime.CommandLine;
9-
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.EntryPoints;
108
using Microsoft.ML.Runtime.Internal.Utilities;
119
using Microsoft.ML.Runtime.Model;
10+
using Microsoft.ML.Transforms;
1211

1312
[assembly: LoadableClass(typeof(GaussianFourierSampler), typeof(GaussianFourierSampler.Arguments), typeof(SignatureFourierDistributionSampler),
1413
"Gaussian Kernel", GaussianFourierSampler.LoadName, "Gaussian")]
@@ -25,26 +24,33 @@
2524
"Laplacian Fourier Sampler Executor", "LaplacianSamplerExecutor", LaplacianFourierSampler.LoaderSignature)]
2625

2726
// REVIEW: Roll all of this in with the RffTransform.
28-
namespace Microsoft.ML.Runtime.Data
27+
namespace Microsoft.ML.Transforms
2928
{
3029
/// <summary>
3130
/// Signature for an IFourierDistributionSampler constructor.
3231
/// </summary>
33-
public delegate void SignatureFourierDistributionSampler(Float avgDist);
32+
public delegate void SignatureFourierDistributionSampler(float avgDist);
3433

3534
public interface IFourierDistributionSampler : ICanSaveModel
3635
{
37-
Float Next(IRandom rand);
36+
float Next(IRandom rand);
37+
}
38+
39+
[TlcModule.ComponentKind("FourierDistributionSampler")]
40+
public interface IFourierDistributionSamplerFactory : IComponentFactory<float, IFourierDistributionSampler>
41+
{
3842
}
3943

4044
public sealed class GaussianFourierSampler : IFourierDistributionSampler
4145
{
4246
private readonly IHost _host;
4347

44-
public class Arguments
48+
public class Arguments : IFourierDistributionSamplerFactory
4549
{
4650
[Argument(ArgumentType.AtMostOnce, HelpText = "gamma in the kernel definition: exp(-gamma*||x-y||^2 / r^2). r is an estimate of the average intra-example distance", ShortName = "g")]
47-
public Float Gamma = 1;
51+
public float Gamma = 1;
52+
53+
public IFourierDistributionSampler CreateComponent(IHostEnvironment env, float avgDist) => new GaussianFourierSampler(env, this, avgDist);
4854
}
4955

5056
public const string LoaderSignature = "RandGaussFourierExec";
@@ -61,9 +67,9 @@ private static VersionInfo GetVersionInfo()
6167

6268
public const string LoadName = "GaussianRandom";
6369

64-
private readonly Float _gamma;
70+
private readonly float _gamma;
6571

66-
public GaussianFourierSampler(IHostEnvironment env, Arguments args, Float avgDist)
72+
public GaussianFourierSampler(IHostEnvironment env, Arguments args, float avgDist)
6773
{
6874
Contracts.CheckValue(env, nameof(env));
6975
_host = env.Register(LoadName);
@@ -91,7 +97,7 @@ private GaussianFourierSampler(IHostEnvironment env, ModelLoadContext ctx)
9197
// Float: gamma
9298

9399
int cbFloat = ctx.Reader.ReadInt32();
94-
_host.CheckDecode(cbFloat == sizeof(Float));
100+
_host.CheckDecode(cbFloat == sizeof(float));
95101

96102
_gamma = ctx.Reader.ReadFloat();
97103
_host.CheckDecode(FloatUtils.IsFinite(_gamma));
@@ -105,23 +111,25 @@ public void Save(ModelSaveContext ctx)
105111
// int: sizeof(Float)
106112
// Float: gamma
107113

108-
ctx.Writer.Write(sizeof(Float));
114+
ctx.Writer.Write(sizeof(float));
109115
_host.Assert(FloatUtils.IsFinite(_gamma));
110116
ctx.Writer.Write(_gamma);
111117
}
112118

113-
public Float Next(IRandom rand)
119+
public float Next(IRandom rand)
114120
{
115-
return (Float)Stats.SampleFromGaussian(rand) * MathUtils.Sqrt(2 * _gamma);
121+
return (float)Stats.SampleFromGaussian(rand) * MathUtils.Sqrt(2 * _gamma);
116122
}
117123
}
118124

119125
public sealed class LaplacianFourierSampler : IFourierDistributionSampler
120126
{
121-
public class Arguments
127+
public class Arguments : IFourierDistributionSamplerFactory
122128
{
123129
[Argument(ArgumentType.AtMostOnce, HelpText = "a in the term exp(-a|x| / r). r is an estimate of the average intra-example L1 distance")]
124-
public Float A = 1;
130+
public float A = 1;
131+
132+
public IFourierDistributionSampler CreateComponent(IHostEnvironment env, float avgDist) => new LaplacianFourierSampler(env, this, avgDist);
125133
}
126134

127135
private static VersionInfo GetVersionInfo()
@@ -139,9 +147,9 @@ private static VersionInfo GetVersionInfo()
139147
public const string RegistrationName = "LaplacianRandom";
140148

141149
private readonly IHost _host;
142-
private readonly Float _a;
150+
private readonly float _a;
143151

144-
public LaplacianFourierSampler(IHostEnvironment env, Arguments args, Float avgDist)
152+
public LaplacianFourierSampler(IHostEnvironment env, Arguments args, float avgDist)
145153
{
146154
Contracts.CheckValue(env, nameof(env));
147155
_host = env.Register(RegistrationName);
@@ -170,7 +178,7 @@ private LaplacianFourierSampler(IHostEnvironment env, ModelLoadContext ctx)
170178
// Float: a
171179

172180
int cbFloat = ctx.Reader.ReadInt32();
173-
_host.CheckDecode(cbFloat == sizeof(Float));
181+
_host.CheckDecode(cbFloat == sizeof(float));
174182

175183
_a = ctx.Reader.ReadFloat();
176184
_host.CheckDecode(FloatUtils.IsFinite(_a));
@@ -184,12 +192,12 @@ public void Save(ModelSaveContext ctx)
184192
// int: sizeof(Float)
185193
// Float: a
186194

187-
ctx.Writer.Write(sizeof(Float));
195+
ctx.Writer.Write(sizeof(float));
188196
_host.Assert(FloatUtils.IsFinite(_a));
189197
ctx.Writer.Write(_a);
190198
}
191199

192-
public Float Next(IRandom rand)
200+
public float Next(IRandom rand)
193201
{
194202
return _a * Stats.SampleFromCauchy(rand);
195203
}

0 commit comments

Comments
 (0)