Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 31f0b07

Browse files
author
Nikhil Thorat
authored
Add more math operations. Highlight error lines in webgl shaders (#123)
* add more math functionality * add test for binary op * add tsdoc * remove check nan snippet * remove unary and binary custom sugar methods. * add cosh(0) and sinh(0) tests. check for nans in sin/cos, they fail on mac
1 parent efb099c commit 31f0b07

17 files changed

+805
-153
lines changed

src/math/math.ts

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,15 @@ export abstract class NDArrayMath {
790790
}
791791
protected abstract sqrtInternal<T extends NDArray>(ndarray: T): T;
792792

793+
/**
794+
* Computes absolute value element-wise.
795+
* @param ndarray The input NDArray.
796+
*/
797+
abs<T extends NDArray>(ndarray: T): T {
798+
return this.executeOp('abs', () => this.absInternal(ndarray));
799+
}
800+
protected abstract absInternal<T extends NDArray>(ndarray: T): T;
801+
793802
/**
794803
* Computes rectified linear element-wise, max(x, 0).
795804
* @param ndarray The input NDArray.
@@ -809,22 +818,86 @@ export abstract class NDArrayMath {
809818
protected abstract sigmoidInternal<T extends NDArray>(ndarray: T): T;
810819

811820
/**
812-
* Computes hyperbolic tangent of the input NDArray element-wise.
821+
* Computes sin of the input NDArray element-wise, y = sin(x).
813822
* @param ndarray The input NDArray.
814823
*/
815-
tanh<T extends NDArray>(ndarray: T): T {
816-
return this.executeOp('tanh', () => this.tanhInternal(ndarray));
824+
sin<T extends NDArray>(ndarray: T): T {
825+
return this.executeOp('sin', () => this.sinInternal(ndarray));
817826
}
818-
protected abstract tanhInternal<T extends NDArray>(ndarray: T): T;
827+
protected abstract sinInternal<T extends NDArray>(ndarray: T): T;
819828

820829
/**
821-
* Computes sin of the input NDArray element-wise, y = sin(x).
830+
* Computes cos of the input NDArray element-wise, y = cos(x).
822831
* @param ndarray The input NDArray.
823832
*/
824-
sin<T extends NDArray>(ndarray: T): T {
825-
return this.executeOp('sin', () => this.sinInternal(ndarray));
833+
cos<T extends NDArray>(ndarray: T): T {
834+
return this.executeOp('cos', () => this.cosInternal(ndarray));
826835
}
827-
protected abstract sinInternal<T extends NDArray>(ndarray: T): T;
836+
protected abstract cosInternal<T extends NDArray>(ndarray: T): T;
837+
838+
/**
839+
* Computes tan of the input NDArray element-wise, y = tan(x).
840+
* @param ndarray The input NDArray.
841+
*/
842+
tan<T extends NDArray>(ndarray: T): T {
843+
return this.executeOp('tan', () => this.tanInternal(ndarray));
844+
}
845+
protected abstract tanInternal<T extends NDArray>(ndarray: T): T;
846+
847+
/**
848+
* Computes asin of the input NDArray element-wise, y = asin(x).
849+
* @param ndarray The input NDArray.
850+
*/
851+
asin<T extends NDArray>(ndarray: T): T {
852+
return this.executeOp('asin', () => this.asinInternal(ndarray));
853+
}
854+
protected abstract asinInternal<T extends NDArray>(ndarray: T): T;
855+
856+
/**
857+
* Computes acos of the input NDArray element-wise, y = acos(x).
858+
* @param ndarray The input NDArray.
859+
*/
860+
acos<T extends NDArray>(ndarray: T): T {
861+
return this.executeOp('acos', () => this.acosInternal(ndarray));
862+
}
863+
protected abstract acosInternal<T extends NDArray>(ndarray: T): T;
864+
865+
/**
866+
* Computes atan of the input NDArray element-wise, y = atan(x).
867+
* @param ndarray The input NDArray.
868+
*/
869+
atan<T extends NDArray>(ndarray: T): T {
870+
return this.executeOp('atan', () => this.atanInternal(ndarray));
871+
}
872+
protected abstract atanInternal<T extends NDArray>(ndarray: T): T;
873+
874+
/**
875+
* Computes hyperbolic sin of the input NDArray element-wise, y = sinh(x).
876+
* @param ndarray The input NDArray.
877+
*/
878+
sinh<T extends NDArray>(ndarray: T): T {
879+
return this.executeOp('sinh', () => this.sinhInternal(ndarray));
880+
}
881+
protected abstract sinhInternal<T extends NDArray>(ndarray: T): T;
882+
883+
/**
884+
* Computes hyperbolic cos of the input NDArray element-wise, y = cosh(x).
885+
* @param ndarray The input NDArray.
886+
*/
887+
cosh<T extends NDArray>(ndarray: T): T {
888+
return this.executeOp('cosh', () => this.coshInternal(ndarray));
889+
}
890+
protected abstract coshInternal<T extends NDArray>(ndarray: T): T;
891+
892+
/**
893+
* Computes hyperbolic tangent of the input NDArray element-wise.
894+
* @param ndarray The input NDArray.
895+
*/
896+
tanh<T extends NDArray>(ndarray: T): T {
897+
return this.executeOp('tanh', () => this.tanhInternal(ndarray));
898+
}
899+
protected abstract tanhInternal<T extends NDArray>(ndarray: T): T;
900+
828901

829902
/**
830903
* Computes step of the input NDArray element-wise, y = 1 if x > 0 | 0 if x <=

src/math/math_cpu.ts

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,20 +337,20 @@ export class NDArrayMathCPU extends NDArrayMath {
337337
return NDArray.make<T>(ndarray.shape, {values: resultValues});
338338
}
339339

340-
protected sigmoidInternal<T extends NDArray>(ndarray: T): T {
340+
protected absInternal<T extends NDArray>(ndarray: T): T {
341341
const resultValues = new Float32Array(ndarray.size);
342342
const values = ndarray.getValues();
343343
for (let i = 0; i < values.length; ++i) {
344-
resultValues[i] = 1 / (1 + Math.exp(-values[i]));
344+
resultValues[i] = Math.abs(values[i]);
345345
}
346346
return NDArray.make<T>(ndarray.shape, {values: resultValues});
347347
}
348348

349-
protected tanhInternal<T extends NDArray>(ndarray: T): T {
349+
protected sigmoidInternal<T extends NDArray>(ndarray: T): T {
350350
const resultValues = new Float32Array(ndarray.size);
351351
const values = ndarray.getValues();
352352
for (let i = 0; i < values.length; ++i) {
353-
resultValues[i] = util.tanh(values[i]);
353+
resultValues[i] = 1 / (1 + Math.exp(-values[i]));
354354
}
355355
return NDArray.make<T>(ndarray.shape, {values: resultValues});
356356
}
@@ -364,6 +364,78 @@ export class NDArrayMathCPU extends NDArrayMath {
364364
return NDArray.make<T>(ndarray.shape, {values: resultValues});
365365
}
366366

367+
protected cosInternal<T extends NDArray>(ndarray: T): T {
368+
const resultValues = new Float32Array(ndarray.size);
369+
const values = ndarray.getValues();
370+
for (let i = 0; i < values.length; ++i) {
371+
resultValues[i] = Math.cos(values[i]);
372+
}
373+
return NDArray.make<T>(ndarray.shape, {values: resultValues});
374+
}
375+
376+
protected tanInternal<T extends NDArray>(ndarray: T): T {
377+
const resultValues = new Float32Array(ndarray.size);
378+
const values = ndarray.getValues();
379+
for (let i = 0; i < values.length; ++i) {
380+
resultValues[i] = Math.tan(values[i]);
381+
}
382+
return NDArray.make<T>(ndarray.shape, {values: resultValues});
383+
}
384+
385+
protected asinInternal<T extends NDArray>(ndarray: T): T {
386+
const resultValues = new Float32Array(ndarray.size);
387+
const values = ndarray.getValues();
388+
for (let i = 0; i < values.length; ++i) {
389+
resultValues[i] = Math.asin(values[i]);
390+
}
391+
return NDArray.make<T>(ndarray.shape, {values: resultValues});
392+
}
393+
394+
protected acosInternal<T extends NDArray>(ndarray: T): T {
395+
const resultValues = new Float32Array(ndarray.size);
396+
const values = ndarray.getValues();
397+
for (let i = 0; i < values.length; ++i) {
398+
resultValues[i] = Math.acos(values[i]);
399+
}
400+
return NDArray.make<T>(ndarray.shape, {values: resultValues});
401+
}
402+
403+
protected atanInternal<T extends NDArray>(ndarray: T): T {
404+
const resultValues = new Float32Array(ndarray.size);
405+
const values = ndarray.getValues();
406+
for (let i = 0; i < values.length; ++i) {
407+
resultValues[i] = Math.atan(values[i]);
408+
}
409+
return NDArray.make<T>(ndarray.shape, {values: resultValues});
410+
}
411+
412+
protected sinhInternal<T extends NDArray>(ndarray: T): T {
413+
const resultValues = new Float32Array(ndarray.size);
414+
const values = ndarray.getValues();
415+
for (let i = 0; i < values.length; ++i) {
416+
resultValues[i] = Math.sinh(values[i]);
417+
}
418+
return NDArray.make<T>(ndarray.shape, {values: resultValues});
419+
}
420+
421+
protected coshInternal<T extends NDArray>(ndarray: T): T {
422+
const resultValues = new Float32Array(ndarray.size);
423+
const values = ndarray.getValues();
424+
for (let i = 0; i < values.length; ++i) {
425+
resultValues[i] = Math.cosh(values[i]);
426+
}
427+
return NDArray.make<T>(ndarray.shape, {values: resultValues});
428+
}
429+
430+
protected tanhInternal<T extends NDArray>(ndarray: T): T {
431+
const resultValues = new Float32Array(ndarray.size);
432+
const values = ndarray.getValues();
433+
for (let i = 0; i < values.length; ++i) {
434+
resultValues[i] = util.tanh(values[i]);
435+
}
436+
return NDArray.make<T>(ndarray.shape, {values: resultValues});
437+
}
438+
367439
protected stepInternal<T extends NDArray>(ndarray: T): T {
368440
const resultValues = new Float32Array(ndarray.size);
369441
const values = ndarray.getValues();

src/math/math_gpu.ts

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {AddScaledMatProgram} from './webgl/addscaledmat_gpu';
2323
import {ArgMaxEqualsProgram} from './webgl/argmaxequals_gpu';
2424
import {ArgMinMaxProgram} from './webgl/argminmax_gpu';
2525
import {BatchNormProgram} from './webgl/batchnorm_gpu';
26+
import * as binaryop_gpu from './webgl/binaryop_gpu';
2627
import {BinaryOpProgram} from './webgl/binaryop_gpu';
2728
import {Concat3DProgram} from './webgl/concat3d_gpu';
2829
// tslint:disable-next-line:max-line-length
@@ -41,7 +42,8 @@ import {Pool2DProgram} from './webgl/pool_gpu';
4142
import {ReduceSumProgram} from './webgl/reducesum_gpu';
4243
import {ResizeBilinear3DProgram} from './webgl/resize_bilinear_gpu';
4344
import {TextureManager} from './webgl/texture_manager';
44-
import {UnaryOp, UnaryOpProgram} from './webgl/unaryop_gpu';
45+
import * as unary_op from './webgl/unaryop_gpu';
46+
import {UnaryOpProgram} from './webgl/unaryop_gpu';
4547
import * as webgl_util from './webgl/webgl_util';
4648

4749
export class NDArrayMathGPU extends NDArrayMath {
@@ -113,7 +115,7 @@ export class NDArrayMathGPU extends NDArrayMath {
113115
}
114116

115117
protected negInternal<T extends NDArray>(a: T): T {
116-
const program = new UnaryOpProgram(a.shape, UnaryOp.NEG);
118+
const program = new UnaryOpProgram(a.shape, unary_op.NEG);
117119
return this.compileAndRun<T, T>(program, [a]);
118120
}
119121

@@ -147,7 +149,7 @@ export class NDArrayMathGPU extends NDArrayMath {
147149
}
148150

149151
protected multiplyInternal<T extends NDArray>(a: T, b: T): T {
150-
const program = new BinaryOpProgram('*', a.shape, b.shape);
152+
const program = new BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape);
151153
return this.compileAndRun<T, T>(program, [a, b]);
152154
}
153155

@@ -219,17 +221,17 @@ export class NDArrayMathGPU extends NDArrayMath {
219221
}
220222

221223
protected divideInternal<T extends NDArray>(a: T, b: T): T {
222-
const program = new BinaryOpProgram('/', a.shape, b.shape);
224+
const program = new BinaryOpProgram(binaryop_gpu.DIV, a.shape, b.shape);
223225
return this.compileAndRun<NDArray, T>(program, [a, b]);
224226
}
225227

226228
protected addInternal<T extends NDArray>(a: T, b: T): T {
227-
const program = new BinaryOpProgram('+', a.shape, b.shape);
229+
const program = new BinaryOpProgram(binaryop_gpu.ADD, a.shape, b.shape);
228230
return this.compileAndRun<NDArray, T>(program, [a, b]);
229231
}
230232

231233
protected subInternal<T extends NDArray>(a: T, b: T): T {
232-
const program = new BinaryOpProgram('-', a.shape, b.shape);
234+
const program = new BinaryOpProgram(binaryop_gpu.SUB, a.shape, b.shape);
233235
return this.compileAndRun<NDArray, T>(program, [a, b]);
234236
}
235237

@@ -239,42 +241,83 @@ export class NDArrayMathGPU extends NDArrayMath {
239241
}
240242

241243
protected expInternal<T extends NDArray>(a: T): T {
242-
const program = new UnaryOpProgram(a.shape, UnaryOp.EXP);
244+
const program = new UnaryOpProgram(a.shape, unary_op.EXP);
243245
return this.compileAndRun(program, [a]);
244246
}
245247

246248
protected logInternal<T extends NDArray>(a: T): T {
247-
const program = new UnaryOpProgram(a.shape, UnaryOp.LOG);
249+
const program = new UnaryOpProgram(a.shape, unary_op.LOG);
248250
return this.compileAndRun(program, [a]);
249251
}
250252

251253
protected sqrtInternal<T extends NDArray>(a: T): T {
252-
const program = new UnaryOpProgram(a.shape, UnaryOp.SQRT);
254+
const program = new UnaryOpProgram(a.shape, unary_op.SQRT);
253255
return this.compileAndRun(program, [a]);
254256
}
255257

256258
protected reluInternal<T extends NDArray>(a: T): T {
257-
const program = new UnaryOpProgram(a.shape, UnaryOp.RELU);
259+
const program = new UnaryOpProgram(a.shape, unary_op.RELU);
260+
return this.compileAndRun(program, [a]);
261+
}
262+
263+
protected absInternal<T extends NDArray>(a: T): T {
264+
const program = new UnaryOpProgram(a.shape, unary_op.ABS);
258265
return this.compileAndRun(program, [a]);
259266
}
260267

261268
protected sigmoidInternal<T extends NDArray>(a: T): T {
262-
const program = new UnaryOpProgram(a.shape, UnaryOp.SIGMOID);
269+
const program = new UnaryOpProgram(a.shape, unary_op.SIGMOID);
263270
return this.compileAndRun<T, T>(program, [a]);
264271
}
265272

266-
protected tanhInternal<T extends NDArray>(a: T): T {
267-
const program = new UnaryOpProgram(a.shape, UnaryOp.TANH);
273+
protected sinInternal<T extends NDArray>(a: T): T {
274+
const program = new UnaryOpProgram(a.shape, unary_op.SIN);
268275
return this.compileAndRun(program, [a]);
269276
}
270277

271-
protected sinInternal<T extends NDArray>(a: T): T {
272-
const program = new UnaryOpProgram(a.shape, UnaryOp.SIN);
278+
protected cosInternal<T extends NDArray>(a: T): T {
279+
const program = new UnaryOpProgram(a.shape, unary_op.COS);
280+
return this.compileAndRun(program, [a]);
281+
}
282+
283+
protected tanInternal<T extends NDArray>(a: T): T {
284+
const program = new UnaryOpProgram(a.shape, unary_op.TAN);
285+
return this.compileAndRun(program, [a]);
286+
}
287+
288+
protected asinInternal<T extends NDArray>(a: T): T {
289+
const program = new UnaryOpProgram(a.shape, unary_op.ASIN);
290+
return this.compileAndRun(program, [a]);
291+
}
292+
293+
protected acosInternal<T extends NDArray>(a: T): T {
294+
const program = new UnaryOpProgram(a.shape, unary_op.ACOS);
273295
return this.compileAndRun(program, [a]);
274296
}
275297

298+
protected atanInternal<T extends NDArray>(a: T): T {
299+
const program = new UnaryOpProgram(a.shape, unary_op.ATAN);
300+
return this.compileAndRun(program, [a]);
301+
}
302+
303+
protected sinhInternal<T extends NDArray>(a: T): T {
304+
const program = new UnaryOpProgram(a.shape, unary_op.SINH);
305+
return this.compileAndRun(program, [a]);
306+
}
307+
308+
protected coshInternal<T extends NDArray>(a: T): T {
309+
const program = new UnaryOpProgram(a.shape, unary_op.COSH);
310+
return this.compileAndRun(program, [a]);
311+
}
312+
313+
protected tanhInternal<T extends NDArray>(a: T): T {
314+
const program = new UnaryOpProgram(a.shape, unary_op.TANH);
315+
return this.compileAndRun(program, [a]);
316+
}
317+
318+
276319
protected stepInternal<T extends NDArray>(a: T): T {
277-
const program = new UnaryOpProgram(a.shape, UnaryOp.STEP);
320+
const program = new UnaryOpProgram(a.shape, unary_op.STEP);
278321
return this.compileAndRun(program, [a]);
279322
}
280323

0 commit comments

Comments
 (0)