Skip to content

Commit ccbe264

Browse files
authored
[js/webgpu] Add LeakyRelu activation for fusedConv (#19369)
### Description This PR 1) adds LeakyRelu activation for fusedConv; 2) makes `vec4<f16>` value work with `float32` uniforms attributes. For example: `clamp(value, vec4<f16>(uniforms.clip_min), vec4<f16>(uniforms.clip_max)` will throw compilation errors since `uniforms.clip_min` and `uniforms.clip_min` are `f32` not `f16`. So we need to change it to `clamp(value, vec4<f16>(f16(uniforms.clip_min)), vec4<f16>(f16(uniforms.clip_max))` And above problem was introduced when we make activation attributes as uniforms instead of constant. BTW, after adding LeakyRelu, `realesrgan-t256` model can pass.
1 parent 50806a7 commit ccbe264

File tree

6 files changed

+184
-25
lines changed

6 files changed

+184
-25
lines changed

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ const conv2dCommonSnippet =
130130
isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType);
131131
const bType =
132132
isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType);
133-
const applyActivation = getActivationSnippet(attributes, resType);
133+
const applyActivation = getActivationSnippet(attributes, resType, dataType);
134134
const userCode = `
135135
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} {
136136
${isChannelsLast ? sampleX : sampleW}

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,8 @@ export const createMatmulProgramInfo =
479479
const uniforms: UniformsArrayType =
480480
[{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}];
481481
appendActivationUniforms(activationAttributes, uniforms);
482-
const applyActivation = getActivationSnippet(activationAttributes, output.type.value);
482+
const baseType = tensorTypeToWsglStorageType(output.type.tensor);
483+
const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType);
483484
const declareFunctions = matMulReadWriteFnSource(
484485
components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims],
485486
isChannelsLast);

js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts

+5-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
66
import {ShapeUtil} from '../../util';
77
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
88

9-
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
9+
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
1010
import {calculateOutputShape, ConvAttributes} from './conv';
1111
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils';
1212

@@ -45,7 +45,8 @@ export const createGroupedConvProgramInfo =
4545

4646
const getShaderSource = (shaderHelper: ShaderHelper) => {
4747
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
48-
const applyActivation = getActivationSnippet(attributes, output.type.value);
48+
const baseType = tensorTypeToWsglStorageType(output.type.tensor);
49+
const applyActivation = getActivationSnippet(attributes, output.type.value, baseType);
4950
const x = inputVariable('x', inputs[0].dataType, xShape.length);
5051
const w = inputVariable('w', inputs[1].dataType, wShape.length);
5152
const inputVars = [x, w];
@@ -136,7 +137,8 @@ export const createGroupedConvVectorizeProgramInfo =
136137
const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1];
137138
const getShaderSource = (shaderHelper: ShaderHelper) => {
138139
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
139-
const applyActivation = getActivationSnippet(attributes, output.type.value);
140+
const baseType = tensorTypeToWsglStorageType(output.type.tensor);
141+
const applyActivation = getActivationSnippet(attributes, output.type.value, baseType);
140142
const x = inputVariable('x', inputs[0].dataType, xShape.length, components);
141143
const w = inputVariable('w', inputs[1].dataType, wShape.length, components);
142144
const inputVars = [x, w];

js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts

+29-18
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,28 @@ export interface InternalActivationAttributes {
1515
readonly beta?: number;
1616
}
1717

18-
export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => {
19-
switch (attributes.activation) {
20-
case 'Relu':
21-
return `value = max(value, ${valueType}(0.0));`;
22-
case 'Sigmoid':
23-
return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`;
24-
case 'Clip':
25-
return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`;
26-
case 'HardSigmoid':
27-
return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${valueType}(uniforms.alpha) * value + ${
28-
valueType}(uniforms.beta)));`;
29-
case '':
30-
return '';
31-
// TODO: adding other activations that can be fused.
32-
default:
33-
throw new Error(`Unsupported activation ${attributes.activation}`);
34-
}
35-
};
18+
export const getActivationSnippet =
19+
(attributes: InternalActivationAttributes, valueType: string, baseType = 'f32'): string => {
20+
switch (attributes.activation) {
21+
case 'Relu':
22+
return `value = max(value, ${valueType}(0.0));`;
23+
case 'Sigmoid':
24+
return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`;
25+
case 'Clip':
26+
return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${
27+
baseType}(uniforms.clip_max)));`;
28+
case 'HardSigmoid':
29+
return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${
30+
baseType}(uniforms.beta)));`;
31+
case 'LeakyRelu':
32+
return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`;
33+
case '':
34+
return '';
35+
// TODO: adding other activations that can be fused.
36+
default:
37+
throw new Error(`Unsupported activation ${attributes.activation}`);
38+
}
39+
};
3640

3741
export const appendActivationUniformsData =
3842
(attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => {
@@ -42,6 +46,8 @@ export const appendActivationUniformsData =
4246
} else if (attributes.activation === 'HardSigmoid') {
4347
programUniform.push(
4448
{type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!});
49+
} else if (attributes.activation === 'LeakyRelu') {
50+
programUniform.push({type: DataType.float, data: attributes.alpha!});
4551
}
4652
};
4753

@@ -50,6 +56,8 @@ export const appendActivationUniforms = (attributes: InternalActivationAttribute
5056
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
5157
} else if (attributes.activation === 'HardSigmoid') {
5258
uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'});
59+
} else if (attributes.activation === 'LeakyRelu') {
60+
uniforms.push({name: 'alpha', type: 'f32'});
5361
}
5462
};
5563

