Skip to content

Commit b2c7e57

Browse files
committed
update syntax
1 parent 500996f commit b2c7e57

File tree

4 files changed

+80
-17
lines changed

4 files changed

+80
-17
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
804804
Optional<Index>:$clusterSizeY,
805805
Optional<Index>:$clusterSizeZ,
806806
Optional<I32>:$dynamicSharedMemorySize,
807-
OptionalAttr<FlatSymbolRefAttr>:$function,
808-
OptionalAttr<FlatSymbolRefAttr>:$module)>,
807+
OptionalAttr<FlatSymbolRefAttr>:$module,
808+
OptionalAttr<FlatSymbolRefAttr>:$function)>,
809809
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
810810
let summary = "GPU kernel launch operation";
811811

@@ -850,6 +850,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
850850
`blocks` `(` ssa-id-list `)` `in` ssa-reassignment
851851
`threads` `(` ssa-id-list `)` `in` ssa-reassignment
852852
(dynamic_shared_memory_size ssa-use)?
853+
(`module(` symbol-ref-id `)`)?
854+
(`function(` symbol-ref-id `)`)?
853855
memory-attribution
854856
region attr-dict?
855857
ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
@@ -907,6 +909,14 @@ def GPU_LaunchOp : GPU_Op<"launch", [
907909
// sizes are immediately usable inside body region.
908910
"some_op"(%cx, %bx, %tx) : (index, index, index) -> ()
909911
}
912+
913+
// Launch with module and function attributes.
914+
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2)
915+
threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5)
916+
module(@kernel_module) function(@kernel_func) {
917+
"some_op"(%bx, %tx) : (index, index) -> ()
918+
%42 = load %val1[%bx] : memref<?xf32, 1>
919+
}
910920
```
911921

912922
Rationale: using operation/block arguments gives analyses a clear way of
@@ -931,7 +941,9 @@ def GPU_LaunchOp : GPU_Op<"launch", [
931941
CArg<"TypeRange", "{}">:$privateAttributions,
932942
CArg<"Value", "nullptr">:$clusterSizeX,
933943
CArg<"Value", "nullptr">:$clusterSizeY,
934-
CArg<"Value", "nullptr">:$clusterSizeZ)>
944+
CArg<"Value", "nullptr">:$clusterSizeZ,
945+
CArg<"FlatSymbolRefAttr", "nullptr">:$module,
946+
CArg<"FlatSymbolRefAttr", "nullptr">:$function)>,
935947
];
936948

937949
let extraClassDeclaration = [{

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
756756
Type asyncTokenType, ValueRange asyncDependencies,
757757
TypeRange workgroupAttributions,
758758
TypeRange privateAttributions, Value clusterSizeX,
759-
Value clusterSizeY, Value clusterSizeZ) {
759+
Value clusterSizeY, Value clusterSizeZ,
760+
FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
760761
OpBuilder::InsertionGuard g(builder);
761762

762763
// Add a WorkGroup attribution attribute. This attribute is required to
@@ -781,6 +782,12 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
781782
if (dynamicSharedMemorySize)
782783
result.addOperands(dynamicSharedMemorySize);
783784

785+
// Add optional module and function attributes.
786+
if (module)
787+
result.addAttribute(getModuleAttrName(), module);
788+
if (function)
789+
result.addAttribute(getFunctionAttrName(), function);
790+
784791
// Create a kernel body region with kNumConfigRegionAttributes + N memory
785792
// attributions, where the first kNumConfigRegionAttributes arguments have
786793
// `index` type and the rest have the same types as the data operands.
@@ -944,6 +951,21 @@ void LaunchOp::print(OpAsmPrinter &p) {
944951
p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
945952
<< getDynamicSharedMemorySize();
946953

954+
// Print optional module attribute.
955+
StringRef moduleAttrName = getModuleAttrName();
956+
if (auto module = getModule()) {
957+
printer << ' ' << moduleAttrName << '(';
958+
printer.printSymbolName(*module);
959+
printer << ')';
960+
}
961+
// Print optional function attribute.
962+
StringRef functionAttrName = getFunctionAttrName();
963+
if (auto function = getFunction()) {
964+
printer << ' ' << functionAttrName << '(';
965+
printer.printSymbolName(*function);
966+
printer << ')';
967+
}
968+
947969
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
948970
printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
949971

@@ -952,7 +974,8 @@ void LaunchOp::print(OpAsmPrinter &p) {
952974
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
953975
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
954976
LaunchOp::getOperandSegmentSizeAttr(),
955-
getNumWorkgroupAttributionsAttrName()});
977+
getNumWorkgroupAttributionsAttrName(),
978+
moduleAttrName, functionAttrName});
956979
}
957980

958981
// Parse the size assignment blocks for blocks and threads. These have the form
@@ -990,6 +1013,9 @@ parseSizeAssignment(OpAsmParser &parser,
9901013
/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
9911014
/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
9921015
/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1016+
/// (`dynamic_shared_memory_size` ssa-id)?
1017+
/// (`module(` symbol-ref-id `)`)?
1018+
/// (`function(` symbol-ref-id `)`)?
9931019
/// memory-attribution
9941020
/// region attr-dict?
9951021
/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
@@ -1060,6 +1086,27 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
10601086
return failure();
10611087
}
10621088

1089+
// Parse optional module attribute.
1090+
StringRef moduleAttrName = getModuleAttrName();
1091+
if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1092+
FlatSymbolRefAttr moduleSymbol;
1093+
if (parser.parseLParen() ||
1094+
parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1095+
result.attributes) ||
1096+
parser.parseRParen())
1097+
return failure();
1098+
}
1099+
// Parse optional function attribute.
1100+
StringRef functionAttrName = getFunctionAttrName();
1101+
if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1102+
FlatSymbolRefAttr funcSymbol;
1103+
if (parser.parseLParen() ||
1104+
parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1105+
result.attributes) ||
1106+
parser.parseRParen())
1107+
return failure();
1108+
}
1109+
10631110
// Create the region arguments, it has kNumConfigRegionAttributes arguments
10641111
// that correspond to block/thread identifiers and grid/block sizes, all
10651112
// having `index` type, a variadic number of WorkGroup Attributions and

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ module attributes {gpu.container_module} {
1919

2020
// CHECK-LABEL:func @launch_with_module_func_attr(%{{.*}}: index)
2121
func.func @launch_with_module_func_attr(%sz : index) {
22-
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}})
22+
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) module(@test_module) function(@test_kernel_func)
2323
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
24-
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
24+
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz)
25+
module(@test_module) function(@test_kernel_func) {
2526
// CHECK: gpu.terminator
2627
gpu.terminator
27-
// CHECK: {function = @test_kernel_func, module = @existing_module}
28-
} {function = @test_kernel_func, module = @existing_module}
28+
}
2929
return
3030
}
3131

mlir/test/Dialect/GPU/outlining.mlir

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,11 @@ func.func @testKernelAttributes() {
523523
%bDimZ = arith.constant 8 : index
524524

525525
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
526-
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
526+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
527+
module(@test_module) function(@test_kernel_func) {
527528
"some_op"(%bx, %tx) : (index, index) -> ()
528529
gpu.terminator
529-
} {module = @test_module, function = @test_kernel_func}
530+
}
530531
return
531532
}
532533

@@ -556,10 +557,11 @@ func.func @testExistingModule() {
556557
%bDimZ = arith.constant 8 : index
557558

558559
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
559-
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
560+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
561+
module(@existing_module) function(@test_kernel_func) {
560562
"some_op"(%bx, %tx) : (index, index) -> ()
561563
gpu.terminator
562-
} {module = @existing_module, function = @test_kernel_func}
564+
}
563565
return
564566
}
565567

@@ -578,10 +580,11 @@ func.func @testKernelModuleOnly() {
578580
%bDimZ = arith.constant 8 : index
579581

580582
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
581-
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
583+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
584+
module(@test_module) {
582585
"some_op"(%bx, %tx) : (index, index) -> ()
583586
gpu.terminator
584-
} {module = @test_module}
587+
}
585588
return
586589
}
587590

@@ -601,10 +604,11 @@ func.func @testKernelFuncOnly() {
601604
%bDimZ = arith.constant 8 : index
602605

603606
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
604-
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
607+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
608+
function(@test_kernel_func) {
605609
"some_op"(%bx, %tx) : (index, index) -> ()
606610
gpu.terminator
607-
} {function = @test_kernel_func}
611+
}
608612
return
609613
}
610614

0 commit comments

Comments
 (0)