Skip to content
This repository was archived by the owner on Oct 17, 2021. It is now read-only.

Add Model.save(); Let loadModel() support IOHandlers #161

Merged
merged 5 commits into from
May 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"main": "dist/index.js",
"types": "dist/index.d.ts",
"devDependencies": {
"@tensorflow/tfjs-core": "0.8.5",
"@tensorflow/tfjs-core": "0.9.1",
"@types/jasmine": "~2.5.53",
"browserify": "~16.1.0",
"clang-format": "~1.2.2",
Expand Down Expand Up @@ -42,6 +42,6 @@
"lint": "tslint -p . --type-check -t verbose"
},
"peerDependencies": {
"@tensorflow/tfjs-core": "0.8.5"
"@tensorflow/tfjs-core": "0.9.1"
}
}
8 changes: 6 additions & 2 deletions src/engine/topology.ts
Original file line number Diff line number Diff line change
Expand Up @@ -985,9 +985,13 @@ export abstract class Layer extends Serializable {

/**
* Returns the current values of the weights of the layer.
*
* @param trainableOnly Whether to get the values of only trainable weights.
* @returns Weight values as an `Array` of `Tensor`s.
*/
getWeights(): Tensor[] {
return K.batchGetValue(this.weights);
getWeights(trainableOnly = false): Tensor[] {
return K.batchGetValue(
trainableOnly ? this.trainableWeights : this.weights);
}

/**
Expand Down
6 changes: 3 additions & 3 deletions src/exports.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/

// tslint:disable:max-line-length
import {doc, Tensor} from '@tensorflow/tfjs-core';
import {doc, io, Tensor} from '@tensorflow/tfjs-core';

import {Constraint, MaxNorm, MaxNormConfig, MinMaxNorm, MinMaxNormConfig, NonNeg, UnitNorm, UnitNormConfig} from './constraints';
import {ContainerConfig, Input, InputConfig, InputLayer, InputLayerConfig, Layer, LayerConfig} from './engine/topology';
Expand Down Expand Up @@ -155,8 +155,8 @@ export class ModelExports {
subheading: 'Loading',
useDocsFrom: 'loadModelInternal'
})
static loadModel(modelConfigPath: string): Promise<Model> {
return loadModelInternal(modelConfigPath);
static loadModel(pathOrIOHandler: string|io.IOHandler): Promise<Model> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the design settled on is not backward compatible, so we'll need to make sure we bump the middle version number, right? (While we still accept strings like before, we require a prefix that we didn't in the past)

return loadModelInternal(pathOrIOHandler);
}

@doc({
Expand Down
104 changes: 104 additions & 0 deletions src/model_save_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

import {io} from '@tensorflow/tfjs-core';

import {Dense} from './layers/core';
import {Sequential} from './models';
import {describeMathCPUAndGPU} from './utils/test_utils';

describeMathCPUAndGPU('Model.save', () => {
class IOHandlerForTest implements io.IOHandler {
savedArtifacts: io.ModelArtifacts;

async save(modelArtifacts: io.ModelArtifacts): Promise<io.SaveResult> {
this.savedArtifacts = modelArtifacts;
return {modelArtifactsInfo: null};
}
}

class EmptyIOHandler implements io.IOHandler {}

it('Saving all weights succeeds', async done => {
const model = new Sequential();
model.add(new Dense({units: 3, inputShape: [5]}));
const handler = new IOHandlerForTest();

model.save(handler)
.then(saveResult => {
expect(handler.savedArtifacts.modelTopology)
.toEqual(model.toJSON(null, false));
expect(handler.savedArtifacts.weightSpecs.length).toEqual(2);
expect(handler.savedArtifacts.weightSpecs[0].name.indexOf('/kernel'))
.toBeGreaterThan(0);
expect(handler.savedArtifacts.weightSpecs[0].shape).toEqual([5, 3]);
expect(handler.savedArtifacts.weightSpecs[0].dtype)
.toEqual('float32');
expect(handler.savedArtifacts.weightSpecs[1].name.indexOf('/bias'))
.toBeGreaterThan(0);
expect(handler.savedArtifacts.weightSpecs[1].shape).toEqual([3]);
expect(handler.savedArtifacts.weightSpecs[1].dtype)
.toEqual('float32');
done();
})
.catch(err => {
console.error(err.stack);
});
});

it('Saving only trainable weights succeeds', async done => {
const model = new Sequential();
model.add(new Dense({units: 3, inputShape: [5], trainable: false}));
model.add(new Dense({units: 2}));
const handler = new IOHandlerForTest();

model.save(handler, {trainableOnly: true})
.then(saveResult => {
expect(handler.savedArtifacts.modelTopology)
.toEqual(model.toJSON(null, false));
// Verify that only the trainable weights (i.e., weights from the
// 2nd, trainable Dense layer) are saved.
expect(handler.savedArtifacts.weightSpecs.length).toEqual(2);
expect(handler.savedArtifacts.weightSpecs[0].name.indexOf('/kernel'))
.toBeGreaterThan(0);
expect(handler.savedArtifacts.weightSpecs[0].shape).toEqual([3, 2]);
expect(handler.savedArtifacts.weightSpecs[0].dtype)
.toEqual('float32');
expect(handler.savedArtifacts.weightSpecs[1].name.indexOf('/bias'))
.toBeGreaterThan(0);
expect(handler.savedArtifacts.weightSpecs[1].shape).toEqual([2]);
expect(handler.savedArtifacts.weightSpecs[1].dtype)
.toEqual('float32');
done();
})
.catch(err => {
console.error(err.stack);
});
});

it('Saving to a handler without save method fails', async done => {
const model = new Sequential();
model.add(new Dense({units: 3, inputShape: [5]}));
const handler = new EmptyIOHandler();
model.save(handler)
.then(saveResult => {
fail(
'Saving with an IOHandler without `save` succeeded ' +
'unexpectedly.');
})
.catch(err => {
expect(err.message)
.toEqual(
'Model.save() cannot proceed because the IOHandler ' +
'provided does not have the `save` attribute defined.');
done();
});
});
});
154 changes: 147 additions & 7 deletions src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import * as K from './backend/tfjs_backend';
import {History} from './callbacks';
import {getSourceInputs, Input, Layer, Node} from './engine/topology';
import {Model, ModelCompileConfig, ModelEvaluateConfig, ModelFitConfig, ModelPredictConfig} from './engine/training';
import {RuntimeError, ValueError} from './errors';
import {NotImplementedError, RuntimeError, ValueError} from './errors';
import {deserialize} from './layers/serialization';
import {NamedTensorMap, Serializable, Shape} from './types';
import {ConfigDict, ConfigDictArray, Constructor, JsonDict, SymbolicTensor} from './types';
Expand All @@ -40,7 +40,7 @@ export async function modelFromJSON(
let modelTopology = modelAndWeightsConfig.modelTopology;
if (modelTopology['model_config'] != null) {
// If the model-topology JSON contains a 'model_config' field, then it is
// a full model JSON (e.g., from `keras.models.save_model`), which contains
// a full model JSON (e.g., from `keras.Model.save()`), which contains
// not only the model's architecture in its 'model_config' field, but
// additional information such as the model's optimizer. We use only the
// 'model_config' field currently.
Expand Down Expand Up @@ -85,7 +85,7 @@ export interface ModelAndWeightsConfig {
* return value of`keras.Model.to_json()`.
* - A full model config, containing not only model architecture, but also
* training options and state, i.e., a format consistent with the return
* value of `keras.models.save_model().
* value of `keras.models.save_model()`.
*/
modelTopology: JsonDict;

Expand All @@ -102,6 +102,72 @@ export interface ModelAndWeightsConfig {
pathPrefix?: string;
}

/**
* Load a model, including its topology and optionally weights. See the
* Tutorial named "How to import a Keras Model" for usage examples.
*
* @param pathOrIOHandler Can be either of the two formats
* 1. A string path to the `ModelAndWeightsConfig` JSON describing
* the model in the canonical TensorFlow.js format. This path will be
* interpreted as a relative HTTP path, to which `fetch` will be used to
* request the model topology and weight manifest JSON.
* The content of the JSON file is assumed to be a JSON object with the
* following fields and values:
* - 'modelTopology': A JSON object that can be either of:
* 1. a model architecture JSON consistent with the format of the return
* value of `keras.Model.to_json()`
* 2. a full model JSON in the format of `keras.models.save_model()`.
* - 'weightsManifest': A TensorFlow.js weights manifest.
* See the Python converter function `save_model()` for more details.
* It is also assumed that model weights can be accessed from relative
* paths described by the `paths` fields in weights manifest.
* 2. An `tf.io.IOHandler` object that loads model artifacts with its `load`
* method.
*
* @returns A `Promise` of `Model`, with the topology and weights loaded.
*/
export async function loadModelInternal(pathOrIOHandler: string|
io.IOHandler): Promise<Model> {
return (typeof pathOrIOHandler === 'string') ?
loadModelFromPath(pathOrIOHandler) :
loadModelFromIOHandler(pathOrIOHandler as io.IOHandler);
}

/**
* Load a model and optionally its weights, using an IOHandler object.
*/
export async function loadModelFromIOHandler(
handler: io.IOHandler, customObjects?: ConfigDict): Promise<Model> {
if (handler.load == null) {
throw new ValueError(
'Cannot proceed with model loading because the IOHandler provided ' +
'does not have the `load` method implemented.');
}
const artifacts = await handler.load();
const model = deserialize(
convertPythonicToTs(
artifacts.modelTopology as ConfigDict) as ConfigDict,
customObjects) as Model;

// If weightData is present, load the weights into the model.
if (artifacts.weightData != null) {
// Loading weights requires weightSpecs.
if (artifacts.weightSpecs == null) {
throw new ValueError(
'Model artifacts contains weight data, but not weight specs. ' +
'Therefore loading of weights cannot proceed.');
}

const skipMismatch = false;
const isNamedTensorMap = true;
model.loadWeights(
io.decodeWeights(artifacts.weightData, artifacts.weightSpecs),
skipMismatch, isNamedTensorMap);
}
return model;
}

// tslint:disable:max-line-length
/**
* Load a model, including its topology and optionally weights. See the
* Tutorial named "How to import a Keras Model" for usage examples.
Expand All @@ -114,18 +180,21 @@ export interface ModelAndWeightsConfig {
* - 'modelTopology': A JSON object that can be either of:
* 1. a model architecture JSON consistent with the format of the return
* value of `keras.Model.to_json()`
* 2. a full model JSON in the format of `keras.models.save_model()`.
* 2. a full model JSON in the format of
* [`keras.Model.save()`](https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model).
* - 'weightsManifest': A TensorFlow.js weights manifest.
*
* See the Python converter function `save_model()` for more details.
* See the Python converter function
* [`save_keras_model()`](https://js.tensorflow.org/tutorials/import-keras.html)
* for more details.
*
* It is also assumed that model weights can be accessed from relative paths
* described by the `paths` fields in weights manifest.
*
* @returns A `Promise` of `Model`, with the topology and weights loaded.
*/
// tslint:enable:max-line-length
// TODO(cais): Add link to the core's documentation of `WeightManifestConfig`.
export async function loadModelInternal(modelConfigPath: string):
export async function loadModelFromPath(modelConfigPath: string):
Promise<Model> {
// TODO(soergel): accept a Request object too, not just a url string.
const modelConfigRequest = await fetch(modelConfigPath);
Expand Down Expand Up @@ -540,6 +609,77 @@ export class Sequential extends Model {

// TODO(cais): Override get trainableWeights() here

/**
* Extract weight values of the model.
*
* @param config: An instance of `io.SaveConfig`, which specifies model-saving
* options such as whether only trainable weights are to be saved.
* @returns A `NamedTensorMap` mapping original weight names (i.e.,
* non-uniqueified weight names) to their values.
*/
protected getNamedWeights(config?: io.SaveConfig): NamedTensorMap {
const namedWeights: NamedTensorMap = {};

const trainableOnly = config != null && config.trainableOnly;
const weights = trainableOnly ? this.trainableWeights : this.weights;
const weightValues = this.getWeights(trainableOnly);
for (let i = 0; i < weights.length; ++i) {
if (trainableOnly && !weights[i].trainable) {
// Optionally skip non-trainable weights.
continue;
}
namedWeights[weights[i].originalName] = weightValues[i];
}
return namedWeights;
}

/**
* Save the configuration and/or weights of the Model.
*
* An `IOHandler` is an object that has a `save` method of the proper
* signature defined. The `save` method manages the storing or transmission of
* serialized data ("artifacts") that represent the model's topology and
* weights onto or via a specific medium, such as file downloads, local
* storage, IndexedDB in the web browser and HTTP requests to a server.
* TensorFlow.js provides `IOHandler` implementations for a number of
* frequently used saving mediums, such as `tf.io.browserDownloads` and
* `tf.io.browserLocalStorage`. See `tf.io` for more details.
*
* This method also allows you to refer to certain types of `IOHandler`s as
* URL-like string shortcuts, such as 'localstorage://' and 'indexeddb://'.
*
* @param handlerOrURL An instance of `IOHandler` or a URL-like, scheme-based
* string shortcut for `IOHandler`.
* @param config Options for saving the model.
* @returns A `Promise` of `SaveResult`, which summarizes the result of the
* saving, such as byte sizes of the saved artifacts for the model's
* topology and weight values.
*/
async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig):
Promise<io.SaveResult> {
if (typeof handlerOrURL === 'string') {
throw new NotImplementedError(
'String URLs support in Model.save() is not implemented yet.');
}
if (handlerOrURL.save == null) {
throw new ValueError(
'Model.save() cannot proceed because the IOHandler provided does ' +
'not have the `save` attribute defined.');
}

const weightDataAndSpecs =
await io.encodeWeights(this.getNamedWeights(config));

const returnString = false;
const modelConfig = this.toJSON(null, returnString);

return handlerOrURL.save({
modelTopology: modelConfig,
weightData: weightDataAndSpecs.data,
weightSpecs: weightDataAndSpecs.specs
});
}

// tslint:disable-next-line:no-any
getConfig(): any {
// NOTE(cais): We override the return type of getConfig() to `any` here,
Expand Down
Loading