Skip to content

Commit 2d8059a

Browse files
committed
review comments and further cleanup
1 parent 78a287e commit 2d8059a

File tree

5 files changed

+122
-24
lines changed

5 files changed

+122
-24
lines changed

src/Microsoft.ML.HalLearners/VectorWhitening.cs

+39-10
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,15 @@ namespace Microsoft.ML.Transforms.Projections
3535
{
3636
public enum WhiteningKind
3737
{
38+
/// <summary>
39+
/// PCA whitening.
40+
/// </summary>
3841
[TGUI(Label = "PCA whitening")]
3942
Pca,
4043

44+
/// <summary>
45+
/// ZCA whitening.
46+
/// </summary>
4147
[TGUI(Label = "ZCA whitening")]
4248
Zca
4349
}
@@ -186,9 +192,9 @@ internal static VectorWhiteningTransformer Create(IHostEnvironment env, ModelLoa
186192
}
187193

188194
// Factory method for SignatureDataTransform.
189-
internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView input)
195+
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
190196
{
191-
var infos = args.Columns.Select(colPair => new VectorWhiteningEstimator.ColumnInfo(colPair, args)).ToArray();
197+
var infos = options.Columns.Select(colPair => new VectorWhiteningEstimator.ColumnInfo(colPair, options)).ToArray();
192198
(var models, var invModels) = TrainVectorWhiteningTransform(env, input, infos);
193199
return new VectorWhiteningTransformer(env, models, invModels, infos).MakeDataTransform(input);
194200
}
@@ -665,7 +671,8 @@ private static float DotProduct(float[] a, int aOffset, ReadOnlySpan<float> b, R
665671
/// <include file='doc.xml' path='doc/members/member[@name="Whitening"]/*'/>
666672
public sealed class VectorWhiteningEstimator : IEstimator<VectorWhiteningTransformer>
667673
{
668-
public static class Defaults
674+
[BestFriend]
675+
internal static class Defaults
669676
{
670677
public const WhiteningKind Kind = WhiteningKind.Zca;
671678
public const float Eps = 1e-5f;
@@ -679,11 +686,29 @@ public static class Defaults
679686
/// </summary>
680687
public sealed class ColumnInfo
681688
{
689+
/// <summary>
690+
/// Name of the column resulting from the transformation of <see cref="InputColumnName"/>.
691+
/// </summary>
682692
public readonly string Name;
693+
/// <summary>
694+
/// Name of column to transform. If set to <see langword="null"/>, the value of the <see cref="Name"/> will be used as source.
695+
/// </summary>
683696
public readonly string InputColumnName;
697+
/// <summary>
698+
/// Whitening kind (PCA/ZCA).
699+
/// </summary>
684700
public readonly WhiteningKind Kind;
701+
/// <summary>
702+
/// Whitening constant, prevents division by zero.
703+
/// </summary>
685704
public readonly float Epsilon;
705+
/// <summary>
706+
/// Maximum number of rows used to train the transform.
707+
/// </summary>
686708
public readonly int MaxRow;
709+
/// <summary>
710+
/// In case of PCA whitening, indicates the number of components to retain.
711+
/// </summary>
687712
public readonly int PcaNum;
688713
internal readonly bool SaveInv;
689714

@@ -714,20 +739,20 @@ public ColumnInfo(string name, string inputColumnName = null, WhiteningKind kind
714739
Contracts.CheckUserArg(PcaNum >= 0, nameof(PcaNum));
715740
}
716741

717-
internal ColumnInfo(VectorWhiteningTransformer.Column item, VectorWhiteningTransformer.Options args)
742+
internal ColumnInfo(VectorWhiteningTransformer.Column item, VectorWhiteningTransformer.Options options)
718743
{
719744
Name = item.Name;
720745
Contracts.CheckValue(Name, nameof(Name));
721746
InputColumnName = item.Source ?? item.Name;
722747
Contracts.CheckValue(InputColumnName, nameof(InputColumnName));
723-
Kind = item.Kind ?? args.Kind;
748+
Kind = item.Kind ?? options.Kind;
724749
Contracts.CheckUserArg(Kind == WhiteningKind.Pca || Kind == WhiteningKind.Zca, nameof(item.Kind));
725-
Epsilon = item.Eps ?? args.Eps;
750+
Epsilon = item.Eps ?? options.Eps;
726751
Contracts.CheckUserArg(0 <= Epsilon && Epsilon < float.PositiveInfinity, nameof(item.Eps));
727-
MaxRow = item.MaxRows ?? args.MaxRows;
752+
MaxRow = item.MaxRows ?? options.MaxRows;
728753
Contracts.CheckUserArg(MaxRow > 0, nameof(item.MaxRows));
729-
SaveInv = item.SaveInverse ?? args.SaveInverse;
730-
PcaNum = item.PcaNum ?? args.PcaNum;
754+
SaveInv = item.SaveInverse ?? options.SaveInverse;
755+
PcaNum = item.PcaNum ?? options.PcaNum;
731756
Contracts.CheckUserArg(PcaNum >= 0, nameof(item.PcaNum));
732757
}
733758

@@ -803,6 +828,9 @@ internal VectorWhiteningEstimator(IHostEnvironment env, string outputColumnName,
803828
{
804829
}
805830

831+
/// <summary>
832+
/// Trains and returns a <see cref="VectorWhiteningTransformer"/>.
833+
/// </summary>
806834
public VectorWhiteningTransformer Fit(IDataView input)
807835
{
808836
// Build transformation matrices for whitening process, then construct a trained transform.
@@ -811,7 +839,8 @@ public VectorWhiteningTransformer Fit(IDataView input)
811839
}
812840

813841
/// <summary>
814-
/// Returns the schema that would be produced by the transformation.
842+
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
843+
/// Used for schema propagation and verification in a pipeline.
815844
/// </summary>
816845
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
817846
{

src/Microsoft.ML.PCA/PcaTransformer.cs

+25-2
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,8 @@ internal static CommonOutputs.TransformOutput Calculate(IHostEnvironment env, Op
619619
/// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*'/>
620620
public sealed class PrincipalComponentAnalysisEstimator : IEstimator<PcaTransformer>
621621
{
622-
public static class Defaults
622+
[BestFriend]
623+
internal static class Defaults
623624
{
624625
public const string WeightColumn = null;
625626
public const int Rank = 20;
@@ -633,12 +634,34 @@ public static class Defaults
633634
/// </summary>
634635
public sealed class ColumnInfo
635636
{
637+
/// <summary>
638+
/// Name of the column resulting from the transformation of <see cref="InputColumnName"/>.
639+
/// </summary>
636640
public readonly string Name;
641+
/// <summary>
642+
/// Name of column to transform.
643+
/// If set to <see langword="null"/>, the value of the <see cref="Name"/> will be used as source.
644+
/// </summary>
637645
public readonly string InputColumnName;
646+
/// <summary>
647+
/// The name of the weight column.
648+
/// </summary>
638649
public readonly string WeightColumn;
650+
/// <summary>
651+
/// The number of components in the PCA.
652+
/// </summary>
639653
public readonly int Rank;
654+
/// <summary>
655+
/// Oversampling parameter for randomized PCA training.
656+
/// </summary>
640657
public readonly int Oversampling;
658+
/// <summary>
659+
/// If enabled, data is centered to be zero mean.
660+
/// </summary>
641661
public readonly bool Center;
662+
/// <summary>
663+
/// The seed for random number generation.
664+
/// </summary>
642665
public readonly int? Seed;
643666

644667
/// <summary>
@@ -706,7 +729,7 @@ internal PrincipalComponentAnalysisEstimator(IHostEnvironment env, params Column
706729
}
707730

708731
/// <summary>
709-
/// Train and return a transformer.
732+
/// Trains and returns a <see cref="PcaTransformer"/>.
710733
/// </summary>
711734
public PcaTransformer Fit(IDataView input) => new PcaTransformer(_host, input, _columns);
712735

src/Microsoft.ML.Transforms/GcnTransform.cs

+26-8
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ private static VersionInfo GetVersionInfo()
191191
// REVIEW: should this be an argument instead?
192192
private const float MinScale = 1e-8f;
193193

194+
/// <summary>
195+
/// The objects describing how the transformation is applied on the input data.
196+
/// </summary>
194197
public IReadOnlyCollection<LpNormalizingEstimatorBase.ColumnInfoBase> Columns => _columns.AsReadOnly();
195198
private readonly LpNormalizingEstimatorBase.ColumnInfoBase[] _columns;
196199

@@ -660,13 +663,28 @@ public enum NormalizerKind : byte
660663
/// </summary>
661664
public abstract class ColumnInfoBase
662665
{
666+
/// <summary>
667+
/// Name of the column resulting from the transformation of <see cref="InputColumnName"/>.
668+
/// </summary>
663669
public readonly string Name;
670+
/// <summary>
671+
/// Name of column to transform. If set to <see langword="null"/>, the value of the <see cref="Name"/> will be used as source.
672+
/// </summary>
664673
public readonly string InputColumnName;
674+
/// <summary>
675+
/// Subtract mean from each value before normalizing.
676+
/// </summary>
665677
public readonly bool SubtractMean;
678+
/// <summary>
679+
/// The norm to use to normalize each sample.
680+
/// </summary>
666681
public readonly NormalizerKind NormKind;
682+
/// <summary>
683+
/// Scale features by this value.
684+
/// </summary>
667685
public readonly float Scale;
668686

669-
internal ColumnInfoBase(string name, string inputColumnName, bool substractMean, LpNormalizingEstimatorBase.NormalizerKind normalizerKind, float scale)
687+
internal ColumnInfoBase(string name, string inputColumnName, bool substractMean, NormalizerKind normalizerKind, float scale)
670688
{
671689
Contracts.CheckNonWhiteSpace(name, nameof(name));
672690
Contracts.CheckNonWhiteSpace(inputColumnName, nameof(inputColumnName));
@@ -692,14 +710,14 @@ internal ColumnInfoBase(ModelLoadContext ctx, string name, string inputColumnNam
692710
// Float: Scale
693711
SubtractMean = ctx.Reader.ReadBoolByte();
694712
byte normKindVal = ctx.Reader.ReadByte();
695-
Contracts.CheckDecode(Enum.IsDefined(typeof(LpNormalizingEstimatorBase.NormalizerKind), normKindVal));
696-
NormKind = (LpNormalizingEstimatorBase.NormalizerKind)normKindVal;
713+
Contracts.CheckDecode(Enum.IsDefined(typeof(NormalizerKind), normKindVal));
714+
NormKind = (NormalizerKind)normKindVal;
697715
// Note: In early versions, a bool option (useStd) to whether to normalize by StdDev rather than
698716
// L2 norm was used. normKind was added in version=verVectorNormalizerSupported.
699717
// normKind was defined in a way such that the serialized boolean (0: use StdDev, 1: use L2) is
700718
// still valid.
701719
Contracts.CheckDecode(normKindSerialized ||
702-
(NormKind == LpNormalizingEstimatorBase.NormalizerKind.L2Norm || NormKind == LpNormalizingEstimatorBase.NormalizerKind.StdDev));
720+
(NormKind == NormalizerKind.L2Norm || NormKind == NormalizerKind.StdDev));
703721
Scale = ctx.Reader.ReadFloat();
704722
Contracts.CheckDecode(0 < Scale && Scale < float.PositiveInfinity);
705723
}
@@ -776,17 +794,17 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
776794
}
777795

778796
/// <summary>
779-
/// Lp Normalizing estimator allow you take columns and normalize them individually by rescaling them to unit norm.
797+
/// Lp Normalizing estimator takes columns and normalizes them individually by rescaling them to unit norm.
780798
/// </summary>
781799
public sealed class LpNormalizingEstimator : LpNormalizingEstimatorBase
782800
{
783801
/// <summary>
784-
/// Describes how the transformer handles one LpNorm column pair.
802+
/// Describes how the transformer handles one column pair.
785803
/// </summary>
786804
public sealed class LpNormColumnInfo : ColumnInfoBase
787805
{
788806
/// <summary>
789-
/// Describes how the transformer handles one LpNorm column pair.
807+
/// Describes how the transformer handles one column pair.
790808
/// </summary>
791809
/// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
792810
/// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="name"/> will be used as source.</param>
@@ -832,7 +850,7 @@ internal LpNormalizingEstimator(IHostEnvironment env, params LpNormColumnInfo[]
832850
}
833851

834852
/// <summary>
835-
/// Global contrast normalizing estimator allow you take columns and performs global constrast normalization on them.
853+
/// Global contrast normalizing estimator takes columns and performs global constrast normalization.
836854
/// </summary>
837855
public sealed class GlobalContrastNormalizingEstimator : LpNormalizingEstimatorBase
838856
{

src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs

+29-4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131

3232
namespace Microsoft.ML.Transforms.Projections
3333
{
34+
/// <summary>
35+
/// Maps vector columns to a random low-dimensional feature space.
36+
/// </summary>
3437
public sealed class RandomFourierFeaturizingTransformer : OneToOneTransformerBase
3538
{
3639
internal sealed class Options
@@ -609,23 +612,45 @@ private void TransformFeatures(in VBuffer<float> src, ref VBuffer<float> dst, Tr
609612
}
610613

611614
/// <summary>
612-
/// Estimator which takes set of vector columns and maps its input to a random low-dimensional feature space.
615+
/// Maps vector columns to a random low-dimensional feature space.
613616
/// </summary>
614617
public sealed class RandomFourierFeaturizingEstimator : IEstimator<RandomFourierFeaturizingTransformer>
615618
{
616-
public static class Defaults
619+
[BestFriend]
620+
internal static class Defaults
617621
{
618622
public const int NewDim = 1000;
619623
public const bool UseSin = false;
620624
}
621625

626+
/// <summary>
627+
/// Describes how the transformer handles one Gcn column pair.
628+
/// </summary>
622629
public sealed class ColumnInfo
623630
{
631+
/// <summary>
632+
/// Name of the column resulting from the transformation of <see cref="InputColumnName"/>.
633+
/// </summary>
624634
public readonly string Name;
635+
/// <summary>
636+
/// Name of the column to transform. If set to <see langword="null"/>, the value of the <see cref="Name"/> will be used as source.
637+
/// </summary>
625638
public readonly string InputColumnName;
639+
/// <summary>
640+
/// Which fourier generator to use.
641+
/// </summary>
626642
public readonly IComponentFactory<float, IFourierDistributionSampler> Generator;
643+
/// <summary>
644+
/// The number of random Fourier features to create.
645+
/// </summary>
627646
public readonly int NewDim;
647+
/// <summary>
648+
/// Create two features for every random Fourier frequency? (one for cos and one for sin).
649+
/// </summary>
628650
public readonly bool UseSin;
651+
/// <summary>
652+
/// The seed of the random number generator for generating the new features (if unspecified, the global random is used).
653+
/// </summary>
629654
public readonly int? Seed;
630655

631656
/// <summary>
@@ -636,7 +661,7 @@ public sealed class ColumnInfo
636661
/// <param name="useSin">Create two features for every random Fourier frequency? (one for cos and one for sin).</param>
637662
/// <param name="inputColumnName">Name of column to transform. </param>
638663
/// <param name="generator">Which fourier generator to use.</param>
639-
/// <param name="seed">The seed of the random number generator for generating the new features (if unspecified, the global random is used.</param>
664+
/// <param name="seed">The seed of the random number generator for generating the new features (if unspecified, the global random is used).</param>
640665
public ColumnInfo(string name, int newDim, bool useSin, string inputColumnName = null, IComponentFactory<float, IFourierDistributionSampler> generator = null, int? seed = null)
641666
{
642667
Contracts.CheckUserArg(newDim > 0, nameof(newDim), "must be positive.");
@@ -673,7 +698,7 @@ internal RandomFourierFeaturizingEstimator(IHostEnvironment env, params ColumnIn
673698
}
674699

675700
/// <summary>
676-
/// Train and return a transformer.
701+
/// Trains and returns a <see cref="RandomFourierFeaturizingTransformer"/>.
677702
/// </summary>
678703
public RandomFourierFeaturizingTransformer Fit(IDataView input) => new RandomFourierFeaturizingTransformer(_host, input, _columns);
679704

src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs

+3
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,9 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory<char> stopwords, string d
809809
}
810810
}
811811

812+
/// <summary>
813+
/// The names of the input output column pairs on which this transformation is applied.
814+
/// </summary>
812815
public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();
813816

814817
/// <summary>

0 commit comments

Comments
 (0)