diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md index efc1f044e1f1f..b9c0a45bf4677 100644 --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -380,6 +380,11 @@ template. The string can be an arbitrary C++ expression that evaluates into some C++ object expected at the `NativeCodeCall` site (here it would be expecting an array attribute). Typically the string should be a function call. +In the case of properties, the return value of the `NativeCodeCall` should +be in terms of the _interface_ type of a property. For example, the `NativeCodeCall` +for a `StringProp` should return a `StringRef`, which will copied into the underlying +`std::string`, just as if it were an argument to the operation's builder. + ##### `NativeCodeCall` placeholders In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`. @@ -416,14 +421,20 @@ must be either passed by reference or pointer to the variable used as argument so that the matched value can be returned. In the same example, `$val` will be bound to a variable with `Attribute` type (as `I32Attr`) and the type of the second argument in `Foo()` could be `Attribute&` or `Attribute*`. Names with -attribute constraints will be captured as `Attribute`s while everything else -will be treated as `Value`s. +attribute constraints will be captured as `Attribute`s, names with +property constraints (which must have a concrete interface type) will be treated +as that type, and everything else will be treated as `Value`s. Positional placeholders will be substituted by the `dag` object parameters at the `NativeCodeCall` use site. For example, if we define `SomeCall : NativeCodeCall<"someFn($1, $2, $0)">` and use it like `(SomeCall $in0, $in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2, $in0)`. +In the case of properties, the placeholder will be bound to a value of the _interface_ +type of the property. For example, passing in a `StringProp` as an argument to a `NativeCodeCall` will pass a `StringRef` (as if the getter of the matched +operation were called) and not a `std::string`. See +`mlir/include/mlir/IR/Properties.td` for details on interface vs. storage type. + Positional range placeholders will be substituted by multiple `dag` object parameters at the `NativeCodeCall` use site. For example, if we define `SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0, diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td index 25a45489c7b53..cae9cce2531ca 100644 --- a/mlir/include/mlir/IR/Properties.td +++ b/mlir/include/mlir/IR/Properties.td @@ -400,6 +400,21 @@ class ConfinedProperty : ConfinedProp, Deprecated<"moved to shorter name ConfinedProp">; +/// Defines a constant value of type `prop` to be used in pattern matching. +/// When used as a constraint, forms a matcher that tests that the property is +/// equal to the given value (and matches any other constraints on the property). +/// The constant value is given as a string and should be of the _interface_ type +/// of the attribute. +/// +/// This requires that the given property's inference type be comparable to the +/// given value with `==`, and does require specify a concrete property type. +class ConstantProp + : ConfinedProp, + "constant '" # prop.summary # "': " # val> { + string value = val; +} + //===----------------------------------------------------------------------===// // Primitive property combinators //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 1c9e128f0a0fb..49b2dae62dc22 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -73,12 +73,23 @@ class DagLeaf { // specifies an attribute constraint. bool isAttrMatcher() const; + // Returns true if this DAG leaf is matching a property. That is, it + // specifies a property constraint. + bool isPropMatcher() const; + + // Returns true if this DAG leaf is describing a property. That is, it + // is a subclass of `Property` in tablegen. + bool isPropDefinition() const; + // Returns true if this DAG leaf is wrapping native code call. bool isNativeCodeCall() const; // Returns true if this DAG leaf is specifying a constant attribute. bool isConstantAttr() const; + // Returns true if this DAG leaf is specifying a constant property. + bool isConstantProp() const; + // Returns true if this DAG leaf is specifying an enum case. bool isEnumCase() const; @@ -88,9 +99,19 @@ class DagLeaf { // Returns this DAG leaf as a constraint. Asserts if fails. Constraint getAsConstraint() const; + // Returns this DAG leaf as a property constraint. Asserts if fails. This + // allows access to the interface type. + PropConstraint getAsPropConstraint() const; + + // Returns this DAG leaf as a property definition. Asserts if fails. + Property getAsProperty() const; + // Returns this DAG leaf as an constant attribute. Asserts if fails. ConstantAttr getAsConstantAttr() const; + // Returns this DAG leaf as an constant property. Asserts if fails. + ConstantProp getAsConstantProp() const; + // Returns this DAG leaf as an enum case. // Precondition: isEnumCase() EnumCase getAsEnumCase() const; @@ -279,6 +300,10 @@ class SymbolInfoMap { // the DAG of the operation, `operandIndexOrNumValues` specifies the // operand index, and `variadicSubIndex` must be set to `std::nullopt`. // + // * Properties not associated with an operation (e.g. as arguments to + // native code) have their corresponding PropConstraint stored in the + // `dag` field. This constraint is only used when + // // * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG // of the parent operation, `operandIndexOrNumValues` specifies the // declared operand index of the variadic operand in the parent @@ -364,12 +389,20 @@ class SymbolInfoMap { // What kind of entity this symbol represents: // * Attr: op attribute + // * Prop: op property // * Operand: op operand // * Result: op result // * Value: a value not attached to an op (e.g., from NativeCodeCall) // * MultipleValues: a pack of values not attached to an op (e.g., from // NativeCodeCall). This kind supports indexing. - enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues }; + enum class Kind : uint8_t { + Attr, + Prop, + Operand, + Result, + Value, + MultipleValues + }; // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr` // and `Operand` so should be std::nullopt for `Result` and `Value` kind. @@ -384,6 +417,15 @@ class SymbolInfoMap { static SymbolInfo getAttr() { return SymbolInfo(nullptr, Kind::Attr, std::nullopt); } + static SymbolInfo getProp(const Operator *op, int index) { + return SymbolInfo(op, Kind::Prop, + DagAndConstant(nullptr, index, std::nullopt)); + } + static SymbolInfo getProp(const PropConstraint *constraint) { + // -1 for anthe `operandIndexOrNumValues` is a sentinel value. + return SymbolInfo(nullptr, Kind::Prop, + DagAndConstant(constraint, -1, std::nullopt)); + } static SymbolInfo getOperand(DagNode node, const Operator *op, int operandIndex, std::optional variadicSubIndex = std::nullopt) { @@ -488,6 +530,10 @@ class SymbolInfoMap { // is already bound. bool bindAttr(StringRef symbol); + // Registers the given `symbol` as bound to a property that satisfies the + // given `constraint`. `constraint` must name a concrete interface type. + bool bindProp(StringRef symbol, const PropConstraint &constraint); + // Returns true if the given `symbol` is bound. bool contains(StringRef symbol) const; diff --git a/mlir/include/mlir/TableGen/Property.h b/mlir/include/mlir/TableGen/Property.h index 6af96f077efe5..81e6d85720829 100644 --- a/mlir/include/mlir/TableGen/Property.h +++ b/mlir/include/mlir/TableGen/Property.h @@ -32,9 +32,9 @@ class Pred; // Wrapper class providing helper methods for accesing property constraint // values. class PropConstraint : public Constraint { +public: using Constraint::Constraint; -public: static bool classof(const Constraint *c) { return c->getKind() == CK_Prop; } StringRef getInterfaceType() const; @@ -143,6 +143,10 @@ class Property : public PropConstraint { // property constraints, this function is added for future-proofing) Property getBaseProperty() const; + // Returns true if this property is backed by a TableGen definition and that + // definition is a subclass of `className`. + bool isSubClassOf(StringRef className) const; + private: // Elements describing a Property, in general fetched from the record. StringRef summary; @@ -169,6 +173,21 @@ struct NamedProperty { Property prop; }; +// Wrapper class providing helper methods for processing constant property +// values defined using the `ConstantProp` subclass of `Property` +// in TableGen. +class ConstantProp : public Property { +public: + explicit ConstantProp(const llvm::DefInit *def) : Property(def) { + assert(isSubClassOf("ConstantProp")); + } + + static bool classof(Property *p) { return p->isSubClassOf("ConstantProp"); } + + // Return the constant value of the property as an expression + // that produces an interface-type constant. + StringRef getValue() const; +}; } // namespace tblgen } // namespace mlir diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp index 4ce6ab1dbfce5..2c119fd680b69 100644 --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -205,10 +205,14 @@ static ::llvm::LogicalResult {0}( /// Code for a pattern type or attribute constraint. /// -/// {3}: "Type type" or "Attribute attr". -static const char *const patternAttrOrTypeConstraintCode = R"( +/// {0}: name of function +/// {1}: Condition template +/// {2}: Constraint summary +/// {3}: "::mlir::Type type" or "::mlirAttribute attr" or "propType prop". +/// Can be "T prop" for generic property constraints. +static const char *const patternConstraintCode = R"( static ::llvm::LogicalResult {0}( - ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3}, + ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, {3}, ::llvm::StringRef failureStr) { if (!({1})) { return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { @@ -265,15 +269,31 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() { FmtContext ctx; ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type"); for (auto &it : typeConstraints) { - os << formatv(patternAttrOrTypeConstraintCode, it.second, + os << formatv(patternConstraintCode, it.second, tgfmt(it.first.getConditionTemplate(), &ctx), - escapeString(it.first.getSummary()), "Type type"); + escapeString(it.first.getSummary()), "::mlir::Type type"); } ctx.withSelf("attr"); for (auto &it : attrConstraints) { - os << formatv(patternAttrOrTypeConstraintCode, it.second, + os << formatv(patternConstraintCode, it.second, tgfmt(it.first.getConditionTemplate(), &ctx), - escapeString(it.first.getSummary()), "Attribute attr"); + escapeString(it.first.getSummary()), + "::mlir::Attribute attr"); + } + ctx.withSelf("prop"); + for (auto &it : propConstraints) { + PropConstraint propConstraint = cast(it.first); + StringRef interfaceType = propConstraint.getInterfaceType(); + // Constraints that are generic over multiple interface types are + // templatized under the assumption that they'll be used correctly. + if (interfaceType.empty()) { + interfaceType = "T"; + os << "template "; + } + os << formatv(patternConstraintCode, it.second, + tgfmt(propConstraint.getConditionTemplate(), &ctx), + escapeString(propConstraint.getSummary()), + Twine(interfaceType) + " prop"); } } @@ -367,10 +387,15 @@ void StaticVerifierFunctionEmitter::collectOpConstraints( void StaticVerifierFunctionEmitter::collectPatternConstraints( const ArrayRef constraints) { for (auto &leaf : constraints) { - assert(leaf.isOperandMatcher() || leaf.isAttrMatcher()); - collectConstraint( - leaf.isOperandMatcher() ? typeConstraints : attrConstraints, - leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint()); + assert(leaf.isOperandMatcher() || leaf.isAttrMatcher() || + leaf.isPropMatcher()); + Constraint constraint = leaf.getAsConstraint(); + if (leaf.isOperandMatcher()) + collectConstraint(typeConstraints, "type", constraint); + else if (leaf.isAttrMatcher()) + collectConstraint(attrConstraints, "attr", constraint); + else if (leaf.isPropMatcher()) + collectConstraint(propConstraints, "prop", constraint); } } diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 13541de66578d..1a1a58ad271bb 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -51,6 +51,16 @@ bool DagLeaf::isAttrMatcher() const { return isSubClassOf("AttrConstraint"); } +bool DagLeaf::isPropMatcher() const { + // Property matchers specify a property constraint. + return isSubClassOf("PropConstraint"); +} + +bool DagLeaf::isPropDefinition() const { + // Property matchers specify a property definition. + return isSubClassOf("Property"); +} + bool DagLeaf::isNativeCodeCall() const { return isSubClassOf("NativeCodeCall"); } @@ -59,14 +69,26 @@ bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); } +bool DagLeaf::isConstantProp() const { return isSubClassOf("ConstantProp"); } + bool DagLeaf::isStringAttr() const { return isa(def); } Constraint DagLeaf::getAsConstraint() const { - assert((isOperandMatcher() || isAttrMatcher()) && - "the DAG leaf must be operand or attribute"); + assert((isOperandMatcher() || isAttrMatcher() || isPropMatcher()) && + "the DAG leaf must be operand, attribute, or property"); return Constraint(cast(def)->getDef()); } +PropConstraint DagLeaf::getAsPropConstraint() const { + assert(isPropMatcher() && "the DAG leaf must be a property matcher"); + return PropConstraint(cast(def)->getDef()); +} + +Property DagLeaf::getAsProperty() const { + assert(isPropDefinition() && "the DAG leaf must be a property definition"); + return Property(cast(def)->getDef()); +} + ConstantAttr DagLeaf::getAsConstantAttr() const { assert(isConstantAttr() && "the DAG leaf must be constant attribute"); return ConstantAttr(cast(def)); @@ -77,6 +99,11 @@ EnumCase DagLeaf::getAsEnumCase() const { return EnumCase(cast(def)); } +ConstantProp DagLeaf::getAsConstantProp() const { + assert(isConstantProp() && "the DAG leaf must be a constant property value"); + return ConstantProp(cast(def)); +} + std::string DagLeaf::getConditionTemplate() const { return getAsConstraint().getConditionTemplate(); } @@ -232,6 +259,7 @@ SymbolInfoMap::SymbolInfo::SymbolInfo( int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { switch (kind) { case Kind::Attr: + case Kind::Prop: case Kind::Operand: case Kind::Value: return 1; @@ -258,6 +286,18 @@ std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { // TODO(suderman): Use a more exact type when available. return "::mlir::Attribute"; } + case Kind::Prop: { + if (op) + return cast(op->getArg(getArgIndex())) + ->prop.getInterfaceType() + .str(); + assert(dagAndConstant && dagAndConstant->dag && + "generic properties must carry their constraint"); + return reinterpret_cast(dagAndConstant->dag) + ->getAsPropConstraint() + .getInterfaceType() + .str(); + } case Kind::Operand: { // Use operand range for captured operands (to support potential variadic // operands). @@ -300,6 +340,12 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( LLVM_DEBUG(dbgs() << repl << " (Attr)\n"); return std::string(repl); } + case Kind::Prop: { + assert(index < 0); + auto repl = formatv(fmt, name); + LLVM_DEBUG(dbgs() << repl << " (Prop)\n"); + return std::string(repl); + } case Kind::Operand: { assert(index < 0); auto *operand = cast(op->getArg(getArgIndex())); @@ -388,10 +434,11 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': "); switch (kind) { case Kind::Attr: + case Kind::Prop: case Kind::Operand: { assert(index < 0 && "only allowed for symbol bound to result"); auto repl = formatv(fmt, name); - LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n"); + LLVM_DEBUG(dbgs() << repl << " (Operand/Attr/Prop)\n"); return std::string(repl); } case Kind::Result: { @@ -449,9 +496,11 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, PrintFatalError(loc, error); } - auto symInfo = - isa(op.getArg(argIndex)) - ? SymbolInfo::getAttr(&op, argIndex) + Argument arg = op.getArg(argIndex); + SymbolInfo symInfo = + isa(arg) ? SymbolInfo::getAttr(&op, argIndex) + : isa(arg) + ? SymbolInfo::getProp(&op, argIndex) : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex); std::string key = symbol.str(); @@ -503,6 +552,13 @@ bool SymbolInfoMap::bindAttr(StringRef symbol) { return symbolInfoMap.count(inserted->first) == 1; } +bool SymbolInfoMap::bindProp(StringRef symbol, + const PropConstraint &constraint) { + auto inserted = + symbolInfoMap.emplace(symbol.str(), SymbolInfo::getProp(&constraint)); + return symbolInfoMap.count(inserted->first) == 1; +} + bool SymbolInfoMap::contains(StringRef symbol) const { return find(symbol) != symbolInfoMap.end(); } @@ -774,10 +830,23 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, if (!treeArgName.empty() && treeArgName != "_") { DagLeaf leaf = tree.getArgAsLeaf(i); - // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), + // In (NativeCodeCall<"Foo($_self, $0, $1, $2, $3)"> I8Attr:$a, I8:$b, + // $c, I8Prop:$d), if (leaf.isUnspecified()) { // This is case of $c, a Value without any constraints. verifyBind(infoMap.bindValue(treeArgName), treeArgName); + } else if (leaf.isPropMatcher()) { + // This is case of $d, a binding to a certain property. + auto propConstraint = leaf.getAsPropConstraint(); + if (propConstraint.getInterfaceType().empty()) { + PrintFatalError(&def, + formatv("binding symbol '{0}' in NativeCodeCall to " + "a property constraint without specifying " + "that constraint's type is unsupported", + treeArgName)); + } + verifyBind(infoMap.bindProp(treeArgName, propConstraint), + treeArgName); } else { auto constraint = leaf.getAsConstraint(); bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() || diff --git a/mlir/lib/TableGen/Property.cpp b/mlir/lib/TableGen/Property.cpp index 9a70c1b6e8d62..47f43267cd197 100644 --- a/mlir/lib/TableGen/Property.cpp +++ b/mlir/lib/TableGen/Property.cpp @@ -112,3 +112,11 @@ Property Property::getBaseProperty() const { } return *this; } + +bool Property::isSubClassOf(StringRef className) const { + return def && def->isSubClassOf(className); +} + +StringRef ConstantProp::getValue() const { + return def->getValueAsString("value"); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 79bcd9c2e0a9a..dbedcfb6079d7 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3455,6 +3455,38 @@ def OpWithPropertyPredicates : TEST_Op<"op_with_property_predicates"> { let assemblyFormat = "attr-dict prop-dict"; } +def TestPropPatternOp1 : TEST_Op<"prop_pattern_op_1"> { + let arguments = (ins + StringProp:$tag, + I64Prop:$val, + BoolProp:$cond + ); + let results = (outs I32:$results); + let assemblyFormat = "$tag $val $cond attr-dict"; +} + +def TestPropPatternOp2 : TEST_Op<"prop_pattern_op_2"> { + let arguments = (ins + I32:$input, + StringProp:$tag + ); + let assemblyFormat = "$input $tag attr-dict"; +} + +def : Pat< + (TestPropPatternOp1 $tag, NonNegativeI64Prop:$val, ConstantProp), + (TestPropPatternOp1 $tag, (NativeCodeCall<"$0 + 1"> $val), ConstantProp)>; + +def : Pat< + (TestPropPatternOp2 (TestPropPatternOp1 $tag1, $val, ConstantProp), + PropConstraint, "non-empty string">:$tag2), + (TestPropPatternOp2 + (TestPropPatternOp1 $tag1, + (NativeCodeCall<"-($0)"> $val), + ConstantProp), + (NativeCodeCall<"$0.str() + \".\" + $1.str()"> $tag1, $tag2)) +>; + //===----------------------------------------------------------------------===// // Test Dataflow //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index 27598fb63a6c8..bd55338618eec 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -586,7 +586,7 @@ func.func @testMatchMultiVariadicSubSymbol(%arg0: i32, %arg1: i32, %arg2: i32, % // CHECK-LABEL: @testMatchMixedVaradicOptional func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () { - // CHECK: "test.mixed_variadic_in6"(%arg0, %arg1, %arg2) <{attr1 = 2 : i32}> : (i32, i32, i32) -> () + // CHECK: "test.mixed_variadic_in6"(%arg0, %arg1, %arg2) <{attr1 = 2 : i32}> : (i32, i32, i32) -> () "test.mixed_variadic_optional_in7"(%arg0, %arg1, %arg2) {attr1 = 2 : i32, operandSegmentSizes = array} : (i32, i32, i32) -> () // CHECK: test.mixed_variadic_optional_in7 "test.mixed_variadic_optional_in7"(%arg0, %arg1) {attr1 = 2 : i32, operandSegmentSizes = array} : (i32, i32) -> () @@ -594,6 +594,32 @@ func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %ar return } +//===----------------------------------------------------------------------===// +// Test patterns that operate on properties +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @testSimplePropertyRewrite +func.func @testSimplePropertyRewrite() { + // CHECK-NEXT: test.prop_pattern_op_1 "o1" 2 true + test.prop_pattern_op_1 "o1" 1 false + // Pattern not applied when predicate not met + // CHECK-NEXT: test.prop_pattern_op_1 "o2" -1 false + test.prop_pattern_op_1 "o2" -1 false + // Pattern not applied when constant doesn't match + // CHCEK-NEXT: test.prop_pattern_op_1 "o3" 1 true + test.prop_pattern_op_1 "o3" 1 true + return +} + +// CHECK-LABEL: @testNestedPropertyRewrite +func.func @testNestedPropertyRewrite() { + // CHECK: %[[v:.*]] = test.prop_pattern_op_1 "s1" -2 false + // CHECK: test.prop_pattern_op_2 %[[v]] "s1.t1" + %v = test.prop_pattern_op_1 "s1" 1 false + test.prop_pattern_op_2 %v "t1" + return +} + //===----------------------------------------------------------------------===// // Test that natives calls are only called once during rewrites. //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td index fc36a51789ec2..40af548b140ff 100644 --- a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td +++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td @@ -45,3 +45,35 @@ def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>; // CHECK: tblgen_values.push_back((*x.getODSResults(0).begin())); // CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present(y); // CHECK: tblgen_AOp_0 = rewriter.create(odsLoc, tblgen_types, tblgen_values, tblgen_props); + +// Note: These use strings to pick up a non-trivial storage/interface type +// difference. +def COp : NS_Op<"c_op", []> { + let arguments = (ins + I32:$x, + StringProp:$y + ); + + let results = (outs I32:$z); +} + +def DOp : NS_Op<"d_op", []> { + let arguments = (ins + StringProp:$y + ); + + let results = (outs I32:$z); +} +def test2 : Pat<(COp (DOp:$x $y), $_), (COp $x, $y)>; +// CHECK-LABEL: struct test2 +// CHECK: ::llvm::LogicalResult matchAndRewrite +// CHECK-DAG: ::llvm::StringRef y; +// CHECK-DAG: test::DOp x; +// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; +// CHECK: tblgen_ops.push_back(op0); +// CHECK: x = castedOp1; +// CHECK: tblgen_prop = castedOp1.getProperties().getY(); +// CHECK: y = tblgen_prop; +// CHECK: tblgen_ops.push_back(op1); +// CHECK: test::COp::Properties tblgen_props; +// CHECK: tblgen_props.setY(y); diff --git a/mlir/test/mlir-tblgen/rewriter-static-matcher.td b/mlir/test/mlir-tblgen/rewriter-static-matcher.td index c5debf5104bbb..99a15921dab2e 100644 --- a/mlir/test/mlir-tblgen/rewriter-static-matcher.td +++ b/mlir/test/mlir-tblgen/rewriter-static-matcher.td @@ -45,6 +45,24 @@ def DOp : NS_Op<"d_op", []> { def Foo : NativeCodeCall<"foo($_builder, $0)">; +def NonNegProp : PropConstraint= 0">, "non-negative integer">; + +def EOp : NS_Op<"e_op", []> { + let arguments = (ins + I32Prop:$x, + I64Prop:$y, + AnyInteger:$z + ); + let results = (outs I32:$res); +} + +def FOp: NS_Op<"f_op", []> { + let arguments = (ins + I32Prop:$a, + AnyInteger:$b + ); +} + // Test static matcher for duplicate DagNode // --- @@ -52,9 +70,16 @@ def Foo : NativeCodeCall<"foo($_builder, $0)">; // CHECK-NEXT: {{.*::mlir::Type type}} // CHECK: static ::llvm::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]]( // CHECK-NEXT: {{.*::mlir::Attribute attr}} +// CHECK: template +// CHECK-NEXT: static ::llvm::LogicalResult [[$PROP_CONSTRAINT:__mlir_ods_local_prop_constraint.*]]( +// CHECK-NEXT: {{.*T prop}} // CHECK: static ::llvm::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]]( // CHECK: if(::mlir::failed([[$ATTR_CONSTRAINT]] // CHECK: if(::mlir::failed([[$TYPE_CONSTRAINT]] +// CHECK: static ::llvm::LogicalResult [[$DAG_MATCHER2:static_dag_matcher.*]]( +// CHECK-SAME: int32_t &x +// CHECK: if(::mlir::failed([[$PROP_CONSTRAINT]] +// CHECK: if(::mlir::failed([[$TYPE_CONSTRAINT]] // CHECK: if(::mlir::failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)), @@ -68,3 +93,11 @@ def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)), // CHECK: ::llvm::SmallVector<::mlir::Value, 4> [[$ARR:tblgen_variadic_values_.*]]; // CHECK: [[$ARR]].push_back([[$VAR]]); def : Pat<(AOp $x), (DOp (variadic (Foo $x)))>; + +// CHECK: if(::mlir::failed([[$DAG_MATCHER2]]({{.*}} x{{[,)]}} +def : Pat<(AOp (EOp NonNegProp:$x, NonNegProp:$_, I32:$z)), + (AOp $z)>; + +// CHECK: if(::mlir::failed([[$DAG_MATCHER2]]({{.*}} x{{[,)]}} +def : Pat<(FOp $_, (EOp NonNegProp:$x, NonNegProp:$_, I32:$z)), + (COp $x, $z)>; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 75721c89793b5..975a524a53285 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -125,6 +125,11 @@ class PatternEmitter { void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex, int depth); + // Emits C++ statements for matching the `argIndex`-th argument of the given + // DAG `tree` as a property. + void emitPropertyMatch(DagNode tree, StringRef castedName, int argIndex, + int depth); + // Emits C++ for checking a match with a corresponding match failure // diagnostic. void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, @@ -338,7 +343,7 @@ class StaticMatcherHelper { // for each DagNode. int staticMatcherCounter = 0; - // The DagLeaf which contains type or attr constraint. + // The DagLeaf which contains type, attr, or prop constraint. SetVector constraints; // Static type/attribute verification function emitter. @@ -487,6 +492,19 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, auto leaf = tree.getArgAsLeaf(i); if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { os << "::mlir::Attribute " << argName << ";\n"; + } else if (leaf.isPropMatcher()) { + StringRef interfaceType = leaf.getAsPropConstraint().getInterfaceType(); + if (interfaceType.empty()) + PrintFatalError(loc, "NativeCodeCall cannot have a property operand " + "with unspecified interface type"); + os << interfaceType << " " << argName; + if (leaf.isPropDefinition()) { + Property propDef = leaf.getAsProperty(); + // Ensure properties that aren't zero-arg-constructable still work. + if (propDef.hasDefaultValue()) + os << " = " << propDef.getDefaultValue(); + } + os << ";\n"; } else { os << "::mlir::Value " << argName << ";\n"; } @@ -539,7 +557,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, auto constraint = leaf.getAsConstraint(); std::string self; - if (leaf.isAttrMatcher() || leaf.isConstantAttr()) + if (leaf.isAttrMatcher() || leaf.isConstantAttr() || leaf.isPropMatcher()) self = argName; else self = formatv("{0}.getType()", argName); @@ -665,6 +683,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { ++nextOperand; } else if (isa(opArg)) { emitAttributeMatch(tree, castedName, opArgIdx, depth); + } else if (isa(opArg)) { + emitPropertyMatch(tree, castedName, opArgIdx, depth); } else { PrintFatalError(loc, "unhandled case when matching op"); } @@ -942,6 +962,46 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName, os.unindent() << "}\n"; } +void PatternEmitter::emitPropertyMatch(DagNode tree, StringRef castedName, + int argIndex, int depth) { + Operator &op = tree.getDialectOp(opMap); + auto *namedProp = cast(op.getArg(argIndex)); + + os << "{\n"; + os.indent() << formatv( + "[[maybe_unused]] auto tblgen_prop = {0}.getProperties().{1}();\n", + castedName, op.getGetterName(namedProp->name)); + + auto matcher = tree.getArgAsLeaf(argIndex); + if (!matcher.isUnspecified()) { + if (!matcher.isPropMatcher()) { + PrintFatalError( + loc, formatv("the {1}-th argument of op '{0}' should be a property", + op.getOperationName(), argIndex + 1)); + } + + // If a constraint is specified, we need to generate function call to its + // static verifier. + StringRef verifier = staticMatcherHelper.getVerifierName(matcher); + emitStaticVerifierCall( + verifier, castedName, "tblgen_prop", + formatv("\"op '{0}' property '{1}' failed to satisfy constraint: " + "'{2}'\"", + op.getOperationName(), namedProp->name, + escapeString(matcher.getAsConstraint().getSummary())) + .str()); + } + + // Capture the value + auto name = tree.getArgName(argIndex); + // `$_` is a special symbol to ignore op argument matching. + if (!name.empty() && name != "_") { + os << formatv("{0} = tblgen_prop;\n", name); + } + + os.unindent() << "}\n"; +} + void PatternEmitter::emitMatchCheck( StringRef opName, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt) { @@ -1384,6 +1444,10 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, std::string val = std::to_string(enumCase.getValue()); return handleConstantAttr(Attribute(&enumCase.getDef()), val); } + if (leaf.isConstantProp()) { + auto constantProp = leaf.getAsConstantProp(); + return constantProp.getValue().str(); + } LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); @@ -1710,7 +1774,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs( for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { const auto *operand = llvm::dyn_cast_if_present(resultOp.getArg(argIndex)); - // We do not need special handling for attributes. + // We do not need special handling for attributes or properties. if (!operand) continue; @@ -1776,7 +1840,7 @@ void PatternEmitter::supplyValuesForOpArgs( if (auto subTree = node.getArgAsNestedDag(argIndex)) { if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " - "for creating attribute"); + "for creating attributes and properties"); os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex)); } else { auto leaf = node.getArgAsLeaf(argIndex); @@ -1788,6 +1852,11 @@ void PatternEmitter::supplyValuesForOpArgs( PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); if (!patArgName.empty()) os << "/*" << patArgName << "=*/"; + } else if (leaf.isConstantProp()) { + if (!isa(opArg)) + PrintFatalError(loc, Twine("expected property ") + Twine(argIndex)); + if (!patArgName.empty()) + os << "/*" << patArgName << "=*/"; } else { os << "/*" << opArgName << "=*/"; } @@ -1820,6 +1889,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " "tmpAttr);\n}\n"; const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd; + const char *propSetterCmd = "tblgen_props.{0}({1});\n"; int numVariadic = 0; bool hasOperandSegmentSizes = false; @@ -1845,6 +1915,28 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( continue; } + if (isa(resultOp.getArg(argIndex))) { + // The argument in the op definition. + auto opArgName = resultOp.getArgName(argIndex); + auto setterName = resultOp.getSetterName(opArgName); + if (auto subTree = node.getArgAsNestedDag(argIndex)) { + if (!subTree.isNativeCodeCall()) + PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " + "for creating property"); + + os << formatv(propSetterCmd, setterName, + childNodeNames.lookup(argIndex)); + } else { + auto leaf = node.getArgAsLeaf(argIndex); + // The argument in the result DAG pattern. + auto patArgName = node.getArgName(argIndex); + // The argument in the result DAG pattern. + os << formatv(propSetterCmd, setterName, + handleOpArgument(leaf, patArgName)); + } + continue; + } + const auto *operand = cast(resultOp.getArg(argIndex)); if (operand->isVariadic()) { @@ -1973,6 +2065,12 @@ StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { assert(constraint && "attribute constraint was not uniqued"); return *constraint; } + if (leaf.isPropMatcher()) { + std::optional constraint = + staticVerifierEmitter.getPropConstraintFn(leaf.getAsConstraint()); + assert(constraint && "prop constraint was not uniqued"); + return *constraint; + } assert(leaf.isOperandMatcher()); return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint()); }