Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Adds ability to load quantized weights by de-quantizing them. #965

Merged
merged 11 commits into from
Apr 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions src/weights_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,20 @@ export interface WeightsManifestGroupConfig {
export interface WeightsManifestEntry {
name: string;
shape: number[];
dtype: 'float32'|'int32';
dtype: 'float32'|'int32'; // Dtype of the (unquantized) weights.
quantization?: {
// Information to dequantize the weights.
scale: number, // The scaling constant to multiply by.
min: number, // The (possibly nudged) minimum weight to add.
dtype: 'uint16'|'uint8' // The dtype of the quantized weights.
};
}

const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = {
'float32': 4,
'int32': 4
'int32': 4,
'uint16': 2,
'uint8': 1
};

/**
Expand Down Expand Up @@ -68,7 +76,11 @@ export async function loadWeights(
manifest.forEach((manifestGroupConfig, groupIndex) => {
let groupOffset = 0;
manifestGroupConfig.weights.forEach(weightsEntry => {
const weightsBytes = DTYPE_VALUE_SIZE_MAP[weightsEntry.dtype] *
const rawDtype = ('quantization' in weightsEntry) ?
weightsEntry.quantization.dtype :
weightsEntry.dtype;

const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *
util.sizeFromShape(weightsEntry.shape);

const enqueueWeightsForFetchingFn = () => {
Expand Down Expand Up @@ -161,14 +173,41 @@ export async function loadWeights(
weightsEntry.groupOffset + weightsEntry.sizeBytes);

let typedArray: Float32Array|Int32Array;
if (weightsEntry.manifestEntry.dtype === 'float32') {
typedArray = new Float32Array(byteBuffer);
} else if (weightsEntry.manifestEntry.dtype === 'int32') {
typedArray = new Int32Array(byteBuffer);

const dtype = weightsEntry.manifestEntry.dtype;

if ('quantization' in weightsEntry.manifestEntry) {
const quantization = weightsEntry.manifestEntry.quantization;
if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has unknown ` +
`quantization dtype ${quantization.dtype}.`);
}
const quantizedArray = (quantization.dtype === 'uint8') ?
new Uint8Array(byteBuffer) :
new Uint16Array(byteBuffer);
if (dtype === 'float32') {
typedArray = Float32Array.from(
quantizedArray, v => v * quantization.scale + quantization.min);
} else if (dtype === 'int32') {
typedArray = Int32Array.from(
quantizedArray,
v => Math.round(v * quantization.scale + quantization.min));
} else {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has a dtype not ` +
`supported by quantization: ${dtype}`);
}
} else {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has unknown dtype ` +
`${weightsEntry.manifestEntry.dtype}.`);
if (dtype === 'float32') {
typedArray = new Float32Array(byteBuffer);
} else if (dtype === 'int32') {
typedArray = new Int32Array(byteBuffer);
} else {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has unknown dtype ` +
`${dtype}.`);
}
}

const weightName = weightsEntry.manifestEntry.name;
Expand Down
129 changes: 121 additions & 8 deletions src/weights_loader_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
* =============================================================================
*/
import * as tf from './index';
import {CPU_ENVS, expectArraysClose} from './test_util';
import {describeWithFlags} from './jasmine_util';
import {CPU_ENVS, expectArraysClose, expectArraysEqual} from './test_util';
import {WeightsManifestConfig} from './weights_loader';

describeWithFlags('loadWeights', CPU_ENVS, () => {
const setupFakeWeightFiles =
(fileBufferMap:
{[filename: string]: Float32Array|Int32Array|ArrayBuffer}) => {
spyOn(window, 'fetch').and.callFake((path: string) => {
return new Response(fileBufferMap[path]);
});
};
const setupFakeWeightFiles = (fileBufferMap: {
[filename: string]: Float32Array|Int32Array|ArrayBuffer|Uint8Array|
Uint16Array
}) => {
spyOn(window, 'fetch').and.callFake((path: string) => {
return new Response(fileBufferMap[path]);
});
};

it('1 group, 1 weight, 1 requested weight', done => {
setupFakeWeightFiles({'./weightfile0': new Float32Array([1, 2, 3])});
Expand Down Expand Up @@ -465,4 +466,116 @@ describeWithFlags('loadWeights', CPU_ENVS, () => {
.then(done)
.catch(done.fail);
});

const quantizationTest =
(quantizationDtype: 'uint8'|'uint16', done: DoneFn) => {
const arrayType =
quantizationDtype === 'uint8' ? Uint8Array : Uint16Array;
setupFakeWeightFiles(
{'./weightfile0': new arrayType([0, 48, 255, 0, 48, 255])});

const manifest: WeightsManifestConfig = [{
'paths': ['weightfile0'],
'weights': [
{
'name': 'weight0',
'dtype': 'float32',
'shape': [3],
'quantization':
{'min': -1, 'scale': 0.1, 'dtype': quantizationDtype}
},
{
'name': 'weight1',
'dtype': 'int32',
'shape': [3],
'quantization':
{'min': -1, 'scale': 0.1, 'dtype': quantizationDtype}
}
]
}];

const weightsNamesToFetch = ['weight0', 'weight1'];
tf.loadWeights(manifest, './', weightsNamesToFetch)
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(weightsNamesToFetch.length);

const weight0 = weights['weight0'];
expectArraysClose(weight0, [-1, 3.8, 24.5]);
expect(weight0.shape).toEqual([3]);
expect(weight0.dtype).toEqual('float32');

const weight1 = weights['weight1'];
expectArraysEqual(weight1, [-1, 4, 25]);
expect(weight1.shape).toEqual([3]);
expect(weight1.dtype).toEqual('int32');
})
.then(done)
.catch(done.fail);
};

it('quantized weights (uint8)', done => {
quantizationTest('uint8', done);
});

it('quantized weights (uint16)', done => {
quantizationTest('uint16', done);
});

it('2 groups, 1 quantized, 1 unquantized', done => {
setupFakeWeightFiles({
'./weightfile0': new Uint8Array([0, 48, 255, 0, 48, 255]),
'./weightfile1': new Float32Array([6, 7, 8, 9])
});

const manifest: WeightsManifestConfig = [
{
'paths': ['weightfile0'],
'weights': [
{
'name': 'weight0',
'dtype': 'float32',
'shape': [3],
'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'}
},
{
'name': 'weight1',
'dtype': 'int32',
'shape': [3],
'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'}
}
]
},
{
'paths': ['weightfile1'],
'weights': [
{'name': 'weight2', 'dtype': 'float32', 'shape': [3, 1]},
{'name': 'weight3', 'dtype': 'float32', 'shape': []}
]
}
];

tf.loadWeights(manifest, './', ['weight0', 'weight2'])
.then(weights => {
// Both groups need to be fetched.
expect((window.fetch as jasmine.Spy).calls.count()).toBe(2);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(2);

const weight0 = weights['weight0'];
expectArraysClose(weight0, [-1, 3.8, 24.5]);
expect(weight0.shape).toEqual([3]);
expect(weight0.dtype).toEqual('float32');

const weight2 = weights['weight2'];
expectArraysClose(weight2, [6, 7, 8]);
expect(weight2.shape).toEqual([3, 1]);
expect(weight2.dtype).toEqual('float32');
})
.then(done)
.catch(done.fail);
});
});