diff --git a/src/io/browser_files.ts b/src/io/browser_files.ts index 04a95c3e93..9c5f5987c6 100644 --- a/src/io/browser_files.ts +++ b/src/io/browser_files.ts @@ -29,7 +29,7 @@ const DEFAULT_FILE_NAME_PREFIX = 'model'; const DEFAULT_JSON_EXTENSION_NAME = '.json'; const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin'; -export class BrowserDownloads implements IOHandler { +class BrowserDownloads implements IOHandler { private readonly modelTopologyFileName: string; private readonly weightDataFileName: string; private readonly jsonAnchor: HTMLAnchorElement; @@ -102,7 +102,7 @@ export class BrowserDownloads implements IOHandler { } } -export class BrowserFiles implements IOHandler { +class BrowserFiles implements IOHandler { private readonly files: File[]; constructor(files: File[]) { diff --git a/src/io/browser_http.ts b/src/io/browser_http.ts new file mode 100644 index 0000000000..70cf49f7a9 --- /dev/null +++ b/src/io/browser_http.ts @@ -0,0 +1,236 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * IOHandler implementations based on HTTP requests in the web browser. + * + * Uses [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). + */ + +import {assert} from '../util'; + +import {getModelArtifactsInfoForKerasJSON} from './io_utils'; +// tslint:disable-next-line:max-line-length +import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig} from './types'; + +class BrowserHTTPRequest implements IOHandler { + protected readonly path: string; + protected readonly requestInit: RequestInit; + + readonly DEFAULT_METHOD = 'POST'; + + constructor(path: string, requestInit?: RequestInit) { + assert( + path != null && path.length > 0, + 'URL path for browserHTTPRequest must not be null, undefined or ' + + 'empty.'); + this.path = path; + + if (requestInit != null && requestInit.body != null) { + throw new Error( + 'requestInit is expected to have no pre-existing body, but has one.'); + } + this.requestInit = requestInit || {}; + } + + async save(modelArtifacts: ModelArtifacts): Promise { + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { + throw new Error( + 'BrowserHTTPRequest.save() does not support saving model topology ' + + 'in binary formats yet.'); + } + + const init = Object.assign({method: this.DEFAULT_METHOD}, this.requestInit); + init.body = new FormData(); + + const weightsManifest: WeightsManifestConfig = [{ + paths: ['./model.weights.bin'], + weights: modelArtifacts.weightSpecs, + }]; + const modelTopologyAndWeightManifest = { + modelTopology: modelArtifacts.modelTopology, + weightsManifest + }; + + init.body.append( + 'model.json', + new Blob( + [JSON.stringify(modelTopologyAndWeightManifest)], + {type: 'application/json'}), + 'model.json'); + + if (modelArtifacts.weightData != null) { + init.body.append( + 'model.weights.bin', + new Blob( + [modelArtifacts.weightData], {type: 'application/octet-stream'}), + 'model.weights.bin'); + } + + const response = await fetch(this.path, init); + + if (response.status === 200) { + return { + modelArtifactsInfo: getModelArtifactsInfoForKerasJSON(modelArtifacts), + responses: [response], + }; + } else { + throw new Error( + `BrowserHTTPRequest.save() failed due to HTTP response status ` + + `${response.status}.`); + } + } + + // TODO(cais): Add load to unify this IOHandler type and the mechanism + // that currently underlies `tf.loadModel('path')` in tfjs-layers. + // See: https://github.com/tensorflow/tfjs/issues/290 +} + +// tslint:disable:max-line-length +/** + * Creates an IOHandler subtype that sends model artifacts to HTTP server. + * + * An HTTP request of the `multipart/form-data` mime type will be sent to the + * `path` URL. The form data includes artifacts that represent the topology + * and/or weights of the model. In the case of Keras-style `tf.Model`, two + * blobs (files) exist in form-data: + * - A JSON file consisting of `modelTopology` and `weightsManifest`. + * - A binary weights file consisting of the concatenated weight values. + * These files are in the same format as the one generated by + * [tensorflowjs_converter](https://js.tensorflow.org/tutorials/import-keras.html). + * + * The following code snippet exemplifies the client-side code that uses this + * function: + * + * ```js + * const model = tf.sequential(); + * model.add( + * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); + * + * const saveResult = await model.save(tf.io.browserHTTPRequest( + * 'http://model-server:5000/upload', {method: 'PUT'})); + * console.log(saveResult); + * ``` + * + * The following Python code snippet based on the + * [flask](https://github.com/pallets/flask) server framework implements a + * server that can receive the request. Upon receiving the model artifacts + * via the requst, this particular server reconsistutes instances of + * [Keras Models](https://keras.io/models/model/) in memory. + * + * ```python + * # pip install -U flask flask-cors keras tensorflow tensorflowjs + * + * from __future__ import absolute_import + * from __future__ import division + * from __future__ import print_function + * + * import io + * + * from flask import Flask, Response, request + * from flask_cors import CORS, cross_origin + * import tensorflow as tf + * import tensorflowjs as tfjs + * import werkzeug.formparser + * + * + * class ModelReceiver(object): + * + * def __init__(self): + * self._model = None + * self._model_json_bytes = None + * self._model_json_writer = None + * self._weight_bytes = None + * self._weight_writer = None + * + * @property + * def model(self): + * self._model_json_writer.flush() + * self._weight_writer.flush() + * self._model_json_writer.seek(0) + * self._weight_writer.seek(0) + * + * json_content = self._model_json_bytes.read() + * weights_content = self._weight_bytes.read() + * return tfjs.converters.deserialize_keras_model( + * json_content, + * weight_data=[weights_content], + * use_unique_name_scope=True) + * + * def stream_factory(self, + * total_content_length, + * content_type, + * filename, + * content_length=None): + * # Note: this example code is *not* thread-safe. + * if filename == 'model.json': + * self._model_json_bytes = io.BytesIO() + * self._model_json_writer = io.BufferedWriter(self._model_json_bytes) + * return self._model_json_writer + * elif filename == 'model.weights.bin': + * self._weight_bytes = io.BytesIO() + * self._weight_writer = io.BufferedWriter(self._weight_bytes) + * return self._weight_writer + * + * + * def main(): + * app = Flask('model-server') + * CORS(app) + * app.config['CORS_HEADER'] = 'Content-Type' + * + * model_receiver = ModelReceiver() + * + * @app.route('/upload', methods=['POST']) + * @cross_origin() + * def upload(): + * print('Handling request...') + * werkzeug.formparser.parse_form_data( + * request.environ, stream_factory=model_receiver.stream_factory) + * print('Received model:') + * with tf.Graph().as_default(), tf.Session(): + * model = model_receiver.model + * model.summary() + * # You can perform `model.predict()`, `model.fit()`, + * # `model.evaluate()` etc. here. + * return Response(status=200) + * + * app.run('localhost', 5000) + * + * + * if __name__ == '__main__': + * main() + * ``` + * + * @param path URL path. Can be an absolute HTTP path (e.g., + * 'http://localhost:8000/model-upload)') or a relative path (e.g., + * './model-upload'). + * @param requestInit Request configurations to be used when sending + * HTTP request to server using `fetch`. It can contain fields such as + * `method`, `credentials`, `headers`, `mode`, etc. See + * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request + * for more information. `requestInit` must not have a body, because the body + * will be set by TensorFlow.js. File blobs representing + * the model topology (filename: 'model.json') and the weights of the + * model (filename: 'model.weights.bin') will be appended to the body. + * If `requestInit` has a `body`, an Error will be thrown. + * @returns An instance of `IOHandler`. + */ +// tslint:enable:max-line-length +export function browserHTTPRequest( + path: string, requestInit?: RequestInit): IOHandler { + return new BrowserHTTPRequest(path, requestInit); +} diff --git a/src/io/browser_http_test.ts b/src/io/browser_http_test.ts new file mode 100644 index 0000000000..0689c463e9 --- /dev/null +++ b/src/io/browser_http_test.ts @@ -0,0 +1,283 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Unit tests for browser_http.ts. + */ + +import * as tf from '../index'; +import {describeWithFlags} from '../jasmine_util'; +import {CPU_ENVS} from '../test_util'; + +describeWithFlags('browserHTTPRequest', CPU_ENVS, () => { + // Test data. + const modelTopology1: {} = { + 'class_name': 'Sequential', + 'keras_version': '2.1.4', + 'config': [{ + 'class_name': 'Dense', + 'config': { + 'kernel_initializer': { + 'class_name': 'VarianceScaling', + 'config': { + 'distribution': 'uniform', + 'scale': 1.0, + 'seed': null, + 'mode': 'fan_avg' + } + }, + 'name': 'dense', + 'kernel_constraint': null, + 'bias_regularizer': null, + 'bias_constraint': null, + 'dtype': 'float32', + 'activation': 'linear', + 'trainable': true, + 'kernel_regularizer': null, + 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, + 'units': 1, + 'batch_input_shape': [null, 3], + 'use_bias': true, + 'activity_regularizer': null + } + }], + 'backend': 'tensorflow' + }; + const weightSpecs1: tf.io.WeightsManifestEntry[] = [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [1], + dtype: 'float32', + } + ]; + const weightData1 = new ArrayBuffer(16); + const artifacts1: tf.io.ModelArtifacts = { + modelTopology: modelTopology1, + weightSpecs: weightSpecs1, + weightData: weightData1, + }; + + let requestInits: RequestInit[] = []; + + beforeEach(() => { + requestInits = []; + spyOn(window, 'fetch').and.callFake((path: string, init: RequestInit) => { + if (path === 'model-upload-test') { + requestInits.push(init); + return new Response(null, {status: 200}); + } else { + return new Response(null, {status: 404}); + } + }); + }); + + it('Save topology and weights, default POST method', done => { + const testStartDate = new Date(); + const handler = tf.io.browserHTTPRequest('model-upload-test'); + handler.save(artifacts1) + .then(saveResult => { + expect(saveResult.modelArtifactsInfo.dateSaved.getTime()) + .toBeGreaterThanOrEqual(testStartDate.getTime()); + // Note: The following two assertions work only because there is no + // non-ASCII characters in `modelTopology1` and `weightSpecs1`. + expect(saveResult.modelArtifactsInfo.modelTopologyBytes) + .toEqual(JSON.stringify(modelTopology1).length); + expect(saveResult.modelArtifactsInfo.weightSpecsBytes) + .toEqual(JSON.stringify(weightSpecs1).length); + expect(saveResult.modelArtifactsInfo.weightDataBytes) + .toEqual(weightData1.byteLength); + + expect(requestInits.length).toEqual(1); + const init = requestInits[0]; + expect(init.method).toEqual('POST'); + const body = init.body as FormData; + const jsonFile = body.get('model.json') as File; + const jsonFileReader = new FileReader(); + jsonFileReader.onload = (event: Event) => { + // tslint:disable-next-line:no-any + const modelJSON = JSON.parse((event.target as any).result); + expect(modelJSON.modelTopology).toEqual(modelTopology1); + expect(modelJSON.weightsManifest.length).toEqual(1); + expect(modelJSON.weightsManifest[0].weights).toEqual(weightSpecs1); + + const weightsFile = body.get('model.weights.bin') as File; + const weightsFileReader = new FileReader(); + weightsFileReader.onload = (event: Event) => { + // tslint:disable-next-line:no-any + const weightData = (event.target as any).result as ArrayBuffer; + expect(new Uint8Array(weightData)) + .toEqual(new Uint8Array(weightData1)); + done(); + }; + weightsFileReader.onerror = (error: ErrorEvent) => { + done.fail(error.message); + }; + weightsFileReader.readAsArrayBuffer(weightsFile); + }; + jsonFileReader.onerror = (error: ErrorEvent) => { + done.fail(error.message); + }; + jsonFileReader.readAsText(jsonFile); + }) + .catch(err => { + done.fail(err.stack); + }); + }); + + it('Save topology only, default POST method', done => { + const testStartDate = new Date(); + const handler = tf.io.browserHTTPRequest('model-upload-test'); + const topologyOnlyArtifacts = {modelTopology: modelTopology1}; + handler.save(topologyOnlyArtifacts) + .then(saveResult => { + expect(saveResult.modelArtifactsInfo.dateSaved.getTime()) + .toBeGreaterThanOrEqual(testStartDate.getTime()); + // Note: The following two assertions work only because there is no + // non-ASCII characters in `modelTopology1` and `weightSpecs1`. + expect(saveResult.modelArtifactsInfo.modelTopologyBytes) + .toEqual(JSON.stringify(modelTopology1).length); + expect(saveResult.modelArtifactsInfo.weightSpecsBytes).toEqual(0); + expect(saveResult.modelArtifactsInfo.weightDataBytes).toEqual(0); + + expect(requestInits.length).toEqual(1); + const init = requestInits[0]; + expect(init.method).toEqual('POST'); + const body = init.body as FormData; + const jsonFile = body.get('model.json') as File; + const jsonFileReader = new FileReader(); + jsonFileReader.onload = (event: Event) => { + // tslint:disable-next-line:no-any + const modelJSON = JSON.parse((event.target as any).result); + expect(modelJSON.modelTopology).toEqual(modelTopology1); + // No weights should have been sent to the server. + expect(body.get('model.weights.bin')).toEqual(null); + done(); + }; + jsonFileReader.onerror = (error: ErrorEvent) => { + done.fail(error.message); + }; + jsonFileReader.readAsText(jsonFile); + }) + .catch(err => { + done.fail(err.stack); + }); + }); + + it('Save topology and weights, PUT method, extra headers', done => { + const testStartDate = new Date(); + const handler = tf.io.browserHTTPRequest('model-upload-test', { + method: 'PUT', + headers: { + 'header_key_1': 'header_value_1', + 'header_key_2': 'header_value_2' + } + }); + handler.save(artifacts1) + .then(saveResult => { + expect(saveResult.modelArtifactsInfo.dateSaved.getTime()) + .toBeGreaterThanOrEqual(testStartDate.getTime()); + // Note: The following two assertions work only because there is no + // non-ASCII characters in `modelTopology1` and `weightSpecs1`. + expect(saveResult.modelArtifactsInfo.modelTopologyBytes) + .toEqual(JSON.stringify(modelTopology1).length); + expect(saveResult.modelArtifactsInfo.weightSpecsBytes) + .toEqual(JSON.stringify(weightSpecs1).length); + expect(saveResult.modelArtifactsInfo.weightDataBytes) + .toEqual(weightData1.byteLength); + + expect(requestInits.length).toEqual(1); + const init = requestInits[0]; + expect(init.method).toEqual('PUT'); + + // Check headers. + expect(init.headers).toEqual({ + 'header_key_1': 'header_value_1', + 'header_key_2': 'header_value_2' + }); + + const body = init.body as FormData; + const jsonFile = body.get('model.json') as File; + const jsonFileReader = new FileReader(); + jsonFileReader.onload = (event: Event) => { + // tslint:disable-next-line:no-any + const modelJSON = JSON.parse((event.target as any).result); + expect(modelJSON.modelTopology).toEqual(modelTopology1); + expect(modelJSON.weightsManifest.length).toEqual(1); + expect(modelJSON.weightsManifest[0].weights).toEqual(weightSpecs1); + + const weightsFile = body.get('model.weights.bin') as File; + const weightsFileReader = new FileReader(); + weightsFileReader.onload = (event: Event) => { + // tslint:disable-next-line:no-any + const weightData = (event.target as any).result as ArrayBuffer; + expect(new Uint8Array(weightData)) + .toEqual(new Uint8Array(weightData1)); + done(); + }; + weightsFileReader.onerror = (error: ErrorEvent) => { + done.fail(error.message); + }; + weightsFileReader.readAsArrayBuffer(weightsFile); + }; + jsonFileReader.onerror = (error: ErrorEvent) => { + done.fail(error.message); + }; + jsonFileReader.readAsText(jsonFile); + }) + .catch(err => { + done.fail(err.stack); + }); + }); + + it('404 response causes Error', done => { + const handler = tf.io.browserHTTPRequest('invalid/path'); + handler.save(artifacts1) + .then(saveResult => { + done.fail( + 'Calling browserHTTPRequest at invalid URL succeeded ' + + 'unexpectedly'); + }) + .catch(err => { + done(); + }); + }); + + it('Existing body leads to Error', () => { + const key1Data = '1337'; + const key2Data = '42'; + const extraFormData = new FormData(); + extraFormData.set('key1', key1Data); + extraFormData.set('key2', key2Data); + expect(() => tf.io.browserHTTPRequest('model-upload-test', { + body: extraFormData + })).toThrowError(/requestInit is expected to have no pre-existing body/); + }); + + it('Empty, null or undefined URL paths lead to Error', () => { + expect(() => tf.io.browserHTTPRequest(null)) + .toThrowError(/must not be null, undefined or empty/); + expect(() => tf.io.browserHTTPRequest(undefined)) + .toThrowError(/must not be null, undefined or empty/); + expect(() => tf.io.browserHTTPRequest('')) + .toThrowError(/must not be null, undefined or empty/); + }); +}); diff --git a/src/io/indexed_db.ts b/src/io/indexed_db.ts index a9b62dfcd6..2e40f71156 100644 --- a/src/io/indexed_db.ts +++ b/src/io/indexed_db.ts @@ -65,7 +65,7 @@ function getIndexedDBFactory(): IDBFactory { * * See the doc string of `browserIndexedDB` for more details. */ -export class BrowserIndexedDB implements IOHandler { +class BrowserIndexedDB implements IOHandler { protected readonly indexedDB: IDBFactory; protected readonly modelPath: string; diff --git a/src/io/io.ts b/src/io/io.ts index bc86e4fba6..6225e444bf 100644 --- a/src/io/io.ts +++ b/src/io/io.ts @@ -16,6 +16,7 @@ */ import {browserDownloads, browserFiles} from './browser_files'; +import {browserHTTPRequest} from './browser_http'; import {browserIndexedDB} from './indexed_db'; import {decodeWeights, encodeWeights} from './io_utils'; import {browserLocalStorage} from './local_storage'; @@ -26,6 +27,7 @@ import {loadWeights} from './weights_loader'; export { browserDownloads, browserFiles, + browserHTTPRequest, browserIndexedDB, browserLocalStorage, decodeWeights, diff --git a/src/io/local_storage.ts b/src/io/local_storage.ts index f2e3c9d037..44a843a049 100644 --- a/src/io/local_storage.ts +++ b/src/io/local_storage.ts @@ -61,7 +61,7 @@ export function purgeLocalStorageArtifacts(): string[] { * * See the doc string to `browserLocalStorage` for more details. */ -export class BrowserLocalStorage implements IOHandler { +class BrowserLocalStorage implements IOHandler { protected readonly LS: Storage; protected readonly modelPath: string; protected readonly paths: {[key: string]: string}; @@ -217,8 +217,8 @@ export class BrowserLocalStorage implements IOHandler { * * @param modelPath A unique identifier for the model to be saved. Must be a * non-empty string. - * @returns An instance of `BrowserLocalStorage` (sublcass of `IOHandler`), - * which can be used with, e.g., `tf.Model.save`. + * @returns An instance of `IOHandler`, which can be used with, e.g., + * `tf.Model.save`. */ export function browserLocalStorage(modelPath: string): IOHandler { return new BrowserLocalStorage(modelPath);