Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Enable passing custom fetch function to BrowserHTTPRequest constructor #1422

Merged
merged 14 commits into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions src/io/browser_http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,29 @@ export class BrowserHTTPRequest implements IOHandler {
protected readonly path: string|string[];
protected readonly requestInit: RequestInit;

private readonly fetchFunc: Function;

readonly DEFAULT_METHOD = 'POST';

static readonly URL_SCHEME_REGEX = /^https?:\/\//;

constructor(
path: string|string[], requestInit?: RequestInit,
private readonly weightPathPrefix?: string) {
if (typeof fetch === 'undefined') {
throw new Error(
// tslint:disable-next-line:max-line-length
'browserHTTPRequest is not supported outside the web browser without a fetch polyfill.');
private readonly weightPathPrefix?: string, fetchFunc?: Function) {
if (fetchFunc == null) {
if (typeof fetch === 'undefined') {
throw new Error(
'browserHTTPRequest is not supported outside the web browser ' +
'without a fetch polyfill.');
}
this.fetchFunc = fetch;
} else {
assert(
typeof fetchFunc === 'function',
'Must pass a function that matches the signature of ' +
'`fetch` (see ' +
'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');
this.fetchFunc = fetchFunc;
}

assert(
Expand Down Expand Up @@ -98,7 +110,7 @@ export class BrowserHTTPRequest implements IOHandler {
'model.weights.bin');
}

const response = await fetch(this.path as string, init);
const response = await this.fetchFunc(this.path as string, init);

if (response.ok) {
return {
Expand Down Expand Up @@ -130,7 +142,7 @@ export class BrowserHTTPRequest implements IOHandler {
*/
private async loadBinaryTopology(): Promise<ArrayBuffer> {
try {
const response = await fetch(this.path[0], this.requestInit);
const response = await this.fetchFunc(this.path[0], this.requestInit);
if (!response.ok) {
throw new Error(
`BrowserHTTPRequest.load() failed due to HTTP response: ${
Expand All @@ -144,7 +156,8 @@ export class BrowserHTTPRequest implements IOHandler {

protected async loadBinaryModel(): Promise<ModelArtifacts> {
const graphPromise = this.loadBinaryTopology();
const manifestPromise = await fetch(this.path[1], this.requestInit);
const manifestPromise =
await this.fetchFunc(this.path[1], this.requestInit);
if (!manifestPromise.ok) {
throw new Error(`BrowserHTTPRequest.load() failed due to HTTP response: ${
manifestPromise.statusText}`);
Expand All @@ -168,7 +181,7 @@ export class BrowserHTTPRequest implements IOHandler {

protected async loadJSONModel(): Promise<ModelArtifacts> {
const modelConfigRequest =
await fetch(this.path as string, this.requestInit);
await this.fetchFunc(this.path as string, this.requestInit);
if (!modelConfigRequest.ok) {
throw new Error(`BrowserHTTPRequest.load() failed due to HTTP response: ${
modelConfigRequest.statusText}`);
Expand Down Expand Up @@ -216,8 +229,8 @@ export class BrowserHTTPRequest implements IOHandler {

return [
weightSpecs,
concatenateArrayBuffers(
await loadWeightsAsArrayBuffer(fetchURLs, this.requestInit))
concatenateArrayBuffers(await loadWeightsAsArrayBuffer(
fetchURLs, this.requestInit, this.fetchFunc))
];
}
}
Expand All @@ -242,7 +255,7 @@ export function parseUrl(url: string): [string, string] {
return [prefix + '/', suffix];
}

function isHTTPScheme(url: string): boolean {
export function isHTTPScheme(url: string): boolean {
return url.match(BrowserHTTPRequest.URL_SCHEME_REGEX) != null;
}

Expand Down Expand Up @@ -404,11 +417,13 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
* `body`, an Error will be thrown.
* @param weightPathPrefix Optional, this specifies the path prefix for weight
* files, by default this is calculated from the path param.
* files, by default this is calculated from the path param.
* @param fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
* the `fetch` from node-fetch can be used here.
* @returns An instance of `IOHandler`.
*/
export function browserHTTPRequest(
path: string|string[], requestInit?: RequestInit,
weightPathPrefix?: string): IOHandler {
return new BrowserHTTPRequest(path, requestInit, weightPathPrefix);
path: string|string[], requestInit?: RequestInit, weightPathPrefix?: string,
fetchFunc?: Function): IOHandler {
return new BrowserHTTPRequest(path, requestInit, weightPathPrefix, fetchFunc);
}
52 changes: 52 additions & 0 deletions src/io/browser_http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1032,4 +1032,56 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(() => tf.io.browserHTTPRequest(['path1/model.pb'])).toThrow();
});
});

it('Overriding BrowserHTTPRequest fetchFunc', async () => {
const weightManifest1: tf.io.WeightsManifestConfig = [{
paths: ['weightfile0'],
weights: [
{
name: 'dense/kernel',
shape: [3, 1],
dtype: 'float32',
},
{
name: 'dense/bias',
shape: [2],
dtype: 'float32',
}
]
}];
const floatData = new Float32Array([1, 3, 3, 7, 4]);

const fetchInputs: RequestInfo[] = [];
const fetchInits: RequestInit[] = [];
async function customFetch(
input: RequestInfo, init?: RequestInit): Promise<Response> {
fetchInputs.push(input);
fetchInits.push(init);

if (input === './model.json') {
return new Response(
JSON.stringify({
modelTopology: modelTopology1,
weightsManifest: weightManifest1
}),
{status: 200});
} else if (input === './weightfile0') {
return new Response(floatData, {status: 200});
} else {
return new Response(null, {status: 404});
}
}

const handler = tf.io.browserHTTPRequest(
'./model.json', {credentials: 'include'}, null, customFetch);
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);

expect(fetchInputs).toEqual(['./model.json', './weightfile0']);
expect(fetchInits).toEqual([
{credentials: 'include'}, {credentials: 'include'}
]);
});
});
3 changes: 2 additions & 1 deletion src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import './indexed_db';
import './local_storage';

import {browserFiles} from './browser_files';
import {browserHTTPRequest} from './browser_http';
import {browserHTTPRequest, isHTTPScheme} from './browser_http';
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
import {fromMemory, withSaveHandler} from './passthrough';
import {IORouterRegistry} from './router_registry';
Expand All @@ -46,6 +46,7 @@ export {
getModelArtifactsInfoForJSON,
getSaveHandlers,
IOHandler,
isHTTPScheme,
LoadHandler,
loadWeights,
ModelArtifacts,
Expand Down
11 changes: 9 additions & 2 deletions src/io/weights_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ import {DTYPE_VALUE_SIZE_MAP, WeightsManifestConfig, WeightsManifestEntry} from
*
* @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.
* @param requestOptions RequestInit (options) for the HTTP requests.
* @param fetchFunc Optional overriding value for the `window.fetch` function.
* @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same
* length as `fetchURLs`.
*/
export async function loadWeightsAsArrayBuffer(
fetchURLs: string[], requestOptions?: RequestInit): Promise<ArrayBuffer[]> {
fetchURLs: string[], requestOptions?: RequestInit, fetchFunc?: Function):
Promise<ArrayBuffer[]> {
if (fetchFunc == null) {
fetchFunc = fetch;
}

// Create the requests for all of the weights in parallel.
const requests = fetchURLs.map(fetchURL => fetch(fetchURL, requestOptions));
const requests = fetchURLs.map(
fetchURL => fetchFunc(fetchURL, requestOptions));
const responses = await Promise.all(requests);
const buffers =
await Promise.all(responses.map(response => response.arrayBuffer()));
Expand Down