@@ -756,7 +756,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
756
756
Type asyncTokenType, ValueRange asyncDependencies,
757
757
TypeRange workgroupAttributions,
758
758
TypeRange privateAttributions, Value clusterSizeX,
759
- Value clusterSizeY, Value clusterSizeZ) {
759
+ Value clusterSizeY, Value clusterSizeZ,
760
+ FlatSymbolRefAttr module , FlatSymbolRefAttr function) {
760
761
OpBuilder::InsertionGuard g (builder);
761
762
762
763
// Add a WorkGroup attribution attribute. This attribute is required to
@@ -781,6 +782,12 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
781
782
if (dynamicSharedMemorySize)
782
783
result.addOperands (dynamicSharedMemorySize);
783
784
785
+ // Add optional module and function attributes.
786
+ if (module )
787
+ result.addAttribute (getModuleAttrName (), module );
788
+ if (function)
789
+ result.addAttribute (getFunctionAttrName (), function);
790
+
784
791
// Create a kernel body region with kNumConfigRegionAttributes + N memory
785
792
// attributions, where the first kNumConfigRegionAttributes arguments have
786
793
// `index` type and the rest have the same types as the data operands.
@@ -944,6 +951,21 @@ void LaunchOp::print(OpAsmPrinter &p) {
944
951
p << ' ' << getDynamicSharedMemorySizeKeyword () << ' '
945
952
<< getDynamicSharedMemorySize ();
946
953
954
+ // Print optional module attribute.
955
+ StringRef moduleAttrName = getModuleAttrName ();
956
+ if (auto module = getModule ()) {
957
+ printer << ' ' << moduleAttrName << ' (' ;
958
+ printer.printSymbolName (*module );
959
+ printer << ' )' ;
960
+ }
961
+ // Print optional function attribute.
962
+ StringRef functionAttrName = getFunctionAttrName ();
963
+ if (auto function = getFunction ()) {
964
+ printer << ' ' << functionAttrName << ' (' ;
965
+ printer.printSymbolName (*function);
966
+ printer << ' )' ;
967
+ }
968
+
947
969
printAttributions (p, getWorkgroupKeyword (), getWorkgroupAttributions ());
948
970
printAttributions (p, getPrivateKeyword (), getPrivateAttributions ());
949
971
@@ -952,7 +974,8 @@ void LaunchOp::print(OpAsmPrinter &p) {
952
974
p.printRegion (getBody (), /* printEntryBlockArgs=*/ false );
953
975
p.printOptionalAttrDict ((*this )->getAttrs (), /* elidedAttrs=*/ {
954
976
LaunchOp::getOperandSegmentSizeAttr (),
955
- getNumWorkgroupAttributionsAttrName ()});
977
+ getNumWorkgroupAttributionsAttrName (),
978
+ moduleAttrName, functionAttrName});
956
979
}
957
980
958
981
// Parse the size assignment blocks for blocks and threads. These have the form
@@ -990,6 +1013,9 @@ parseSizeAssignment(OpAsmParser &parser,
990
1013
// / `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
991
1014
// / `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
992
1015
// / `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1016
+ // / (`dynamic_shared_memory_size` ssa-id)?
1017
+ // / (`module(` symbol-ref-id `)`)?
1018
+ // / (`function(` symbol-ref-id `)`)?
993
1019
// / memory-attribution
994
1020
// / region attr-dict?
995
1021
// / ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
@@ -1060,6 +1086,27 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
1060
1086
return failure ();
1061
1087
}
1062
1088
1089
+ // Parse optional module attribute.
1090
+ StringRef moduleAttrName = getModuleAttrName ();
1091
+ if (succeeded (parser.parseOptionalKeyword (moduleAttrName))) {
1092
+ FlatSymbolRefAttr moduleSymbol;
1093
+ if (parser.parseLParen () ||
1094
+ parser.parseAttribute (moduleSymbol, Type (), moduleAttrName,
1095
+ result.attributes ) ||
1096
+ parser.parseRParen ())
1097
+ return failure ();
1098
+ }
1099
+ // Parse optional function attribute.
1100
+ StringRef functionAttrName = getFunctionAttrName ();
1101
+ if (succeeded (parser.parseOptionalKeyword (functionAttrName))) {
1102
+ FlatSymbolRefAttr funcSymbol;
1103
+ if (parser.parseLParen () ||
1104
+ parser.parseAttribute (funcSymbol, Type (), functionAttrName,
1105
+ result.attributes ) ||
1106
+ parser.parseRParen ())
1107
+ return failure ();
1108
+ }
1109
+
1063
1110
// Create the region arguments, it has kNumConfigRegionAttributes arguments
1064
1111
// that correspond to block/thread identifiers and grid/block sizes, all
1065
1112
// having `index` type, a variadic number of WorkGroup Attributions and
0 commit comments