@@ -295,13 +295,17 @@ static bool isODSReserved(StringRef str) {
295
295
// / (does not change the `name` if it already is suitable) and returns the
296
296
// / modified version.
297
297
static std::string sanitizeName (StringRef name) {
298
- std::string processed_str = name.str ();
298
+ std::string processedStr = name.str ();
299
+ std::replace_if (
300
+ processedStr.begin (), processedStr.end (),
301
+ [](char c) { return !llvm::isAlnum (c); }, ' _' );
299
302
300
- std::replace (processed_str.begin (), processed_str.end (), ' -' , ' _' );
303
+ if (llvm::isDigit (*processedStr.begin ()))
304
+ return " _" + processedStr;
301
305
302
- if (isPythonReserved (processed_str ) || isODSReserved (processed_str ))
303
- return processed_str + " _" ;
304
- return processed_str ;
306
+ if (isPythonReserved (processedStr ) || isODSReserved (processedStr ))
307
+ return processedStr + " _" ;
308
+ return processedStr ;
305
309
}
306
310
307
311
static std::string attrSizedTraitForKind (const char *kind) {
@@ -977,7 +981,6 @@ static void emitValueBuilder(const Operator &op,
977
981
llvm::SmallVector<std::string> functionArgs,
978
982
raw_ostream &os) {
979
983
auto name = sanitizeName (op.getOperationName ());
980
- iterator_range<llvm::SplittingIterator> splitName = llvm::split (name, " ." );
981
984
// Params with (possibly) default args.
982
985
auto valueBuilderParams =
983
986
llvm::map_range (functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -996,16 +999,16 @@ static void emitValueBuilder(const Operator &op,
996
999
auto lhs = *llvm::split (arg, " =" ).begin ();
997
1000
return (lhs + " =" + llvm::convertToSnakeFromCamelCase (lhs)).str ();
998
1001
});
999
- os << llvm::formatv (
1000
- valueBuilderTemplate,
1001
- // Drop dialect name and then sanitize again (to catch e.g. func.return).
1002
- sanitizeName ( llvm::join (++splitName. begin (), splitName. end (), " _ " ) ),
1003
- op. getCppClassName (), llvm::join (valueBuilderParams, " , " ),
1004
- llvm::join (opBuilderArgs, " , " ),
1005
- (op.getNumResults () > 1
1006
- ? " _Sequence[_ods_ir.OpResult]"
1007
- : (op.getNumResults () > 0 ? " _ods_ir.OpResult"
1008
- : " _ods_ir.Operation" )));
1002
+ std::string nameWithoutDialect =
1003
+ op. getOperationName (). substr (op. getOperationName (). find ( ' . ' ) + 1 );
1004
+ os << llvm::formatv (valueBuilderTemplate, sanitizeName (nameWithoutDialect),
1005
+ op. getCppClassName ( ),
1006
+ llvm::join (valueBuilderParams, " , " ),
1007
+ llvm::join (opBuilderArgs, " , " ),
1008
+ (op.getNumResults () > 1
1009
+ ? " _Sequence[_ods_ir.OpResult]"
1010
+ : (op.getNumResults () > 0 ? " _ods_ir.OpResult"
1011
+ : " _ods_ir.Operation" )));
1009
1012
}
1010
1013
1011
1014
// / Emits bindings for a specific Op to the given output stream.
0 commit comments