diff --git a/package.json b/package.json index 1e035cfda..06ad78156 100644 --- a/package.json +++ b/package.json @@ -6,7 +6,7 @@ "main": "dist/index.js", "types": "dist/index.d.ts", "devDependencies": { - "@tensorflow/tfjs-core": "0.8.5", + "@tensorflow/tfjs-core": "0.9.1", "@types/jasmine": "~2.5.53", "browserify": "~16.1.0", "clang-format": "~1.2.2", @@ -42,6 +42,6 @@ "lint": "tslint -p . --type-check -t verbose" }, "peerDependencies": { - "@tensorflow/tfjs-core": "0.8.5" + "@tensorflow/tfjs-core": "0.9.1" } } diff --git a/src/engine/topology.ts b/src/engine/topology.ts index 6ad142a18..53893b692 100644 --- a/src/engine/topology.ts +++ b/src/engine/topology.ts @@ -985,9 +985,13 @@ export abstract class Layer extends Serializable { /** * Returns the current values of the weights of the layer. + * + * @param trainableOnly Whether to get the values of only trainable weights. + * @returns Weight values as an `Array` of `Tensor`s. */ - getWeights(): Tensor[] { - return K.batchGetValue(this.weights); + getWeights(trainableOnly = false): Tensor[] { + return K.batchGetValue( + trainableOnly ? this.trainableWeights : this.weights); } /** diff --git a/src/exports.ts b/src/exports.ts index 35fff688a..79566831d 100644 --- a/src/exports.ts +++ b/src/exports.ts @@ -13,7 +13,7 @@ */ // tslint:disable:max-line-length -import {doc, Tensor} from '@tensorflow/tfjs-core'; +import {doc, io, Tensor} from '@tensorflow/tfjs-core'; import {Constraint, MaxNorm, MaxNormConfig, MinMaxNorm, MinMaxNormConfig, NonNeg, UnitNorm, UnitNormConfig} from './constraints'; import {ContainerConfig, Input, InputConfig, InputLayer, InputLayerConfig, Layer, LayerConfig} from './engine/topology'; @@ -155,8 +155,8 @@ export class ModelExports { subheading: 'Loading', useDocsFrom: 'loadModelInternal' }) - static loadModel(modelConfigPath: string): Promise { - return loadModelInternal(modelConfigPath); + static loadModel(pathOrIOHandler: string|io.IOHandler): Promise { + return loadModelInternal(pathOrIOHandler); } @doc({ diff --git a/src/model_save_test.ts b/src/model_save_test.ts new file mode 100644 index 000000000..0a0aad2f5 --- /dev/null +++ b/src/model_save_test.ts @@ -0,0 +1,104 @@ +/** + * @license + * Copyright 2018 Google LLC + * + * Use of this source code is governed by an MIT-style + * license that can be found in the LICENSE file or at + * https://opensource.org/licenses/MIT. + * ============================================================================= + */ + +import {io} from '@tensorflow/tfjs-core'; + +import {Dense} from './layers/core'; +import {Sequential} from './models'; +import {describeMathCPUAndGPU} from './utils/test_utils'; + +describeMathCPUAndGPU('Model.save', () => { + class IOHandlerForTest implements io.IOHandler { + savedArtifacts: io.ModelArtifacts; + + async save(modelArtifacts: io.ModelArtifacts): Promise { + this.savedArtifacts = modelArtifacts; + return {modelArtifactsInfo: null}; + } + } + + class EmptyIOHandler implements io.IOHandler {} + + it('Saving all weights succeeds', async done => { + const model = new Sequential(); + model.add(new Dense({units: 3, inputShape: [5]})); + const handler = new IOHandlerForTest(); + + model.save(handler) + .then(saveResult => { + expect(handler.savedArtifacts.modelTopology) + .toEqual(model.toJSON(null, false)); + expect(handler.savedArtifacts.weightSpecs.length).toEqual(2); + expect(handler.savedArtifacts.weightSpecs[0].name.indexOf('/kernel')) + .toBeGreaterThan(0); + expect(handler.savedArtifacts.weightSpecs[0].shape).toEqual([5, 3]); + expect(handler.savedArtifacts.weightSpecs[0].dtype) + .toEqual('float32'); + expect(handler.savedArtifacts.weightSpecs[1].name.indexOf('/bias')) + .toBeGreaterThan(0); + expect(handler.savedArtifacts.weightSpecs[1].shape).toEqual([3]); + expect(handler.savedArtifacts.weightSpecs[1].dtype) + .toEqual('float32'); + done(); + }) + .catch(err => { + console.error(err.stack); + }); + }); + + it('Saving only trainable weights succeeds', async done => { + const model = new Sequential(); + model.add(new Dense({units: 3, inputShape: [5], trainable: false})); + model.add(new Dense({units: 2})); + const handler = new IOHandlerForTest(); + + model.save(handler, {trainableOnly: true}) + .then(saveResult => { + expect(handler.savedArtifacts.modelTopology) + .toEqual(model.toJSON(null, false)); + // Verify that only the trainable weights (i.e., weights from the + // 2nd, trainable Dense layer) are saved. + expect(handler.savedArtifacts.weightSpecs.length).toEqual(2); + expect(handler.savedArtifacts.weightSpecs[0].name.indexOf('/kernel')) + .toBeGreaterThan(0); + expect(handler.savedArtifacts.weightSpecs[0].shape).toEqual([3, 2]); + expect(handler.savedArtifacts.weightSpecs[0].dtype) + .toEqual('float32'); + expect(handler.savedArtifacts.weightSpecs[1].name.indexOf('/bias')) + .toBeGreaterThan(0); + expect(handler.savedArtifacts.weightSpecs[1].shape).toEqual([2]); + expect(handler.savedArtifacts.weightSpecs[1].dtype) + .toEqual('float32'); + done(); + }) + .catch(err => { + console.error(err.stack); + }); + }); + + it('Saving to a handler without save method fails', async done => { + const model = new Sequential(); + model.add(new Dense({units: 3, inputShape: [5]})); + const handler = new EmptyIOHandler(); + model.save(handler) + .then(saveResult => { + fail( + 'Saving with an IOHandler without `save` succeeded ' + + 'unexpectedly.'); + }) + .catch(err => { + expect(err.message) + .toEqual( + 'Model.save() cannot proceed because the IOHandler ' + + 'provided does not have the `save` attribute defined.'); + done(); + }); + }); +}); diff --git a/src/models.ts b/src/models.ts index 781dd3e2a..fa3b8e96f 100644 --- a/src/models.ts +++ b/src/models.ts @@ -17,7 +17,7 @@ import * as K from './backend/tfjs_backend'; import {History} from './callbacks'; import {getSourceInputs, Input, Layer, Node} from './engine/topology'; import {Model, ModelCompileConfig, ModelEvaluateConfig, ModelFitConfig, ModelPredictConfig} from './engine/training'; -import {RuntimeError, ValueError} from './errors'; +import {NotImplementedError, RuntimeError, ValueError} from './errors'; import {deserialize} from './layers/serialization'; import {NamedTensorMap, Serializable, Shape} from './types'; import {ConfigDict, ConfigDictArray, Constructor, JsonDict, SymbolicTensor} from './types'; @@ -40,7 +40,7 @@ export async function modelFromJSON( let modelTopology = modelAndWeightsConfig.modelTopology; if (modelTopology['model_config'] != null) { // If the model-topology JSON contains a 'model_config' field, then it is - // a full model JSON (e.g., from `keras.models.save_model`), which contains + // a full model JSON (e.g., from `keras.Model.save()`), which contains // not only the model's architecture in its 'model_config' field, but // additional information such as the model's optimizer. We use only the // 'model_config' field currently. @@ -85,7 +85,7 @@ export interface ModelAndWeightsConfig { * return value of`keras.Model.to_json()`. * - A full model config, containing not only model architecture, but also * training options and state, i.e., a format consistent with the return - * value of `keras.models.save_model(). + * value of `keras.models.save_model()`. */ modelTopology: JsonDict; @@ -102,6 +102,72 @@ export interface ModelAndWeightsConfig { pathPrefix?: string; } +/** + * Load a model, including its topology and optionally weights. See the + * Tutorial named "How to import a Keras Model" for usage examples. + * + * @param pathOrIOHandler Can be either of the two formats + * 1. A string path to the `ModelAndWeightsConfig` JSON describing + * the model in the canonical TensorFlow.js format. This path will be + * interpreted as a relative HTTP path, to which `fetch` will be used to + * request the model topology and weight manifest JSON. + * The content of the JSON file is assumed to be a JSON object with the + * following fields and values: + * - 'modelTopology': A JSON object that can be either of: + * 1. a model architecture JSON consistent with the format of the return + * value of `keras.Model.to_json()` + * 2. a full model JSON in the format of `keras.models.save_model()`. + * - 'weightsManifest': A TensorFlow.js weights manifest. + * See the Python converter function `save_model()` for more details. + * It is also assumed that model weights can be accessed from relative + * paths described by the `paths` fields in weights manifest. + * 2. An `tf.io.IOHandler` object that loads model artifacts with its `load` + * method. + * + * @returns A `Promise` of `Model`, with the topology and weights loaded. + */ +export async function loadModelInternal(pathOrIOHandler: string| + io.IOHandler): Promise { + return (typeof pathOrIOHandler === 'string') ? + loadModelFromPath(pathOrIOHandler) : + loadModelFromIOHandler(pathOrIOHandler as io.IOHandler); +} + +/** + * Load a model and optionally its weights, using an IOHandler object. + */ +export async function loadModelFromIOHandler( + handler: io.IOHandler, customObjects?: ConfigDict): Promise { + if (handler.load == null) { + throw new ValueError( + 'Cannot proceed with model loading because the IOHandler provided ' + + 'does not have the `load` method implemented.'); + } + const artifacts = await handler.load(); + const model = deserialize( + convertPythonicToTs( + artifacts.modelTopology as ConfigDict) as ConfigDict, + customObjects) as Model; + + // If weightData is present, load the weights into the model. + if (artifacts.weightData != null) { + // Loading weights requires weightSpecs. + if (artifacts.weightSpecs == null) { + throw new ValueError( + 'Model artifacts contains weight data, but not weight specs. ' + + 'Therefore loading of weights cannot proceed.'); + } + + const skipMismatch = false; + const isNamedTensorMap = true; + model.loadWeights( + io.decodeWeights(artifacts.weightData, artifacts.weightSpecs), + skipMismatch, isNamedTensorMap); + } + return model; +} + +// tslint:disable:max-line-length /** * Load a model, including its topology and optionally weights. See the * Tutorial named "How to import a Keras Model" for usage examples. @@ -114,18 +180,21 @@ export interface ModelAndWeightsConfig { * - 'modelTopology': A JSON object that can be either of: * 1. a model architecture JSON consistent with the format of the return * value of `keras.Model.to_json()` - * 2. a full model JSON in the format of `keras.models.save_model()`. + * 2. a full model JSON in the format of + * [`keras.Model.save()`](https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model). * - 'weightsManifest': A TensorFlow.js weights manifest. - * - * See the Python converter function `save_model()` for more details. + * See the Python converter function + * [`save_keras_model()`](https://js.tensorflow.org/tutorials/import-keras.html) + * for more details. * * It is also assumed that model weights can be accessed from relative paths * described by the `paths` fields in weights manifest. * * @returns A `Promise` of `Model`, with the topology and weights loaded. */ +// tslint:enable:max-line-length // TODO(cais): Add link to the core's documentation of `WeightManifestConfig`. -export async function loadModelInternal(modelConfigPath: string): +export async function loadModelFromPath(modelConfigPath: string): Promise { // TODO(soergel): accept a Request object too, not just a url string. const modelConfigRequest = await fetch(modelConfigPath); @@ -540,6 +609,77 @@ export class Sequential extends Model { // TODO(cais): Override get trainableWeights() here + /** + * Extract weight values of the model. + * + * @param config: An instance of `io.SaveConfig`, which specifies model-saving + * options such as whether only trainable weights are to be saved. + * @returns A `NamedTensorMap` mapping original weight names (i.e., + * non-uniqueified weight names) to their values. + */ + protected getNamedWeights(config?: io.SaveConfig): NamedTensorMap { + const namedWeights: NamedTensorMap = {}; + + const trainableOnly = config != null && config.trainableOnly; + const weights = trainableOnly ? this.trainableWeights : this.weights; + const weightValues = this.getWeights(trainableOnly); + for (let i = 0; i < weights.length; ++i) { + if (trainableOnly && !weights[i].trainable) { + // Optionally skip non-trainable weights. + continue; + } + namedWeights[weights[i].originalName] = weightValues[i]; + } + return namedWeights; + } + + /** + * Save the configuration and/or weights of the Model. + * + * An `IOHandler` is an object that has a `save` method of the proper + * signature defined. The `save` method manages the storing or transmission of + * serialized data ("artifacts") that represent the model's topology and + * weights onto or via a specific medium, such as file downloads, local + * storage, IndexedDB in the web browser and HTTP requests to a server. + * TensorFlow.js provides `IOHandler` implementations for a number of + * frequently used saving mediums, such as `tf.io.browserDownloads` and + * `tf.io.browserLocalStorage`. See `tf.io` for more details. + * + * This method also allows you to refer to certain types of `IOHandler`s as + * URL-like string shortcuts, such as 'localstorage://' and 'indexeddb://'. + * + * @param handlerOrURL An instance of `IOHandler` or a URL-like, scheme-based + * string shortcut for `IOHandler`. + * @param config Options for saving the model. + * @returns A `Promise` of `SaveResult`, which summarizes the result of the + * saving, such as byte sizes of the saved artifacts for the model's + * topology and weight values. + */ + async save(handlerOrURL: io.IOHandler|string, config?: io.SaveConfig): + Promise { + if (typeof handlerOrURL === 'string') { + throw new NotImplementedError( + 'String URLs support in Model.save() is not implemented yet.'); + } + if (handlerOrURL.save == null) { + throw new ValueError( + 'Model.save() cannot proceed because the IOHandler provided does ' + + 'not have the `save` attribute defined.'); + } + + const weightDataAndSpecs = + await io.encodeWeights(this.getNamedWeights(config)); + + const returnString = false; + const modelConfig = this.toJSON(null, returnString); + + return handlerOrURL.save({ + modelTopology: modelConfig, + weightData: weightDataAndSpecs.data, + weightSpecs: weightDataAndSpecs.specs + }); + } + // tslint:disable-next-line:no-any getConfig(): any { // NOTE(cais): We override the return type of getConfig() to `any` here, diff --git a/src/models_test.ts b/src/models_test.ts index 69271c7ae..47b3f5dd5 100644 --- a/src/models_test.ts +++ b/src/models_test.ts @@ -9,14 +9,14 @@ */ // tslint:disable:max-line-length -import {io, ones, Scalar, scalar, Tensor, zeros} from '@tensorflow/tfjs-core'; +import {io, ones, Scalar, scalar, Tensor, tensor1d, tensor2d, zeros} from '@tensorflow/tfjs-core'; import * as K from './backend/tfjs_backend'; import {Model} from './engine/training'; import * as tfl from './index'; import {Reshape} from './layers/core'; import {deserialize} from './layers/serialization'; -import {ModelAndWeightsConfig, modelFromJSON} from './models'; +import {loadModelInternal, ModelAndWeightsConfig, modelFromJSON} from './models'; import {ConfigDict, JsonDict} from './types'; import {convertPythonicToTs} from './utils/serialization_utils'; import {describeMathCPU, describeMathCPUAndGPU, expectTensorsClose} from './utils/test_utils'; @@ -161,7 +161,7 @@ describeMathCPU('model_from_json', () => { }); }); -describeMathCPU('loadModel', () => { +describeMathCPU('loadModel from URL', () => { const setupFakeWeightFiles = (fileBufferMap: {[filename: string]: Float32Array|Int32Array|ArrayBuffer}) => { @@ -327,6 +327,133 @@ describeMathCPU('loadModel', () => { }); }); +describeMathCPU('loadModel from IOHandler', () => { + // The model topology JSON can be obtained with the following Python Keras + // code: + // + // ```python + // import keras + // model = keras.Sequential([ + // keras.layers.Dense(1, input_shape=[4], activation='sigmoid') + // ]) + // print(model.to_json()) + // ``` + const modelTopology: {} = { + 'class_name': 'Sequential', + 'keras_version': '2.1.4', + 'config': [{ + 'class_name': 'Dense', + 'config': { + 'kernel_initializer': { + 'class_name': 'VarianceScaling', + 'config': { + 'distribution': 'uniform', + 'scale': 1.0, + 'seed': null, + 'mode': 'fan_avg' + } + }, + 'name': 'dense_1', + 'kernel_constraint': null, + 'bias_regularizer': null, + 'bias_constraint': null, + 'dtype': 'float32', + 'activation': 'sigmoid', + 'trainable': true, + 'kernel_regularizer': null, + 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, + 'units': 1, + 'batch_input_shape': [null, 4], + 'use_bias': true, + 'activity_regularizer': null + } + }], + 'backend': 'tensorflow' + }; + const weightSpecs: io.WeightsManifestEntry[] = [ + { + name: 'dense_1/kernel', + shape: [4, 1], + dtype: 'float32', + }, + { + name: 'dense_1/bias', + shape: [1], + dtype: 'float32', + } + ]; + const weightData = new Float32Array([1.1, 2.2, 3.3, 4.4, 5.5]).buffer; + + // A dummy IOHandler that returns hard-coded model artifacts when its `load` + // method is called. + class IOHandlerForTest implements io.IOHandler { + private readonly includeWeights: boolean; + + constructor(includeWeights = true) { + this.includeWeights = includeWeights; + } + + async load(): Promise { + return this.includeWeights ? {modelTopology, weightSpecs, weightData} : + {modelTopology}; + } + } + + // A dummy IOHandler that doesn't have the `load` method implemented and is + // expected to cause `loadModel` or `loadModelInternal` to fail. + class IOHandlerWithoutLoad implements io.IOHandler { + constructor() {} + } + + it('load topology and weights', async done => { + loadModelInternal(new IOHandlerForTest(true)) + .then(model => { + expect(model.layers.length).toEqual(1); + expect(model.inputs.length).toEqual(1); + expect(model.inputs[0].shape).toEqual([null, 4]); + expect(model.outputs.length).toEqual(1); + expect(model.outputs[0].shape).toEqual([null, 1]); + const weightValues = model.getWeights(); + expect(weightValues.length).toEqual(2); + expectTensorsClose( + weightValues[0], tensor2d([1.1, 2.2, 3.3, 4.4], [4, 1])); + expectTensorsClose(weightValues[1], tensor1d([5.5])); + done(); + }) + .catch(err => { + done.fail(err.stack); + }); + }); + + it('load topology only', async done => { + loadModelInternal(new IOHandlerForTest(false)) + .then(model => { + expect(model.layers.length).toEqual(1); + expect(model.inputs.length).toEqual(1); + expect(model.inputs[0].shape).toEqual([null, 4]); + expect(model.outputs.length).toEqual(1); + expect(model.outputs[0].shape).toEqual([null, 1]); + done(); + }) + .catch(err => { + done.fail(err.stack); + }); + }); + + it('IOHandler without load method causes error', async done => { + loadModelInternal(new IOHandlerWithoutLoad()) + .then(model => { + done.fail( + 'Loading with an IOHandler without load method succeeded ' + + 'unexpectedly.'); + }) + .catch(err => { + expect(err.message).toMatch(/does not have .*load.* method/); + done(); + }); + }); +}); + describeMathCPUAndGPU('Sequential', () => { const inputShape = [1, 6]; const batchInputShape = [1].concat(inputShape); diff --git a/yarn.lock b/yarn.lock index 6907d4072..226dc8923 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2,9 +2,9 @@ # yarn lockfile v1 -"@tensorflow/tfjs-core@0.8.5": - version "0.8.5" - resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-0.8.5.tgz#8b0d5e99094eae47806a98a76bc3cadf8b7efd4d" +"@tensorflow/tfjs-core@0.9.1": + version "0.9.1" + resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-0.9.1.tgz#c36763068a5b8ec99c60fd161dc58dd3fb97001b" dependencies: seedrandom "~2.4.3"