Skip to content

Create model file V1 scenario tests #2899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 25, 2019
6 changes: 3 additions & 3 deletions test/Microsoft.ML.Functional.Tests/Datasets/CommonColumns.cs
Original file line number Diff line number Diff line change
@@ -13,23 +13,23 @@ internal sealed class FeatureColumn
}

/// <summary>
/// A class to hold the output of FeatureContributionCalculator.
/// A class to hold the output of FeatureContributionCalculator
/// </summary>
internal sealed class FeatureContributionOutput
{
public float[] FeatureContributions { get; set; }
}

/// <summary>
/// A class to hold the Score column.
/// A class to hold a score column.
/// </summary>
internal sealed class ScoreColumn
{
public float Score { get; set; }
}

/// <summary>
/// A class to hold a vector Score column.
/// A class to hold a vector score column.
/// </summary>
internal sealed class VectorScoreColumn
{
Original file line number Diff line number Diff line change
@@ -4,9 +4,11 @@

using System;
using System.IO;
using System.IO.Compression;
using System.Linq;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Functional.Tests.Datasets;
using Microsoft.ML.RunTests;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Transforms;
@@ -15,9 +17,9 @@

namespace Microsoft.ML.Functional.Tests
{
public partial class ModelLoadingTests : TestDataPipeBase
public partial class ModelFiles : TestDataPipeBase
{
public ModelLoadingTests(ITestOutputHelper output) : base(output)
public ModelFiles(ITestOutputHelper output) : base(output)
{
}

@@ -30,6 +32,101 @@ private class InputData
public float[] Features { get; set; }
}

/// <summary>
/// Model Files: The (minimum) nuget version can be found in the model file.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model file [](start = 73, length = 10)

Is this an old model file or a current one? If it's current, shouldn't we create the model on the fly (similar to the other scenario below) instead of reading a static keep-model.zip file?

/// </summary>
[Fact]
public void DetermineNugetVersionFromModel()
{
var mlContext = new MLContext(seed: 1);

// Get the dataset.
var data = mlContext.Data.LoadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);

// Create a pipeline to train on the housing data.
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
.Append(mlContext.Regression.Trainers.FastTree(
new FastTreeRegressionTrainer.Options { NumberOfThreads = 1, NumberOfTrees = 10 }));

// Fit the pipeline.
var model = pipeline.Fit(data);

// Save model to a file.
var modelPath = DeleteOutputPath("determineNugetVersionFromModel.zip");
mlContext.Model.Save(model, data.Schema, modelPath);

// Check that the version can be extracted from the model.
var versionFileName = @"TrainingInfo" + Path.DirectorySeparatorChar + "Version.txt";
using (ZipArchive archive = ZipFile.OpenRead(modelPath))
{
// The version of the entire model is kept in the version file.
var versionPath = archive.Entries.First(x => x.FullName == versionFileName);
Assert.NotNull(versionPath);
using (var stream = versionPath.Open())
using (var reader = new StreamReader(stream))
{
// The only line in the file is the version of the model.
var line = reader.ReadLine();
Assert.Equal(@"1.0.0.0", line);
}
}
}

/// <summary>
/// Model Files: Save a model, including all transforms, then load and make predictions.
/// </summary>
/// <remarks>
/// Serves two scenarios:
/// 1. I can train a model and save it to a file, including transforms.
/// 2. Training and prediction happen in different processes (or even different machines).
/// The actual test will not run in different processes, but will simulate the idea that the
/// "communication pipe" is just a serialized model of some form.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary for tests files, but as FYI lists need to be in xml style as well.

/// </remarks>
[Fact]
public void FitPipelineSaveModelAndPredict()
{
var mlContext = new MLContext(seed: 1);

// Get the dataset.
var data = mlContext.Data.LoadFromTextFile<HousingRegression>(GetDataPath(TestDatasets.housing.trainFilename), hasHeader: true);

// Create a pipeline to train on the housing data.
var pipeline = mlContext.Transforms.Concatenate("Features", HousingRegression.Features)
.Append(mlContext.Regression.Trainers.FastTree(
new FastTreeRegressionTrainer.Options { NumberOfThreads = 1, NumberOfTrees = 10 }));

// Fit the pipeline.
var model = pipeline.Fit(data);

var modelPath = DeleteOutputPath("fitPipelineSaveModelAndPredict.zip");
// Save model to a file.
mlContext.Model.Save(model, data.Schema, modelPath);

// Load model from a file.
ITransformer serializedModel;
using (var file = File.OpenRead(modelPath))
{
serializedModel = mlContext.Model.Load(file, out var serializedSchema);
Copy link

@yaeldekel yaeldekel Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serializedSchema [](start = 69, length = 16)

You can verify that this is the same as data.Schema. #Resolved

CheckSameSchemas(data.Schema, serializedSchema);
}

// Create prediction engine and test predictions.
var originalPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, ScoreColumn>(model);
var serializedPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, ScoreColumn>(serializedModel);

// Take a handful of examples out of the dataset and compute predictions.
var dataEnumerator = mlContext.Data.CreateEnumerable<HousingRegression>(mlContext.Data.TakeRows(data, 5), false);
foreach (var row in dataEnumerator)
{
var originalPrediction = originalPredictionEngine.Predict(row);
var serializedPrediction = serializedPredictionEngine.Predict(row);
// Check that the predictions are identical.
Assert.Equal(originalPrediction.Score, serializedPrediction.Score);
}

Done();
}

[Fact]
public void LoadModelAndExtractPredictor()
{

This file was deleted.