Skip to content

Commit e3db0e7

Browse files
author
Ivan Matantsev
committed
introspective training
1 parent 86cf8c9 commit e3db0e7

File tree

2 files changed

+121
-3
lines changed

2 files changed

+121
-3
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
using Microsoft.ML.Runtime.Data;
2+
using Microsoft.ML.Runtime.FastTree;
3+
using Microsoft.ML.Runtime.Internal.Calibration;
4+
using Microsoft.ML.Runtime.Internal.Internallearn;
5+
using Microsoft.ML.Runtime.Learners;
6+
using Microsoft.ML.Runtime.TextAnalytics;
7+
using System;
8+
using System.Collections.Generic;
9+
using System.Text;
10+
using Xunit;
11+
12+
namespace Microsoft.ML.Tests.Scenarios.Api
13+
{
14+
15+
public partial class ApiScenariosTests
16+
{
17+
private TOut GetValue<TOut>(Dictionary<string, object> keyValues, string key)
18+
{
19+
if (keyValues.ContainsKey(key))
20+
return (TOut)keyValues[key];
21+
22+
return default(TOut);
23+
}
24+
25+
/// <summary>
26+
/// Introspective training: Models that produce outputs and are otherwise black boxes are of limited use;
27+
/// it is also necessary often to understand at least to some degree what was learnt. To outline critical
28+
/// scenarios that have come up multiple times:
29+
/// *) When I train a linear model, I should be able to inspect coefficients.
30+
/// *) The tree ensemble learners, I should be able to inspect the trees.
31+
/// *) The LDA transform, I should be able to inspect the topics.
32+
/// I view it as essential from a usability perspective that this be discoverable to someone without
33+
/// having to read documentation.E.g.: if I have var lda = new LdaTransform().Fit(data)(I don't insist on that
34+
/// exact signature, just giving the idea), then if I were to type lda.
35+
/// In Visual Studio, one of the auto-complete targets should be something like GetTopics.
36+
/// </summary>
37+
38+
[Fact]
39+
void IntrospectiveTraining()
40+
{
41+
var dataPath = GetDataPath(SentimentDataPath);
42+
43+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
44+
{
45+
// Pipeline
46+
var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));
47+
48+
var words = WordBagTransform.Create(env, new WordBagTransform.Arguments()
49+
{
50+
NgramLength = 1,
51+
Column = new[] { new WordBagTransform.Column() { Name = "Tokenize", Source = new[] { "SentimentText" } } }
52+
},
53+
loader
54+
);
55+
var lda = new LdaTransform(env, new LdaTransform.Arguments()
56+
{
57+
NumTopic = 10,
58+
NumIterations = 3,
59+
NumThreads = 1,
60+
Column = new[] { new LdaTransform.Column { Source = "Tokenize", Name = "Features"}
61+
}
62+
}, words);
63+
var trainData = lda;
64+
65+
var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
66+
// Train the first predictor.
67+
var linearTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
68+
{
69+
NumThreads = 1
70+
});
71+
var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
72+
var linearPredictor = linearTrainer.Train(new Runtime.TrainContext(trainRoles));
73+
VBuffer<float> weights = default;
74+
linearPredictor.GetFeatureWeights(ref weights);
75+
76+
var topicSummary = lda.GetTopicSummary();
77+
var treeTrainer = new FastTreeBinaryClassificationTrainer(env,
78+
new FastTreeBinaryClassificationTrainer.Arguments
79+
{
80+
NumTrees = 2
81+
}
82+
);
83+
var ftPredictor = treeTrainer.Train(new Runtime.TrainContext(trainRoles));
84+
FastTreeBinaryPredictor treePredictor;
85+
if (ftPredictor is CalibratedPredictorBase calibrator)
86+
treePredictor = (FastTreeBinaryPredictor)calibrator.SubPredictor;
87+
else
88+
treePredictor = (FastTreeBinaryPredictor)ftPredictor;
89+
var featureNameCollection = FeatureNameCollection.Create(trainRoles.Schema);
90+
foreach (var tree in treePredictor.GetTrees())
91+
{
92+
var lteChild = tree.LteChild;
93+
var gteChild = tree.GtChild;
94+
//get Nodes;
95+
for (var i = 0; i < tree.NumNodes; i++)
96+
{
97+
var node = tree.GetNode(i, false, featureNameCollection);
98+
var gainValue = GetValue<double>(node.KeyValues, "GainValue");
99+
var splitGain = GetValue<double>(node.KeyValues, "SplitGain");
100+
var featureName = GetValue<string>(node.KeyValues, "SplitName");
101+
var previousLeafValue = GetValue<double>(node.KeyValues, "PreviousLeafValue");
102+
var threshold = GetValue<string>(node.KeyValues, "Threshold").Split(new[] { ' ' }, 2)[1];
103+
var nodeIndex = i;
104+
}
105+
// get leafs
106+
for (var i = 0; i < tree.NumLeaves; i++)
107+
{
108+
var node = tree.GetNode(i, true, featureNameCollection);
109+
var leafValue = GetValue<double>(node.KeyValues, "LeafValue");
110+
var extras = GetValue<string>(node.KeyValues, "Extras");
111+
var nodeIndex = ~i;
112+
}
113+
}
114+
115+
116+
117+
}
118+
}
119+
}
120+
}

test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
using Microsoft.ML.Core.Data;
2-
using Microsoft.ML.Runtime.Api;
1+
using Microsoft.ML.Runtime.Api;
32
using Microsoft.ML.Runtime.Data;
43
using Microsoft.ML.Runtime.Learners;
5-
using Microsoft.ML.Runtime.Model;
64
using System.Linq;
75
using Xunit;
86

0 commit comments

Comments
 (0)