Skip to content

Commit e25d3e7

Browse files
committed
REducing public surface of remaining predictors
1 parent 76d26d0 commit e25d3e7

38 files changed

+348
-317
lines changed

src/Microsoft.ML.CpuMath/AlignedArray.cs

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
namespace Microsoft.ML.Runtime.Internal.CpuMath
99
{
10-
using Float = System.Single;
11-
1210
/// <summary>
1311
/// This implements a logical array of Floats that is automatically aligned for SSE/AVX operations.
1412
/// To pin and force alignment, call the GetPin method, typically wrapped in a using (since it
@@ -24,7 +22,7 @@ internal sealed class AlignedArray
2422
// items, also filled with NaN. Note that _size * sizeof(Float) is divisible by _cbAlign.
2523
// It is illegal to access any slot outsize [_base, _base + _size). This is internal so clients
2624
// can easily pin it.
27-
public Float[] Items;
25+
public float[] Items;
2826

2927
private readonly int _size; // Must be divisible by (_cbAlign / sizeof(Float)).
3028
private readonly int _cbAlign; // The alignment in bytes, a power of two, divisible by sizeof(Float).
@@ -40,12 +38,12 @@ public AlignedArray(int size, int cbAlign)
4038
{
4139
Contracts.Assert(0 < size);
4240
// cbAlign should be a power of two.
43-
Contracts.Assert(sizeof(Float) <= cbAlign);
41+
Contracts.Assert(sizeof(float) <= cbAlign);
4442
Contracts.Assert((cbAlign & (cbAlign - 1)) == 0);
4543
// cbAlign / sizeof(Float) should divide size.
46-
Contracts.Assert((size * sizeof(Float)) % cbAlign == 0);
44+
Contracts.Assert((size * sizeof(float)) % cbAlign == 0);
4745

48-
Items = new Float[size + cbAlign / sizeof(Float)];
46+
Items = new float[size + cbAlign / sizeof(float)];
4947
_size = size;
5048
_cbAlign = cbAlign;
5149
_lock = new object();
@@ -54,15 +52,15 @@ public AlignedArray(int size, int cbAlign)
5452
public unsafe int GetBase(long addr)
5553
{
5654
#if DEBUG
57-
fixed (Float* pv = Items)
58-
Contracts.Assert((Float*)addr == pv);
55+
fixed (float* pv = Items)
56+
Contracts.Assert((float*)addr == pv);
5957
#endif
6058

6159
int cbLow = (int)(addr & (_cbAlign - 1));
6260
int ibMin = cbLow == 0 ? 0 : _cbAlign - cbLow;
63-
Contracts.Assert(ibMin % sizeof(Float) == 0);
61+
Contracts.Assert(ibMin % sizeof(float) == 0);
6462

65-
int ifltMin = ibMin / sizeof(Float);
63+
int ifltMin = ibMin / sizeof(float);
6664
if (ifltMin == _base)
6765
return _base;
6866

@@ -71,9 +69,9 @@ public unsafe int GetBase(long addr)
7169
// Anything outsize [_base, _base + _size) should not be accessed, so
7270
// set them to NaN, for debug validation.
7371
for (int i = 0; i < _base; i++)
74-
Items[i] = Float.NaN;
72+
Items[i] = float.NaN;
7573
for (int i = _base + _size; i < Items.Length; i++)
76-
Items[i] = Float.NaN;
74+
Items[i] = float.NaN;
7775
#endif
7876
return _base;
7977
}
@@ -96,7 +94,7 @@ private void MoveData(int newBase)
9694

9795
public int CbAlign { get { return _cbAlign; } }
9896

99-
public Float this[int index]
97+
public float this[int index]
10098
{
10199
get
102100
{
@@ -110,15 +108,21 @@ public Float this[int index]
110108
}
111109
}
112110

113-
public void CopyTo(Span<Float> dst, int index, int count)
111+
public void CopyTo(Span<float> dst)
112+
{
113+
Contracts.Assert(dst != null);
114+
Items.AsSpan().CopyTo(dst);
115+
}
116+
117+
public void CopyTo(Span<float> dst, int index, int count)
114118
{
115119
Contracts.Assert(0 <= count && count <= _size);
116120
Contracts.Assert(dst != null);
117121
Contracts.Assert(0 <= index && index <= dst.Length - count);
118122
Items.AsSpan(_base, count).CopyTo(dst.Slice(index));
119123
}
120124

121-
public void CopyTo(int start, Span<Float> dst, int index, int count)
125+
public void CopyTo(int start, Span<float> dst, int index, int count)
122126
{
123127
Contracts.Assert(0 <= count);
124128
Contracts.Assert(0 <= start && start <= _size - count);
@@ -127,13 +131,13 @@ public void CopyTo(int start, Span<Float> dst, int index, int count)
127131
Items.AsSpan(start + _base, count).CopyTo(dst.Slice(index));
128132
}
129133

130-
public void CopyFrom(ReadOnlySpan<Float> src)
134+
public void CopyFrom(ReadOnlySpan<float> src)
131135
{
132136
Contracts.Assert(src.Length <= _size);
133137
src.CopyTo(Items.AsSpan(_base));
134138
}
135139

136-
public void CopyFrom(int start, ReadOnlySpan<Float> src)
140+
public void CopyFrom(int start, ReadOnlySpan<float> src)
137141
{
138142
Contracts.Assert(0 <= start && start <= _size - src.Length);
139143
src.CopyTo(Items.AsSpan(start + _base));
@@ -143,7 +147,7 @@ public void CopyFrom(int start, ReadOnlySpan<Float> src)
143147
// valuesSrc contains only the non-zero entries. Those are copied into their logical positions in the dense array.
144148
// rgposSrc contains the logical positions + offset of the non-zero entries in the dense array.
145149
// rgposSrc runs parallel to the valuesSrc array.
146-
public void CopyFrom(ReadOnlySpan<int> rgposSrc, ReadOnlySpan<Float> valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
150+
public void CopyFrom(ReadOnlySpan<int> rgposSrc, ReadOnlySpan<float> valuesSrc, int posMin, int iposMin, int iposLim, bool zeroItems)
147151
{
148152
Contracts.Assert(rgposSrc != null);
149153
Contracts.Assert(valuesSrc != null);
@@ -202,7 +206,7 @@ public void ZeroItems(int[] rgposSrc, int posMin, int iposMin, int iposLim)
202206
// REVIEW: This is hackish and slightly dangerous. Perhaps we should wrap this in an
203207
// IDisposable that "locks" this, prohibiting GetBase from being called, while the buffer
204208
// is "checked out".
205-
public void GetRawBuffer(out Float[] items, out int offset)
209+
public void GetRawBuffer(out float[] items, out int offset)
206210
{
207211
items = Items;
208212
offset = _base;

src/Microsoft.ML.Data/Dirty/PredictorBase.cs renamed to src/Microsoft.ML.Data/Dirty/ModelParametersBase.cs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
75
using System;
86
using Microsoft.ML.Runtime;
97
using Microsoft.ML.Runtime.Model;
@@ -15,22 +13,22 @@ namespace Microsoft.ML.Runtime.Internal.Internallearn
1513
/// Note: This provides essentially no value going forward. New predictors should just
1614
/// derive from the interfaces they need.
1715
/// </summary>
18-
public abstract class PredictorBase<TOutput> : IPredictorProducing<TOutput>
16+
public abstract class ModelParametersBase<TOutput> : ICanSaveModel, IPredictorProducing<TOutput>
1917
{
2018
public const string NormalizerWarningFormat =
2119
"Ignoring integrated normalizer while loading a predictor of type {0}.{1}" +
2220
" Please refer to https://aka.ms/MLNetIssue for assistance with converting legacy models.";
2321

2422
protected readonly IHost Host;
2523

26-
protected PredictorBase(IHostEnvironment env, string name)
24+
protected ModelParametersBase(IHostEnvironment env, string name)
2725
{
2826
Contracts.CheckValue(env, nameof(env));
2927
env.CheckNonWhiteSpace(name, nameof(name));
3028
Host = env.Register(name);
3129
}
3230

33-
protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
31+
protected ModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
3432
{
3533
Contracts.CheckValue(env, nameof(env));
3634
env.CheckNonWhiteSpace(name, nameof(name));
@@ -42,11 +40,14 @@ protected PredictorBase(IHostEnvironment env, string name, ModelLoadContext ctx)
4240
// Verify that the Float type matches.
4341
int cbFloat = ctx.Reader.ReadInt32();
4442
#pragma warning disable MSML_NoMessagesForLoadContext // This one is actually useful.
45-
Host.CheckDecode(cbFloat == sizeof(Float), "This file was saved by an incompatible version");
43+
Host.CheckDecode(cbFloat == sizeof(float), "This file was saved by an incompatible version");
4644
#pragma warning restore MSML_NoMessagesForLoadContext
4745
}
4846

49-
public virtual void Save(ModelSaveContext ctx)
47+
void ICanSaveModel.Save(ModelSaveContext ctx) => Save(ctx);
48+
49+
[BestFriend]
50+
private protected virtual void Save(ModelSaveContext ctx)
5051
{
5152
Host.CheckValue(ctx, nameof(ctx));
5253
ctx.CheckAtModel();
@@ -61,7 +62,7 @@ private protected virtual void SaveCore(ModelSaveContext ctx)
6162
// *** Binary format ***
6263
// int: sizeof(Float)
6364
// <Derived type stuff>
64-
ctx.Writer.Write(sizeof(Float));
65+
ctx.Writer.Write(sizeof(float));
6566
}
6667

6768
public abstract PredictionKind PredictionKind { get; }

src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ internal interface ICanGetSummaryAsIRow
143143
Row GetStatsIRowOrNull(RoleMappedSchema schema);
144144
}
145145

146-
public interface ICanGetSummaryAsIDataView
146+
[BestFriend]
147+
internal interface ICanGetSummaryAsIDataView
147148
{
148149
IDataView GetSummaryDataView(RoleMappedSchema schema);
149150
}

src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind pre
8585
private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetModel<TScalarPredictor>> models)
8686
{
8787
if (models.All(m => m.Predictor is TDistPredictor))
88-
return new EnsembleDistributionPredictor(Host, PredictionKind, CreateModels<TDistPredictor>(models), Combiner);
89-
return new EnsemblePredictor(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
88+
return new EnsembleDistributionModelParameters(Host, PredictionKind, CreateModels<TDistPredictor>(models), Combiner);
89+
return new EnsembleModelParameters(Host, PredictionKind, CreateModels<TScalarPredictor>(models), Combiner);
9090
}
9191

9292
public IPredictor CombineModels(IEnumerable<IPredictor> models)
@@ -98,12 +98,12 @@ public IPredictor CombineModels(IEnumerable<IPredictor> models)
9898
if (p is TDistPredictor)
9999
{
100100
Host.CheckParam(models.All(m => m is TDistPredictor), nameof(models));
101-
return new EnsembleDistributionPredictor(Host, p.PredictionKind,
101+
return new EnsembleDistributionModelParameters(Host, p.PredictionKind,
102102
models.Select(k => new FeatureSubsetModel<TDistPredictor>((TDistPredictor)k)).ToArray(), combiner);
103103
}
104104

105105
Host.CheckParam(models.All(m => m is TScalarPredictor), nameof(models));
106-
return new EnsemblePredictor(Host, p.PredictionKind,
106+
return new EnsembleModelParameters(Host, p.PredictionKind,
107107
models.Select(k => new FeatureSubsetModel<TScalarPredictor>((TScalarPredictor)k)).ToArray(), combiner);
108108
}
109109
}

src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs renamed to src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
using Microsoft.ML.Runtime.Model;
1515

1616
// These are for deserialization from a model repository.
17-
[assembly: LoadableClass(typeof(EnsembleDistributionPredictor), null, typeof(SignatureLoadModel),
18-
EnsembleDistributionPredictor.UserName, EnsembleDistributionPredictor.LoaderSignature)]
17+
[assembly: LoadableClass(typeof(EnsembleDistributionModelParameters), null, typeof(SignatureLoadModel),
18+
EnsembleDistributionModelParameters.UserName, EnsembleDistributionModelParameters.LoaderSignature)]
1919

2020
namespace Microsoft.ML.Runtime.Ensemble
2121
{
2222
using TDistPredictor = IDistPredictorProducing<Single, Single>;
2323

24-
public sealed class EnsembleDistributionPredictor : EnsemblePredictorBase<TDistPredictor, Single>,
24+
public sealed class EnsembleDistributionModelParameters : EnsembleModelParametersBase<TDistPredictor, Single>,
2525
TDistPredictor, IValueMapperDist
2626
{
2727
internal const string UserName = "Ensemble Distribution Executor";
@@ -38,7 +38,7 @@ private static VersionInfo GetVersionInfo()
3838
verReadableCur: 0x00010003,
3939
verWeCanReadBack: 0x00010002,
4040
loaderSignature: LoaderSignature,
41-
loaderAssemblyName: typeof(EnsembleDistributionPredictor).Assembly.FullName);
41+
loaderAssemblyName: typeof(EnsembleDistributionModelParameters).Assembly.FullName);
4242
}
4343

4444
private readonly Single[] _averagedWeights;
@@ -53,7 +53,7 @@ private static VersionInfo GetVersionInfo()
5353

5454
public override PredictionKind PredictionKind { get; }
5555

56-
internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind,
56+
public EnsembleDistributionModelParameters(IHostEnvironment env, PredictionKind kind,
5757
FeatureSubsetModel<TDistPredictor>[] models, IOutputCombiner<Single> combiner, Single[] weights = null)
5858
: base(env, RegistrationName, models, combiner, weights)
5959
{
@@ -63,7 +63,7 @@ internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind
6363
ComputeAveragedWeights(out _averagedWeights);
6464
}
6565

66-
private EnsembleDistributionPredictor(IHostEnvironment env, ModelLoadContext ctx)
66+
private EnsembleDistributionModelParameters(IHostEnvironment env, ModelLoadContext ctx)
6767
: base(env, RegistrationName, ctx)
6868
{
6969
PredictionKind = (PredictionKind)ctx.Reader.ReadInt32();
@@ -103,12 +103,12 @@ private bool IsValid(IValueMapperDist mapper)
103103
&& mapper.DistType == NumberType.Float;
104104
}
105105

106-
private static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
106+
private static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
107107
{
108108
Contracts.CheckValue(env, nameof(env));
109109
env.CheckValue(ctx, nameof(ctx));
110110
ctx.CheckAtModel(GetVersionInfo());
111-
return new EnsembleDistributionPredictor(env, ctx);
111+
return new EnsembleDistributionModelParameters(env, ctx);
112112
}
113113

114114
private protected override void SaveCore(ModelSaveContext ctx)

src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs renamed to src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@
1111
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
1212
using Microsoft.ML.Runtime.EntryPoints;
1313

14-
[assembly: LoadableClass(typeof(EnsemblePredictor), null, typeof(SignatureLoadModel), EnsemblePredictor.UserName,
15-
EnsemblePredictor.LoaderSignature)]
14+
[assembly: LoadableClass(typeof(EnsembleModelParameters), null, typeof(SignatureLoadModel), EnsembleModelParameters.UserName,
15+
EnsembleModelParameters.LoaderSignature)]
1616

17-
[assembly: EntryPointModule(typeof(EnsemblePredictor))]
17+
[assembly: EntryPointModule(typeof(EnsembleModelParameters))]
1818

1919
namespace Microsoft.ML.Runtime.Ensemble
2020
{
2121
using TScalarPredictor = IPredictorProducing<Single>;
2222

23-
public sealed class EnsemblePredictor : EnsemblePredictorBase<TScalarPredictor, Single>, IValueMapper
23+
public sealed class EnsembleModelParameters : EnsembleModelParametersBase<TScalarPredictor, Single>, IValueMapper
2424
{
25-
public const string UserName = "Ensemble Executor";
26-
public const string LoaderSignature = "EnsembleFloatExec";
27-
public const string RegistrationName = "EnsemblePredictor";
25+
internal const string UserName = "Ensemble Executor";
26+
internal const string LoaderSignature = "EnsembleFloatExec";
27+
internal const string RegistrationName = "EnsemblePredictor";
2828

2929
private static VersionInfo GetVersionInfo()
3030
{
@@ -36,28 +36,29 @@ private static VersionInfo GetVersionInfo()
3636
verReadableCur: 0x00010003,
3737
verWeCanReadBack: 0x00010002,
3838
loaderSignature: LoaderSignature,
39-
loaderAssemblyName: typeof(EnsemblePredictor).Assembly.FullName);
39+
loaderAssemblyName: typeof(EnsembleModelParameters).Assembly.FullName);
4040
}
4141

4242
private readonly IValueMapper[] _mappers;
4343

44-
public ColumnType InputType { get; }
45-
public ColumnType OutputType => NumberType.Float;
44+
private readonly ColumnType _inputType;
45+
ColumnType IValueMapper.InputType => _inputType;
46+
ColumnType IValueMapper.OutputType => NumberType.Float;
4647
public override PredictionKind PredictionKind { get; }
4748

48-
internal EnsemblePredictor(IHostEnvironment env, PredictionKind kind,
49+
public EnsembleModelParameters(IHostEnvironment env, PredictionKind kind,
4950
FeatureSubsetModel<TScalarPredictor>[] models, IOutputCombiner<Single> combiner, Single[] weights = null)
5051
: base(env, LoaderSignature, models, combiner, weights)
5152
{
5253
PredictionKind = kind;
53-
InputType = InitializeMappers(out _mappers);
54+
_inputType = InitializeMappers(out _mappers);
5455
}
5556

56-
private EnsemblePredictor(IHostEnvironment env, ModelLoadContext ctx)
57+
private EnsembleModelParameters(IHostEnvironment env, ModelLoadContext ctx)
5758
: base(env, RegistrationName, ctx)
5859
{
5960
PredictionKind = (PredictionKind)ctx.Reader.ReadInt32();
60-
InputType = InitializeMappers(out _mappers);
61+
_inputType = InitializeMappers(out _mappers);
6162
}
6263

6364
private ColumnType InitializeMappers(out IValueMapper[] mappers)
@@ -91,12 +92,12 @@ private bool IsValid(IValueMapper mapper)
9192
&& mapper.OutputType == NumberType.Float;
9293
}
9394

94-
public static EnsemblePredictor Create(IHostEnvironment env, ModelLoadContext ctx)
95+
private static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
9596
{
9697
Contracts.CheckValue(env, nameof(env));
9798
env.CheckValue(ctx, nameof(ctx));
9899
ctx.CheckAtModel(GetVersionInfo());
99-
return new EnsemblePredictor(env, ctx);
100+
return new EnsembleModelParameters(env, ctx);
100101
}
101102

102103
private protected override void SaveCore(ModelSaveContext ctx)
@@ -124,8 +125,8 @@ ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
124125
ValueMapper<VBuffer<Single>, Single> del =
125126
(in VBuffer<Single> src, ref Single dst) =>
126127
{
127-
if (InputType.VectorSize > 0)
128-
Host.Check(src.Length == InputType.VectorSize);
128+
if (_inputType.VectorSize > 0)
129+
Host.Check(src.Length == _inputType.VectorSize);
129130

130131
var tmp = src;
131132
Parallel.For(0, maps.Length, i =>

0 commit comments

Comments
 (0)