@@ -62,6 +70,9 @@ export const parseInternalActivationAttributes =
6270
} else if (activation === 'Clip') {
6371
const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP];
6472
return {activation, clipMax, clipMin};
73+
} else if (activation === 'LeakyRelu') {
74+
const [alpha] = attributes?.activation_params as [number] || [0.01];
75+
return {activation, alpha};
6576
}
6677
return {activation};
6778
};

js/web/lib/wasm/jsep/webgpu/ops/matmul.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util';
77
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
88

99
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
10-
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common';
10+
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
1111
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils';
1212

1313
export const createNaiveMatmulProgramInfo =
@@ -45,7 +45,8 @@ export const createNaiveMatmulProgramInfo =
4545
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
4646
const b = inputVariable('b', inputs[1].dataType, bShape.length, components);
4747
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
48-
const applyActivation = getActivationSnippet(activationAttributes, output.type.value);
48+
const baseType = tensorTypeToWsglStorageType(output.type.tensor);
49+
const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType);
4950
const inputVariables = [a, b];
5051
let processBias = '';
5152
if (hasBias) {

js/web/test/data/ops/fused-conv.jsonc

+144
Original file line numberDiff line numberDiff line change
@@ -286,5 +286,149 @@
286286
]
287287
}
288288
]
289+
},
290+
{
291+
"name": "fused group-conv with LeakyRelu",
292+
"operator": "FusedConv",
293+
"attributes": [
294+
{ "name": "activation", "data": "LeakyRelu", "type": "string" },
295+
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
296+
{ "name": "group", "data": 3, "type": "int" },
297+
{ "name": "activation_params", "data": [2.0], "type": "floats" }
298+
],
299+
"opset": { "domain": "com.microsoft", "version": 1 },
300+
"cases": [
301+
{
302+
"name": "T[0]",
303+
"inputs": [
304+
{
305+
"data": [
306+
0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0,
307+
18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0
308+
],
309+
"dims": [1, 3, 3, 3],
310+
"type": "float32"
311+
},
312+
{
313+
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
314+
"dims": [3, 1, 2, 2],
315+
"type": "float32"
316+
}
317+
],
318+
"outputs": [
319+
{
320+
"data": [9, -6, 51, 47, -170, -10, 251, 229, 847, 889, 973, 1015],
321+
"dims": [1, 3, 2, 2],
322+
"type": "float32"
323+
}
324+
]
325+
}
326+
]
327+
},
328+
{
329+
"name": "NHWC group-conv with LeakyRelu",
330+
"operator": "Conv",
331+
"attributes": [
332+
{ "name": "activation", "data": "LeakyRelu", "type": "string" },
333+
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
334+
{ "name": "group", "data": 3, "type": "int" },
335+
{ "name": "activation_params", "data": [2.0], "type": "floats" }
336+
],
337+
"opset": { "domain": "com.ms.internal.nhwc", "version": 1 },
338+
"cases": [
339+
{
340+
"name": "T[0]",
341+
"inputs": [
342+
{
343+
"data": [
344+
0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0,
345+
18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0
346+
],
347+
"dims": [1, 3, 3, 3],
348+
"type": "float32"
349+
},
350+
{
351+
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
352+
"dims": [3, 1, 2, 2],
353+
"type": "float32"
354+
}
355+
],
356+
"outputs": [
357+
{
358+
"data": [-162, 63, -158, 33, 281, 85, 105, 337, 455, 177, 515, 609],
359+
"dims": [1, 2, 2, 3],
360+
"type": "float32"
361+
}
362+
]
363+
}
364+
]
365+
},
366+
{
367+
"name": "fused conv with LeakyRelu",
368+
"operator": "FusedConv",
369+
"attributes": [
370+
{ "name": "activation", "data": "LeakyRelu", "type": "string" },
371+
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
372+
{ "name": "activation_params", "data": [2.0], "type": "floats" }
373+
],
374+
"opset": { "domain": "com.microsoft", "version": 1 },
375+
"cases": [
376+
{
377+
"name": "T[0]",
378+
"inputs": [
379+
{
380+
"data": [10, 20, -30, -40, -50, -60, 70, 80, 90],
381+
"dims": [1, 1, 3, 3],
382+
"type": "float32"
383+
},
384+
{
385+
"data": [1, 2, 3, 4],
386+
"dims": [1, 1, 2, 2],
387+
"type": "float32"
388+
}
389+
],
390+
"outputs": [
391+
{
392+
"data": [-540, -860, 390, 430],
393+
"dims": [1, 1, 2, 2],
394+
"type": "float32"
395+
}
396+
]
397+
}
398+
]
399+
},
400+
{
401+
"name": "NHWC conv with LeakyRelu",
402+
"operator": "Conv",
403+
"attributes": [
404+
{ "name": "activation", "data": "LeakyRelu", "type": "string" },
405+
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
406+
{ "name": "activation_params", "data": [2.0], "type": "floats" }
407+
],
408+
"opset": { "domain": "com.ms.internal.nhwc", "version": 1 },
409+
"cases": [
410+
{
411+
"name": "T[0]",
412+
"inputs": [
413+
{
414+
"data": [10, 20, -30, -40, -50, -60, 70, 80, 90],
415+
"dims": [1, 3, 3, 1],
416+
"type": "float32"
417+
},
418+
{
419+
"data": [1, 2, 3, 4],
420+
"dims": [1, 1, 2, 2],
421+
"type": "float32"
422+
}
423+
],
424+
"outputs": [
425+
{
426+
"data": [-540, -860, 390, 430],
427+
"dims": [1, 2, 2, 1],
428+
"type": "float32"
429+
}
430+
]
431+
}
432+
]
289433
}
290434
]

0 commit comments

Comments
 (0)