diff --git a/src/graph/optimizers/adam_optimizer.ts b/src/graph/optimizers/adam_optimizer.ts new file mode 100644 index 0000000000..b28578c095 --- /dev/null +++ b/src/graph/optimizers/adam_optimizer.ts @@ -0,0 +1,130 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NDArrayMath} from '../../math/math'; +import {NDArray, Scalar} from '../../math/ndarray'; +import {Node} from '../graph'; +import {SessionRuntime} from '../session'; +import {SummedTensorArrayMap, TensorArrayMap} from '../tensor_array_map'; + +import {Optimizer} from './optimizer'; + +export class AdamOptimizer extends Optimizer { + constructor( + protected learningRate: number, + private beta1: number, private beta2: number, + specifiedVariableList?: Node[]) { + super(learningRate, specifiedVariableList); + this.eps = Scalar.new(1e-8); + // b1, b2 keep initial value of beta* hyperparameters. + this.b1 = Scalar.new(this.beta1); + this.b2 = Scalar.new(this.beta2); + // accB* will be updated by batch. + this.accB1 = Scalar.new(this.beta1); + this.accB2 = Scalar.new(this.beta2); + } + + beforeBatch( + math: NDArrayMath, batchSize: number, runtime: SessionRuntime, + activationArrayMap: TensorArrayMap, + gradientArrayMap: SummedTensorArrayMap) { + super.beforeBatch( + math, batchSize, runtime, activationArrayMap, gradientArrayMap); + + if (this.firstMoment.size() === 0) { + this.variableNodes.forEach(node => { + this.firstMoment.set(node.output, NDArray.zeros(node.output.shape)); + }); + } + + if (this.secondMoment.size() === 0) { + this.variableNodes.forEach(node => { + this.secondMoment.set(node.output, NDArray.zeros(node.output.shape)); + }); + } + } + + afterBatch( + math: NDArrayMath, batchSize: number, runtime: SessionRuntime, + activationArrayMap: TensorArrayMap, + gradientArrayMap: SummedTensorArrayMap) { + math.scope((keep) => { + this.variableNodes.forEach(node => { + const oldVariable = activationArrayMap.get(node.output); + const gradient = this.variableGradients.get(node.output); + + const oldFirstMoment = this.firstMoment.get(node.output); + const oldSecondMoment = this.secondMoment.get(node.output); + + const newFirstMoment = math.scaledArrayAdd( + this.b1, oldFirstMoment, math.sub(this.one, this.b1), gradient); + const gradientSquare = math.multiply(gradient, gradient); + const newSecondMoment = math.scaledArrayAdd( + this.b2, oldSecondMoment, math.sub(this.one, this.b2), + gradientSquare); + + const biasCorrectedFirstMoment = math.divide( + newFirstMoment, math.sub(this.one, this.accB1)); + const biasCorrectedSecondMoment = math.divide( + newSecondMoment, math.sub(this.one, this.accB2)); + + const variable = math.scaledArrayAdd( + this.c, math.divide(biasCorrectedFirstMoment, + math.add(math.sqrt(biasCorrectedSecondMoment), this.eps)), + this.one, oldVariable); + activationArrayMap.set(node.output, keep(variable)); + node.data = variable; + + this.firstMoment.set(node.output, keep(newFirstMoment)); + this.secondMoment.set(node.output, keep(newSecondMoment)); + + oldVariable.dispose(); + gradient.dispose(); + oldFirstMoment.dispose(); + oldSecondMoment.dispose(); + }); + // accB* represents beta1 and beta2 to + // the power t (the number of iteration). + this.accB1 = keep(math.multiply(this.accB1, this.b1)); + this.accB2 = keep(math.multiply(this.accB2, this.b2)); + }); + + this.variableGradients.dispose(); + this.variableGradients = new TensorArrayMap(); + } + + dispose() { + super.dispose(); + this.firstMoment.dispose(); + this.secondMoment.dispose(); + this.eps.dispose(); + this.b1.dispose(); + this.b2.dispose(); + this.accB1.dispose(); + this.accB2.dispose(); + } + + // Average of gradient + private firstMoment = new TensorArrayMap(); + // Average of squared gradient + private secondMoment = new TensorArrayMap(); + private eps: Scalar; + private b1: Scalar; + private b2: Scalar; + private accB1: Scalar; + private accB2: Scalar; +} diff --git a/src/graph/session_test.ts b/src/graph/session_test.ts index f24334456a..cc02a2bb84 100644 --- a/src/graph/session_test.ts +++ b/src/graph/session_test.ts @@ -27,6 +27,7 @@ import {MomentumOptimizer} from './optimizers/momentum_optimizer'; import {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; import {SGDOptimizer} from './optimizers/sgd_optimizer'; import {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; +import {AdamOptimizer} from './optimizers/adam_optimizer'; import {FeedDictionary, FeedEntry, Session} from './session'; describe('FeedDictionary', () => { @@ -498,4 +499,63 @@ describe('Session', () => { dydw2, new Float32Array([-.4, -.8]), 2e-5); }); }); + + it('adam', () => { + const x = g.placeholder('x', [2]); + const w = g.variable('w', NDArray.zeros([1, 2])); + const b = g.variable('b', NDArray.zeros([1])); + const y = g.reduceSum(g.add(g.matmul(w, x), b)); + + const safeMode = true; + const optimizer = new AdamOptimizer(0.1, 0.8, 0.9); + const math = new NDArrayMathCPU(safeMode); + const session = new Session(g, math); + const inputProvider: InputProvider = { + getNextCopy() { + return Array1D.new([2, 4]); + }, + disposeCopy(math, example) {} + }; + + math.scope(() => { + // w = reduce_sum(w_1*x_1 + w_2*x_2 + b) + // new_first_m = [beta1*old_first_m_w1 + (1-beta1)*grad_w1, + // beta1*old_first_m_w2 + (1-beta1)*grad_w2] + // = [.4, .8] + // new_second_m = [beta2*old_second_m_w1 + (1-beta2)*grad_w1**2, + // beta2*old_second_m_w2 + (1-beta2)*grad_w2**2] + // = [.4, 1.6] + // m = [new_first_m/(1-acc_beta1)] = [2, 4] + // v = [new_second_m/(1-acc_beta2)] = [4, 16] + // updates = [m_1/(sqrt(v_1) + eps), + // m_2/(sqrt(v_2) + eps)] + // = [1.0, 1.0] + // w = [ w1_old - lr*updates_1, w2_old - lr*updates_2] + // = [-0.1, -0.1] + // + session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer); + const dydw = session.activationArrayMap.get(w).getValues(); + test_util.expectArraysClose( + dydw, new Float32Array([-0.1, -0.1]), 1e-5); + // new_first_m = [beta1*old_first_m_w1 + (1-beta1)*grad_w1, + // beta1*old_first_m_w2 + (1-beta1)*grad_w2] + // = [0.8*0.4 + 0.2*2, 0.8*0.8 + 0.2*4] + // = [0.72, 1.44] + // new_second_m = [beta2*old_second_m_w1 + (1-beta2)*grad_w1**2, + // beta2*old_second_m_w2 + (1-beta2)*grad_w2**2] + // = [0.9*0.4 + 0.1*4, 0.9*1.6+0.1*16] + // = [0.76, 3.04] + // m = [new_first_m/(1-acc_beta1)] = [2, 4] + // v = [new_second_m/(1-acc_beta2)] = [4, 16] + // updates = [m_1/sqrt(v_1) + eps, + // m_2/sqrt(v_2) + eps] + // = [1.0, 1.0] + // w = [ w1_old - lr*updates_1, w2_old - lr*updates_2] + // = [-0.2, -0.2] + session.train(y, [{tensor: x, data: inputProvider}], 1, optimizer); + const dydw2 = session.activationArrayMap.get(w).getValues(); + test_util.expectArraysClose( + dydw2, new Float32Array([-.2, -.2]), 2e-5); + }); + }); }); diff --git a/src/index.ts b/src/index.ts index ec32db1b91..98a4ce6c5d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -36,6 +36,7 @@ export {MomentumOptimizer} from './graph/optimizers/momentum_optimizer'; export {Optimizer} from './graph/optimizers/optimizer'; export {RMSPropOptimizer} from './graph/optimizers/rmsprop_optimizer'; export {SGDOptimizer} from './graph/optimizers/sgd_optimizer'; +export {AdamOptimizer} from './graph/optimizers/adam_optimizer'; export {CostReduction, FeedEntry, Session} from './graph/session'; // tslint:disable-next-line:max-line-length export {GraphRunner, GraphRunnerEventObserver, MetricReduction} from './graph_runner';