|
| 1 | +using Microsoft.ML.Runtime.Data; |
| 2 | +using Microsoft.ML.Runtime.FastTree; |
| 3 | +using Microsoft.ML.Runtime.Internal.Calibration; |
| 4 | +using Microsoft.ML.Runtime.Internal.Internallearn; |
| 5 | +using Microsoft.ML.Runtime.Learners; |
| 6 | +using Microsoft.ML.Runtime.TextAnalytics; |
| 7 | +using System; |
| 8 | +using System.Collections.Generic; |
| 9 | +using System.Text; |
| 10 | +using Xunit; |
| 11 | + |
| 12 | +namespace Microsoft.ML.Tests.Scenarios.Api |
| 13 | +{ |
| 14 | + |
| 15 | + public partial class ApiScenariosTests |
| 16 | + { |
| 17 | + private TOut GetValue<TOut>(Dictionary<string, object> keyValues, string key) |
| 18 | + { |
| 19 | + if (keyValues.ContainsKey(key)) |
| 20 | + return (TOut)keyValues[key]; |
| 21 | + |
| 22 | + return default(TOut); |
| 23 | + } |
| 24 | + |
| 25 | + /// <summary> |
| 26 | + /// Introspective training: Models that produce outputs and are otherwise black boxes are of limited use; |
| 27 | + /// it is also necessary often to understand at least to some degree what was learnt. To outline critical |
| 28 | + /// scenarios that have come up multiple times: |
| 29 | + /// *) When I train a linear model, I should be able to inspect coefficients. |
| 30 | + /// *) The tree ensemble learners, I should be able to inspect the trees. |
| 31 | + /// *) The LDA transform, I should be able to inspect the topics. |
| 32 | + /// I view it as essential from a usability perspective that this be discoverable to someone without |
| 33 | + /// having to read documentation.E.g.: if I have var lda = new LdaTransform().Fit(data)(I don't insist on that |
| 34 | + /// exact signature, just giving the idea), then if I were to type lda. |
| 35 | + /// In Visual Studio, one of the auto-complete targets should be something like GetTopics. |
| 36 | + /// </summary> |
| 37 | + |
| 38 | + [Fact] |
| 39 | + void IntrospectiveTraining() |
| 40 | + { |
| 41 | + var dataPath = GetDataPath(SentimentDataPath); |
| 42 | + |
| 43 | + using (var env = new TlcEnvironment(seed: 1, conc: 1)) |
| 44 | + { |
| 45 | + // Pipeline |
| 46 | + var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); |
| 47 | + |
| 48 | + var words = WordBagTransform.Create(env, new WordBagTransform.Arguments() |
| 49 | + { |
| 50 | + NgramLength = 1, |
| 51 | + Column = new[] { new WordBagTransform.Column() { Name = "Tokenize", Source = new[] { "SentimentText" } } } |
| 52 | + }, |
| 53 | +loader |
| 54 | +); |
| 55 | + var lda = new LdaTransform(env, new LdaTransform.Arguments() |
| 56 | + { |
| 57 | + NumTopic = 10, |
| 58 | + NumIterations = 3, |
| 59 | + NumThreads = 1, |
| 60 | + Column = new[] { new LdaTransform.Column { Source = "Tokenize", Name = "Features"} |
| 61 | + } |
| 62 | + }, words); |
| 63 | + var trainData = lda; |
| 64 | + |
| 65 | + var cachedTrain = new CacheDataView(env, trainData, prefetch: null); |
| 66 | + // Train the first predictor. |
| 67 | + var linearTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments |
| 68 | + { |
| 69 | + NumThreads = 1 |
| 70 | + }); |
| 71 | + var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features"); |
| 72 | + var linearPredictor = linearTrainer.Train(new Runtime.TrainContext(trainRoles)); |
| 73 | + VBuffer<float> weights = default; |
| 74 | + linearPredictor.GetFeatureWeights(ref weights); |
| 75 | + |
| 76 | + var topicSummary = lda.GetTopicSummary(); |
| 77 | + var treeTrainer = new FastTreeBinaryClassificationTrainer(env, |
| 78 | + new FastTreeBinaryClassificationTrainer.Arguments |
| 79 | + { |
| 80 | + NumTrees = 2 |
| 81 | + } |
| 82 | + ); |
| 83 | + var ftPredictor = treeTrainer.Train(new Runtime.TrainContext(trainRoles)); |
| 84 | + FastTreeBinaryPredictor treePredictor; |
| 85 | + if (ftPredictor is CalibratedPredictorBase calibrator) |
| 86 | + treePredictor = (FastTreeBinaryPredictor)calibrator.SubPredictor; |
| 87 | + else |
| 88 | + treePredictor = (FastTreeBinaryPredictor)ftPredictor; |
| 89 | + var featureNameCollection = FeatureNameCollection.Create(trainRoles.Schema); |
| 90 | + foreach (var tree in treePredictor.GetTrees()) |
| 91 | + { |
| 92 | + var lteChild = tree.LteChild; |
| 93 | + var gteChild = tree.GtChild; |
| 94 | + //get Nodes; |
| 95 | + for (var i = 0; i < tree.NumNodes; i++) |
| 96 | + { |
| 97 | + var node = tree.GetNode(i, false, featureNameCollection); |
| 98 | + var gainValue = GetValue<double>(node.KeyValues, "GainValue"); |
| 99 | + var splitGain = GetValue<double>(node.KeyValues, "SplitGain"); |
| 100 | + var featureName = GetValue<string>(node.KeyValues, "SplitName"); |
| 101 | + var previousLeafValue = GetValue<double>(node.KeyValues, "PreviousLeafValue"); |
| 102 | + var threshold = GetValue<string>(node.KeyValues, "Threshold").Split(new[] { ' ' }, 2)[1]; |
| 103 | + var nodeIndex = i; |
| 104 | + } |
| 105 | + // get leafs |
| 106 | + for (var i = 0; i < tree.NumLeaves; i++) |
| 107 | + { |
| 108 | + var node = tree.GetNode(i, true, featureNameCollection); |
| 109 | + var leafValue = GetValue<double>(node.KeyValues, "LeafValue"); |
| 110 | + var extras = GetValue<string>(node.KeyValues, "Extras"); |
| 111 | + var nodeIndex = ~i; |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + |
| 116 | + |
| 117 | + } |
| 118 | + } |
| 119 | + } |
| 120 | +} |
0 commit comments