Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ export const createMatmulProgramInfo =

const components = isVec4 ? 4 : 1;
const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components];
const aShapeOrRank = aShapeTemp.length;
const aRank = aShapeTemp.length;
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
const bShapeOrRank = bShapeTemp.length;
const bRank = bShapeTemp.length;
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
Expand All @@ -467,12 +467,12 @@ export const createMatmulProgramInfo =
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));

const getShaderSource = (shaderHelper: ShaderHelper) => {
const batchShapeOrRank = outerDims.length;
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1);
const batchRank = outerDims.length;
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);

const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
const A = inputVariable('a', inputs[0].dataType, aRank, components);
const B = inputVariable('b', inputs[1].dataType, bRank, components);
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
const inputVariables = [A, B];
if (hasBias) {
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';

export interface BatchNormAttributes extends AttributeWithCacheKey {
readonly epsilon: number;
Expand Down Expand Up @@ -61,7 +61,7 @@ const createBatchNormInferenceProgramInfo =
const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1;
const outputSize = ShapeUtil.size(yShape) / components;
// Only support uniforms for opset version >= 9 (spatial = true).
const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial;
const useShapesUniforms = spatial;
const shapeOrRank = useShapesUniforms ? yShape.length : yShape;
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components);
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents);
Expand Down
37 changes: 14 additions & 23 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

type BuiltinFunctionName = string;
type BinaryCustomExpression = (expressionA: string, expressionB: string) => string;
Expand All @@ -18,8 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{
const createBinaryOpProgramShader =
(shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[],
vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall,
typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean,
additionalImplementation?: string) => {
typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => {
let expressionScalar: BinaryCustomExpression;
let expressionVector: BinaryCustomExpression;
if (typeof funcCall === 'string') {
Expand All @@ -31,12 +30,9 @@ const createBinaryOpProgramShader =
expressionVector = funcCall.vector;
}

const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA;
const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB;
const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput;
const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4);
const a = inputVariable('aData', typeA, inputAShapeOrRank, 4);
const b = inputVariable('bData', typeB, inputBShapeOrRank, 4);
const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4);
const a = inputVariable('aData', typeA, dimsA.length, 4);
const b = inputVariable('bData', typeB, dimsB.length, 4);

let assignment: string;
if (vectorize) {
Expand Down Expand Up @@ -169,30 +165,25 @@ const createBinaryOpProgramInfo =
vectorize = true;
}
cacheKeyAux.push(vectorize);
const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) &&
enableShapesUniforms(outputShape.length);

return {
name,
shaderCache: {
hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'),
inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'],
inputDependencies: ['rank', 'rank'],
},
getShaderSource: (shaderHelper) => createBinaryOpProgramShader(
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall,
a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation),
a.dataType, b.dataType, outputDataType, additionalImplementation),
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
programUniforms: useShapesUniforms ?
[
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
...createTensorShapeVariables(a.dims),
...createTensorShapeVariables(b.dims),
...createTensorShapeVariables(outputShape),
] :
[
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
],
programUniforms: [
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
...createTensorShapeVariables(a.dims),
...createTensorShapeVariables(b.dims),
...createTensorShapeVariables(outputShape),
],
}),
};
};
Expand Down
3 changes: 0 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,3 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
}
return dims;
};

// TODO: remove this when all related uses have been removed.
export const enableShapesUniforms = (_rank: number): boolean => true;
26 changes: 8 additions & 18 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';

export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
Expand Down Expand Up @@ -94,32 +94,22 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputShapeOrRanks = [];
const enableInputShapesUniforms = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length));
inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims);
inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]);
inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims');
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
if (enableInputShapesUniforms[i]) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
}

const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
if (enableOutputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));

const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
const output = outputVariable('output', dataType, outputShapeOrRank);

const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
Expand Down
31 changes: 10 additions & 21 deletions js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';

import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

export interface EinsumAttributes extends AttributeWithCacheKey {
readonly equation: string;
Expand Down Expand Up @@ -181,14 +180,12 @@ class EinsumEquation {
const appendMax = (name: string): string => name + '_max';

const createEinsumProgramInfo =
(enableInputShapesUniforms: readonly boolean[], inputShapes: Array<readonly number[]>, dataType: number,
einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => {
const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims);
const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank));
(inputShapes: Array<readonly number[]>, dataType: number, einsumEquation: EinsumEquation,
outputShape: readonly number[]): ProgramInfo => {
const ranks = inputShapes.map((dims) => dims.length);
const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
const outputSize = ShapeUtil.size(outputShape);
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
const output = outputVariable('output', dataType, outputShapeOrRank);
const output = outputVariable('output', dataType, outputShape.length);
const uniformsSymbols =
[...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol));
const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -269,10 +266,7 @@ const createEinsumProgramInfo =
};
return {
name: 'Einsum',
shaderCache: {
hint: einsumEquation.equation,
inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims')
},
shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')},
getRunData: () => {
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
// filter is added to make sure that dimValue is never 0.
Expand All @@ -281,12 +275,9 @@ const createEinsumProgramInfo =
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
programUniformsInit.push({type: 'uint32', data: outputSize});
const programUniforms: ProgramUniform[] =
inputShapes.filter((_, index) => enableInputShapesUniforms[index])
.map((dims, _) => [...createTensorShapeVariables(dims)])
inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)])
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
if (enableOutputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(outputShape));
}
programUniforms.push(...createTensorShapeVariables(outputShape));
return ({
outputs: [{dims: outputShape, dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
Expand All @@ -299,11 +290,9 @@ const createEinsumProgramInfo =

export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length));
const outputShape = einsumEquation.outputDims;
const inputShapes = context.inputs.map((input, _) => input.dims);
context.compute(createEinsumProgramInfo(
enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
};

export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
Expand Down
25 changes: 8 additions & 17 deletions js/web/lib/wasm/jsep/webgpu/ops/expand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
Expand Down Expand Up @@ -49,15 +49,9 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
const components = dataType === DataType.bool ? 4 : 1;
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);

const enableInputShapeUniform = enableShapesUniforms(inputShape.length);
const enableOutputShapeUniform = enableShapesUniforms(outputShape.length);


const getShaderSource = (shaderHelper: ShaderHelper) => {
const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape;
const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape;
const input = inputVariable('input', dataType, inputShapeOrRank, components);
const output = outputVariable('output', dataType, outputShapeOrRank, components);
const input = inputVariable('input', dataType, inputShape.length, components);
const output = outputVariable('output', dataType, outputShape.length, components);
let assignment: string;
if (dataType === DataType.bool) {
const singleAssignment = (resStr: string, x: number, typeCast = '') => `
Expand Down Expand Up @@ -90,16 +84,13 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
${assignment}`;
};

const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
if (enableInputShapeUniform) {
programUniforms.push(...createTensorShapeVariables(inputShape));
}
if (enableOutputShapeUniform) {
programUniforms.push(...createTensorShapeVariables(outputShape));
}
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape),
...createTensorShapeVariables(outputShape)
];
return {
name: 'Expand',
shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']},
shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']},
getShaderSource,
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
Expand Down
Loading