Skip to content

Commit 15e92d0

Browse files
committed
rebase
1 parent 969b91a commit 15e92d0

File tree

4 files changed

+67
-76
lines changed

4 files changed

+67
-76
lines changed

mlir/python/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings(
5252
TD_FILE dialects/AffineOps.td
5353
SOURCES
5454
dialects/affine.py
55-
dialects/_affine_ops_ext.py
5655
DIALECT_NAME affine
5756
GEN_ENUM_BINDINGS)
5857

mlir/python/mlir/dialects/_affine_ops_ext.py

Lines changed: 0 additions & 56 deletions
This file was deleted.
Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,50 @@
1-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2-
# See https://llvm.org/LICENSE.txt for license information.
3-
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._affine_ops_gen import *
6+
from ._affine_ops_gen import _Dialect
7+
8+
try:
9+
from ..ir import *
10+
from ._ods_common import (
11+
get_op_result_or_value as _get_op_result_or_value,
12+
get_op_results_or_values as _get_op_results_or_values,
13+
_cext as _ods_cext,
14+
)
15+
except ImportError as e:
16+
raise RuntimeError("Error loading imports from extension module") from e
17+
18+
from typing import Optional, Sequence, Union
19+
20+
21+
@_ods_cext.register_operation(_Dialect, replace=True)
22+
class AffineStoreOp(AffineStoreOp):
23+
"""Specialization for the Affine store operation."""
24+
25+
def __init__(
26+
self,
27+
value: Union[Operation, OpView, Value],
28+
memref: Union[Operation, OpView, Value],
29+
map: AffineMap = None,
30+
*,
31+
map_operands=None,
32+
loc=None,
33+
ip=None,
34+
):
35+
"""Creates an affine store operation.
36+
37+
- `value`: the value to store into the memref.
38+
- `memref`: the buffer to store into.
39+
- `map`: the affine map that maps the map_operands to the index of the
40+
memref.
41+
- `map_operands`: the list of arguments to substitute the dimensions,
42+
then symbols in the affine map, in increasing order.
43+
"""
44+
map = map if map is not None else []
45+
map_operands = map_operands if map_operands is not None else []
46+
indicies = [_get_op_result_or_value(op) for op in map_operands]
47+
_ods_successors = None
48+
super().__init__(
49+
value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
50+
)

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,17 @@ static bool isODSReserved(StringRef str) {
295295
/// (does not change the `name` if it already is suitable) and returns the
296296
/// modified version.
297297
static std::string sanitizeName(StringRef name) {
298-
std::string processed_str = name.str();
298+
std::string processedStr = name.str();
299+
std::replace_if(
300+
processedStr.begin(), processedStr.end(),
301+
[](char c) { return !llvm::isAlnum(c); }, '_');
299302

300-
std::replace(processed_str.begin(), processed_str.end(), '-', '_');
303+
if (llvm::isDigit(*processedStr.begin()))
304+
return "_" + processedStr;
301305

302-
if (isPythonReserved(processed_str) || isODSReserved(processed_str))
303-
return processed_str + "_";
304-
return processed_str;
306+
if (isPythonReserved(processedStr) || isODSReserved(processedStr))
307+
return processedStr + "_";
308+
return processedStr;
305309
}
306310

307311
static std::string attrSizedTraitForKind(const char *kind) {
@@ -977,7 +981,6 @@ static void emitValueBuilder(const Operator &op,
977981
llvm::SmallVector<std::string> functionArgs,
978982
raw_ostream &os) {
979983
auto name = sanitizeName(op.getOperationName());
980-
iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
981984
// Params with (possibly) default args.
982985
auto valueBuilderParams =
983986
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -996,16 +999,16 @@ static void emitValueBuilder(const Operator &op,
996999
auto lhs = *llvm::split(arg, "=").begin();
9971000
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
9981001
});
999-
os << llvm::formatv(
1000-
valueBuilderTemplate,
1001-
// Drop dialect name and then sanitize again (to catch e.g. func.return).
1002-
sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
1003-
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
1004-
llvm::join(opBuilderArgs, ", "),
1005-
(op.getNumResults() > 1
1006-
? "_Sequence[_ods_ir.OpResult]"
1007-
: (op.getNumResults() > 0 ? "_ods_ir.OpResult"
1008-
: "_ods_ir.Operation")));
1002+
std::string nameWithoutDialect =
1003+
op.getOperationName().substr(op.getOperationName().find('.') + 1);
1004+
os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
1005+
op.getCppClassName(),
1006+
llvm::join(valueBuilderParams, ", "),
1007+
llvm::join(opBuilderArgs, ", "),
1008+
(op.getNumResults() > 1
1009+
? "_Sequence[_ods_ir.OpResult]"
1010+
: (op.getNumResults() > 0 ? "_ods_ir.OpResult"
1011+
: "_ods_ir.Operation")));
10091012
}
10101013

10111014
/// Emits bindings for a specific Op to the given output stream.

0 commit comments

Comments
 (0)