@@ -23,6 +23,7 @@ import {AddScaledMatProgram} from './webgl/addscaledmat_gpu';
23
23
import { ArgMaxEqualsProgram } from './webgl/argmaxequals_gpu' ;
24
24
import { ArgMinMaxProgram } from './webgl/argminmax_gpu' ;
25
25
import { BatchNormProgram } from './webgl/batchnorm_gpu' ;
26
+ import * as binaryop_gpu from './webgl/binaryop_gpu' ;
26
27
import { BinaryOpProgram } from './webgl/binaryop_gpu' ;
27
28
import { Concat3DProgram } from './webgl/concat3d_gpu' ;
28
29
// tslint:disable-next-line:max-line-length
@@ -41,7 +42,8 @@ import {Pool2DProgram} from './webgl/pool_gpu';
41
42
import { ReduceSumProgram } from './webgl/reducesum_gpu' ;
42
43
import { ResizeBilinear3DProgram } from './webgl/resize_bilinear_gpu' ;
43
44
import { TextureManager } from './webgl/texture_manager' ;
44
- import { UnaryOp , UnaryOpProgram } from './webgl/unaryop_gpu' ;
45
+ import * as unary_op from './webgl/unaryop_gpu' ;
46
+ import { UnaryOpProgram } from './webgl/unaryop_gpu' ;
45
47
import * as webgl_util from './webgl/webgl_util' ;
46
48
47
49
export class NDArrayMathGPU extends NDArrayMath {
@@ -113,7 +115,7 @@ export class NDArrayMathGPU extends NDArrayMath {
113
115
}
114
116
115
117
protected negInternal < T extends NDArray > ( a : T ) : T {
116
- const program = new UnaryOpProgram ( a . shape , UnaryOp . NEG ) ;
118
+ const program = new UnaryOpProgram ( a . shape , unary_op . NEG ) ;
117
119
return this . compileAndRun < T , T > ( program , [ a ] ) ;
118
120
}
119
121
@@ -147,7 +149,7 @@ export class NDArrayMathGPU extends NDArrayMath {
147
149
}
148
150
149
151
protected multiplyInternal < T extends NDArray > ( a : T , b : T ) : T {
150
- const program = new BinaryOpProgram ( '*' , a . shape , b . shape ) ;
152
+ const program = new BinaryOpProgram ( binaryop_gpu . MUL , a . shape , b . shape ) ;
151
153
return this . compileAndRun < T , T > ( program , [ a , b ] ) ;
152
154
}
153
155
@@ -219,17 +221,17 @@ export class NDArrayMathGPU extends NDArrayMath {
219
221
}
220
222
221
223
protected divideInternal < T extends NDArray > ( a : T , b : T ) : T {
222
- const program = new BinaryOpProgram ( '/' , a . shape , b . shape ) ;
224
+ const program = new BinaryOpProgram ( binaryop_gpu . DIV , a . shape , b . shape ) ;
223
225
return this . compileAndRun < NDArray , T > ( program , [ a , b ] ) ;
224
226
}
225
227
226
228
protected addInternal < T extends NDArray > ( a : T , b : T ) : T {
227
- const program = new BinaryOpProgram ( '+' , a . shape , b . shape ) ;
229
+ const program = new BinaryOpProgram ( binaryop_gpu . ADD , a . shape , b . shape ) ;
228
230
return this . compileAndRun < NDArray , T > ( program , [ a , b ] ) ;
229
231
}
230
232
231
233
protected subInternal < T extends NDArray > ( a : T , b : T ) : T {
232
- const program = new BinaryOpProgram ( '-' , a . shape , b . shape ) ;
234
+ const program = new BinaryOpProgram ( binaryop_gpu . SUB , a . shape , b . shape ) ;
233
235
return this . compileAndRun < NDArray , T > ( program , [ a , b ] ) ;
234
236
}
235
237
@@ -239,42 +241,83 @@ export class NDArrayMathGPU extends NDArrayMath {
239
241
}
240
242
241
243
protected expInternal < T extends NDArray > ( a : T ) : T {
242
- const program = new UnaryOpProgram ( a . shape , UnaryOp . EXP ) ;
244
+ const program = new UnaryOpProgram ( a . shape , unary_op . EXP ) ;
243
245
return this . compileAndRun ( program , [ a ] ) ;
244
246
}
245
247
246
248
protected logInternal < T extends NDArray > ( a : T ) : T {
247
- const program = new UnaryOpProgram ( a . shape , UnaryOp . LOG ) ;
249
+ const program = new UnaryOpProgram ( a . shape , unary_op . LOG ) ;
248
250
return this . compileAndRun ( program , [ a ] ) ;
249
251
}
250
252
251
253
protected sqrtInternal < T extends NDArray > ( a : T ) : T {
252
- const program = new UnaryOpProgram ( a . shape , UnaryOp . SQRT ) ;
254
+ const program = new UnaryOpProgram ( a . shape , unary_op . SQRT ) ;
253
255
return this . compileAndRun ( program , [ a ] ) ;
254
256
}
255
257
256
258
protected reluInternal < T extends NDArray > ( a : T ) : T {
257
- const program = new UnaryOpProgram ( a . shape , UnaryOp . RELU ) ;
259
+ const program = new UnaryOpProgram ( a . shape , unary_op . RELU ) ;
260
+ return this . compileAndRun ( program , [ a ] ) ;
261
+ }
262
+
263
+ protected absInternal < T extends NDArray > ( a : T ) : T {
264
+ const program = new UnaryOpProgram ( a . shape , unary_op . ABS ) ;
258
265
return this . compileAndRun ( program , [ a ] ) ;
259
266
}
260
267
261
268
protected sigmoidInternal < T extends NDArray > ( a : T ) : T {
262
- const program = new UnaryOpProgram ( a . shape , UnaryOp . SIGMOID ) ;
269
+ const program = new UnaryOpProgram ( a . shape , unary_op . SIGMOID ) ;
263
270
return this . compileAndRun < T , T > ( program , [ a ] ) ;
264
271
}
265
272
266
- protected tanhInternal < T extends NDArray > ( a : T ) : T {
267
- const program = new UnaryOpProgram ( a . shape , UnaryOp . TANH ) ;
273
+ protected sinInternal < T extends NDArray > ( a : T ) : T {
274
+ const program = new UnaryOpProgram ( a . shape , unary_op . SIN ) ;
268
275
return this . compileAndRun ( program , [ a ] ) ;
269
276
}
270
277
271
- protected sinInternal < T extends NDArray > ( a : T ) : T {
272
- const program = new UnaryOpProgram ( a . shape , UnaryOp . SIN ) ;
278
+ protected cosInternal < T extends NDArray > ( a : T ) : T {
279
+ const program = new UnaryOpProgram ( a . shape , unary_op . COS ) ;
280
+ return this . compileAndRun ( program , [ a ] ) ;
281
+ }
282
+
283
+ protected tanInternal < T extends NDArray > ( a : T ) : T {
284
+ const program = new UnaryOpProgram ( a . shape , unary_op . TAN ) ;
285
+ return this . compileAndRun ( program , [ a ] ) ;
286
+ }
287
+
288
+ protected asinInternal < T extends NDArray > ( a : T ) : T {
289
+ const program = new UnaryOpProgram ( a . shape , unary_op . ASIN ) ;
290
+ return this . compileAndRun ( program , [ a ] ) ;
291
+ }
292
+
293
+ protected acosInternal < T extends NDArray > ( a : T ) : T {
294
+ const program = new UnaryOpProgram ( a . shape , unary_op . ACOS ) ;
273
295
return this . compileAndRun ( program , [ a ] ) ;
274
296
}
275
297
298
+ protected atanInternal < T extends NDArray > ( a : T ) : T {
299
+ const program = new UnaryOpProgram ( a . shape , unary_op . ATAN ) ;
300
+ return this . compileAndRun ( program , [ a ] ) ;
301
+ }
302
+
303
+ protected sinhInternal < T extends NDArray > ( a : T ) : T {
304
+ const program = new UnaryOpProgram ( a . shape , unary_op . SINH ) ;
305
+ return this . compileAndRun ( program , [ a ] ) ;
306
+ }
307
+
308
+ protected coshInternal < T extends NDArray > ( a : T ) : T {
309
+ const program = new UnaryOpProgram ( a . shape , unary_op . COSH ) ;
310
+ return this . compileAndRun ( program , [ a ] ) ;
311
+ }
312
+
313
+ protected tanhInternal < T extends NDArray > ( a : T ) : T {
314
+ const program = new UnaryOpProgram ( a . shape , unary_op . TANH ) ;
315
+ return this . compileAndRun ( program , [ a ] ) ;
316
+ }
317
+
318
+
276
319
protected stepInternal < T extends NDArray > ( a : T ) : T {
277
- const program = new UnaryOpProgram ( a . shape , UnaryOp . STEP ) ;
320
+ const program = new UnaryOpProgram ( a . shape , unary_op . STEP ) ;
278
321
return this . compileAndRun ( program , [ a ] ) ;
279
322
}
280
323
0 commit comments