From 1203b90c4ba7bfa79ab2fefe81ae7f05e96bde1c Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 20 May 2024 13:48:08 -0700 Subject: [PATCH 1/8] [mlir][polynomial] fix polynomial.constant syntax in docstrings --- .../mlir/Dialect/Polynomial/IR/Polynomial.td | 20 +++++++++---------- .../Polynomial/IR/PolynomialDialect.td | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td index 3ef899d3376b1..e03d2ec81e9c8 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> { // add two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> - %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> + %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> + %1 = polynomial.constant {value=#polynomial.int_polynomial} : !polynomial.polynomial<#ring> %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> { // subtract two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> - %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> + %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> + %1 = polynomial.constant {value=#polynomial.int_polynomial} : !polynomial.polynomial<#ring> %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> { // multiply two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> - %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> + %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> + %1 = polynomial.constant {value=#polynomial.int_polynomial} : !polynomial.polynomial<#ring> %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [ // multiply two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> %1 = arith.constant 3 : i32 %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32 ``` @@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> { ```mlir #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32) ``` }]; @@ -286,10 +286,10 @@ def Polynomial_ConstantOp : Op { ```mlir #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> #float_ring = #polynomial.ring - %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring> + %0 = polynomial.constant {value=#polynomial.float_polynomial<0.5 + 1.3e06 x**2>} : !polynomial.polynomial<#float_ring> ``` }]; let arguments = (ins Polynomial_AnyPolynomialAttr:$value); diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td index b0573b3715f78..73783815781cf 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td @@ -33,18 +33,18 @@ def Polynomial_Dialect : Dialect { ```mlir // A constant polynomial in a ring with i32 coefficients and no polynomial modulus #ring = #polynomial.ring - %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring> + %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring> // A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1) #modulus = #polynomial.int_polynomial<1 + x**1024> #ring = #polynomial.ring - %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring> + %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring> // A constant polynomial in a ring with i32 coefficients, with a polynomial // modulus of (x^1024 + 1) and a coefficient modulus of 17. #modulus = #polynomial.int_polynomial<1 + x**1024> #ring = #polynomial.ring - %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring> + %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring> ``` }]; From 29b42317e7b3800ffe4a3b4b0f89f499ce4e51e7 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 20 May 2024 15:53:26 -0700 Subject: [PATCH 2/8] Revert "[mlir][polynomial] fix polynomial.constant syntax in docstrings" This reverts commit 1203b90c4ba7bfa79ab2fefe81ae7f05e96bde1c. --- .../mlir/Dialect/Polynomial/IR/Polynomial.td | 20 +++++++++---------- .../Polynomial/IR/PolynomialDialect.td | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td index e03d2ec81e9c8..3ef899d3376b1 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> { // add two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> - %1 = polynomial.constant {value=#polynomial.int_polynomial} : !polynomial.polynomial<#ring> + %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> { // subtract two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> - %1 = polynomial.constant {value=#polynomial.int_polynomial} : !polynomial.polynomial<#ring> + %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> { // multiply two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> - %1 = polynomial.constant {value=#polynomial.int_polynomial} : !polynomial.polynomial<#ring> + %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [ // multiply two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> + %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> %1 = arith.constant 3 : i32 %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32 ``` @@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> { ```mlir #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> + %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32) ``` }]; @@ -286,10 +286,10 @@ def Polynomial_ConstantOp : Op { ```mlir #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring> + %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> #float_ring = #polynomial.ring - %0 = polynomial.constant {value=#polynomial.float_polynomial<0.5 + 1.3e06 x**2>} : !polynomial.polynomial<#float_ring> + %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring> ``` }]; let arguments = (ins Polynomial_AnyPolynomialAttr:$value); diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td index 73783815781cf..b0573b3715f78 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td @@ -33,18 +33,18 @@ def Polynomial_Dialect : Dialect { ```mlir // A constant polynomial in a ring with i32 coefficients and no polynomial modulus #ring = #polynomial.ring - %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring> + %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring> // A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1) #modulus = #polynomial.int_polynomial<1 + x**1024> #ring = #polynomial.ring - %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring> + %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring> // A constant polynomial in a ring with i32 coefficients, with a polynomial // modulus of (x^1024 + 1) and a coefficient modulus of 17. #modulus = #polynomial.int_polynomial<1 + x**1024> #ring = #polynomial.ring - %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring> + %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring> ``` }]; From f6276fe2d81a883676cfc24de152c0f16afdc16d Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 20 May 2024 17:01:42 -0700 Subject: [PATCH 3/8] add typed variants for polynomial.constant op --- .../mlir/Dialect/Polynomial/IR/Polynomial.td | 13 ++--- .../Polynomial/IR/PolynomialAttributes.td | 54 +++++++++++++++++-- .../Dialect/Polynomial/IR/PolynomialOps.cpp | 15 ++++++ mlir/test/Dialect/Polynomial/ops.mlir | 8 +-- 4 files changed, 77 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td index 3ef899d3376b1..85a9dd6b935d2 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -272,13 +272,14 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> { let hasVerifier = 1; } -def Polynomial_AnyPolynomialAttr : AnyAttrOf<[ - Polynomial_FloatPolynomialAttr, - Polynomial_IntPolynomialAttr +def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[ + Polynomial_TypedFloatPolynomialAttr, + Polynomial_TypedIntPolynomialAttr ]>; // Not deriving from Polynomial_Op due to need for custom assembly format -def Polynomial_ConstantOp : Op { +def Polynomial_ConstantOp : Op { let summary = "Define a constant polynomial via an attribute."; let description = [{ Example: @@ -292,9 +293,9 @@ def Polynomial_ConstantOp : Op { %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring> ``` }]; - let arguments = (ins Polynomial_AnyPolynomialAttr:$value); + let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value); let results = (outs Polynomial_PolynomialType:$output); - let assemblyFormat = "attr-dict `:` type($output)"; + let assemblyFormat = "attr-dict $value"; } def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> { diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td index e5dbfa7fa21ee..1ea07e21e0076 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -18,7 +18,7 @@ class Polynomial_Attr traits = []> } def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> { - let summary = "An attribute containing a single-variable polynomial with integer coefficients."; + let summary = "an attribute containing a single-variable polynomial with integer coefficients"; let description = [{ A polynomial attribute represents a single-variable polynomial with integer coefficients, which is used to define the modulus of a `RingAttr`, as well @@ -41,7 +41,7 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom } def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> { - let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients."; + let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients"; let description = [{ A polynomial attribute represents a single-variable polynomial with double precision floating point coefficients. @@ -62,8 +62,56 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p let hasCustomAssemblyFormat = 1; } +def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr< + "TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> { + let summary = "a typed int_polynomial"; + let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value); + let assemblyFormat = "$value `:` $type"; + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$type, + "const IntPolynomial &":$value), [{ + return $_get( + type.getContext(), + type, + IntPolynomialAttr::get(type.getContext(), value)); + }]>, + AttrBuilderWithInferredContext<(ins "Type":$type, + "const Attribute &":$value), [{ + return $_get(type.getContext(), type, ::llvm::cast(value)); + }]> + ]; + let extraClassDeclaration = [{ + // used for constFoldBinaryOp + using ValueType = ::mlir::Attribute; + }]; +} + +def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr< + "TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> { + let summary = "a typed float_polynomial"; + let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value); + let assemblyFormat = "$value `:` $type"; + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$type, + "const FloatPolynomial &":$value), [{ + return $_get( + type.getContext(), + type, + FloatPolynomialAttr::get(type.getContext(), value)); + }]>, + AttrBuilderWithInferredContext<(ins "Type":$type, + "const Attribute &":$value), [{ + return $_get(type.getContext(), type, ::llvm::cast(value)); + }]> + ]; + let extraClassDeclaration = [{ + // used for constFoldBinaryOp + using ValueType = ::mlir::Attribute; + }]; +} + def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { - let summary = "An attribute specifying a polynomial ring."; + let summary = "an attribute specifying a polynomial ring"; let description = [{ A ring describes the domain in which polynomial arithmetic occurs. The ring attribute in `polynomial` represents the more specific case of polynomials diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 1a2439fe810b5..4c2fed6bab312 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -186,6 +186,21 @@ LogicalResult INTTOp::verify() { return verifyNTTOp(this->getOperation(), ring, tensorType); } +LogicalResult ConstantOp::inferReturnTypes( + MLIRContext *context, std::optional location, + ConstantOp::Adaptor adaptor, + llvm::SmallVectorImpl &inferredReturnTypes) { + Attribute operand = adaptor.getValue(); + if (auto intPoly = dyn_cast(operand)) { + inferredReturnTypes.push_back(intPoly.getType()); + } else if (auto floatPoly = dyn_cast(operand)) { + inferredReturnTypes.push_back(floatPoly.getType()); + } else { + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd canonicalization patterns //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir index ff709960c50e9..695b1acf18bd7 100644 --- a/mlir/test/Dialect/Polynomial/ops.mlir +++ b/mlir/test/Dialect/Polynomial/ops.mlir @@ -74,15 +74,15 @@ module { func.func @test_monic_monomial_mul() { %five = arith.constant 5 : index - %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial + %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial, index) -> !polynomial.polynomial return } func.func @test_constant() { - %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial - %1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial - %2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial + %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial + %1 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial + %2 = polynomial.constant #polynomial.float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial return } From ef17f2af7f75bd2c98c05720eb8d3f77d652ef43 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 20 May 2024 22:25:14 -0700 Subject: [PATCH 4/8] show broken attempt --- .../mlir/Dialect/Polynomial/IR/Polynomial.td | 2 +- .../Polynomial/IR/PolynomialAttributes.td | 12 +++- .../Polynomial/IR/PolynomialAttributes.cpp | 64 ++++++++++++------- .../Dialect/Polynomial/IR/PolynomialOps.cpp | 55 ++++++++++++++++ mlir/test/Dialect/Polynomial/ops.mlir | 8 +-- 5 files changed, 111 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td index 85a9dd6b935d2..a0bd0bb0861bd 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -295,7 +295,7 @@ def Polynomial_ConstantOp : Op { diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td index 1ea07e21e0076..3bae6204299d1 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -38,6 +38,11 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom }]; let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial); let hasCustomAssemblyFormat = 1; + let extraClassDeclaration = [{ + /// A parser which, upon failure to parse, does not emit errors and just returns + /// a null attribute. + static Attribute parse(AsmParser &parser, Type type, bool optional); + }]; } def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> { @@ -60,6 +65,11 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p }]; let parameters = (ins "FloatPolynomial":$polynomial); let hasCustomAssemblyFormat = 1; + let extraClassDeclaration = [{ + /// A parser which, upon failure to parse, does not emit errors and just returns + /// a null attribute. + static Attribute parse(AsmParser &parser, Type type, bool optional); + }]; } def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr< @@ -81,7 +91,6 @@ def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr< }]> ]; let extraClassDeclaration = [{ - // used for constFoldBinaryOp using ValueType = ::mlir::Attribute; }]; } @@ -105,7 +114,6 @@ def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr< }]> ]; let extraClassDeclaration = [{ - // used for constFoldBinaryOp using ValueType = ::mlir::Attribute; }]; } diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp index 890ce5226c30f..94169b5e93cf8 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp @@ -38,10 +38,11 @@ using ParseCoefficientFn = std::function; /// a '+'. /// template -ParseResult -parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, - bool &isConstantTerm, bool &shouldParseMore, - ParseCoefficientFn parseAndStoreCoefficient) { +ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, + llvm::StringRef &variable, bool &isConstantTerm, + bool &shouldParseMore, + ParseCoefficientFn parseAndStoreCoefficient, + bool optional) { OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial); isConstantTerm = false; @@ -85,8 +86,9 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, // If there's a **, then the integer exponent is required. APInt parsedExponent(apintBitWidth, 0); if (failed(parser.parseInteger(parsedExponent))) { - parser.emitError(parser.getCurrentLocation(), - "found invalid integer exponent"); + if (!optional) + parser.emitError(parser.getCurrentLocation(), + "found invalid integer exponent"); return failure(); } @@ -101,11 +103,12 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, return success(); } -template +template LogicalResult parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, llvm::StringSet<> &variables, - ParseCoefficientFn parseAndStoreCoefficient) { + ParseCoefficientFn parseAndStoreCoefficient, + bool optional) { while (true) { Monomial parsedMonomial; llvm::StringRef parsedVariableRef; @@ -113,8 +116,9 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, bool shouldParseMore; if (failed(parseMonomial( parser, parsedMonomial, parsedVariableRef, isConstantTerm, - shouldParseMore, parseAndStoreCoefficient))) { - parser.emitError(parser.getCurrentLocation(), "expected a monomial"); + shouldParseMore, parseAndStoreCoefficient, optional))) { + if (!optional) + parser.emitError(parser.getCurrentLocation(), "expected a monomial"); return failure(); } @@ -130,18 +134,20 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, if (succeeded(parser.parseOptionalGreater())) { break; } - parser.emitError( - parser.getCurrentLocation(), - "expected + and more monomials, or > to end polynomial attribute"); + if (!optional) + parser.emitError( + parser.getCurrentLocation(), + "expected + and more monomials, or > to end polynomial attribute"); return failure(); } if (variables.size() > 1) { std::string vars = llvm::join(variables.keys(), ", "); - parser.emitError( - parser.getCurrentLocation(), - "polynomials must have one indeterminate, but there were multiple: " + - vars); + if (!optional) + parser.emitError( + parser.getCurrentLocation(), + "polynomials must have one indeterminate, but there were multiple: " + + vars); return failure(); } @@ -149,13 +155,18 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, } Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) { + return IntPolynomialAttr::parse(parser, type, /*optional=*/false); +} + +Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type, + bool optional) { if (failed(parser.parseLess())) return {}; llvm::SmallVector monomials; llvm::StringSet<> variables; - if (failed(parsePolynomialAttr( + if (failed(parsePolynomialAttr( parser, monomials, variables, [&](IntMonomial &monomial) -> OptionalParseResult { APInt parsedCoeff(apintBitWidth, 1); @@ -163,20 +174,27 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) { parser.parseOptionalInteger(parsedCoeff); monomial.setCoefficient(parsedCoeff); return result; - }))) { + }, + optional))) { return {}; } auto result = IntPolynomial::fromMonomials(monomials); if (failed(result)) { - parser.emitError(parser.getCurrentLocation()) - << "parsed polynomial must have unique exponents among monomials"; + if (!optional) + parser.emitError(parser.getCurrentLocation()) + << "parsed polynomial must have unique exponents among monomials"; return {}; } return IntPolynomialAttr::get(parser.getContext(), result.value()); } Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { + return FloatPolynomialAttr::parse(parser, type, /*optional=*/false); +} + +Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type, + bool optional) { if (failed(parser.parseLess())) return {}; @@ -191,8 +209,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { return OptionalParseResult(result); }; - if (failed(parsePolynomialAttr( - parser, monomials, variables, parseAndStoreCoefficient))) { + if (failed(parsePolynomialAttr( + parser, monomials, variables, parseAndStoreCoefficient, optional))) { return {}; } diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 4c2fed6bab312..c7c61d2ad8190 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -186,6 +186,60 @@ LogicalResult INTTOp::verify() { return verifyNTTOp(this->getOperation(), ring, tensorType); } +ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { + // Using the built-in parser.parseAttribute requires the full + // #polynomial.typed_int_polynomial syntax, which is excessive. + // Instead we manually parse the components. + Type type; + parser.parseOptionalAttribute(); + + IntPolynomialAttr intPolyAttr; + parser.parseOptionalAttribute(intPolyAttr); + if (intPolyAttr) { + if (parser.parseColon() || parser.parseType(type)) + return failure(); + + result.addAttribute("value", + TypedIntPolynomialAttr::get(type, intPolyAttr)); + result.addTypes(type); + return success(); + } + + Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr, /*optional=*/true); + if (floatPolyAttr) { + if (parser.parseColon() || parser.parseType(type)) + return failure(); + result.addAttribute("value", + TypedFloatPolynomialAttr::get(type, intPolyAttr)); + result.addTypes(type); + return success(); + } + + // In the worst case, still accept the verbose versions. + TypedIntPolynomialAttr typedIntPolyAttr; + ParseResult res = parser.parseAttribute( + typedIntPolyAttr, "value", result.attributes); + if (succeeded(res)) { + result.addTypes(typedIntPolyAttr.getType()); + return success(); + } + + TypedFloatPolynomialAttr typedFloatPolyAttr; + res = parser.parseAttribute( + typedFloatPolyAttr, "value", result.attributes); + if (succeeded(res)) { + result.addTypes(typedFloatPolyAttr.getType()); + return success(); + } + + return failure(); +} + +void ConstantOp::print(OpAsmPrinter &p) { + p << " "; + p.printAttribute(getValue()); +} + LogicalResult ConstantOp::inferReturnTypes( MLIRContext *context, std::optional location, ConstantOp::Adaptor adaptor, @@ -196,6 +250,7 @@ LogicalResult ConstantOp::inferReturnTypes( } else if (auto floatPoly = dyn_cast(operand)) { inferredReturnTypes.push_back(floatPoly.getType()); } else { + assert(false && "unexpected attribute type"); return failure(); } return success(); diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir index 695b1acf18bd7..cfe3446a1dccf 100644 --- a/mlir/test/Dialect/Polynomial/ops.mlir +++ b/mlir/test/Dialect/Polynomial/ops.mlir @@ -74,15 +74,15 @@ module { func.func @test_monic_monomial_mul() { %five = arith.constant 5 : index - %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial + %0 = polynomial.constant <1 + x**2> : !polynomial.polynomial %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial, index) -> !polynomial.polynomial return } func.func @test_constant() { - %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial - %1 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial - %2 = polynomial.constant #polynomial.float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial + %0 = polynomial.constant <1 + x**2> : !polynomial.polynomial + %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial + %2 = polynomial.constant <1.5 + 0.5 x**2> : !polynomial.polynomial return } From 74470e38e3b380c0ae1fe6951ada2fe8a189d532 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 21 May 2024 15:42:17 -0700 Subject: [PATCH 5/8] use int/float keywords --- .../Dialect/Polynomial/IR/PolynomialOps.cpp | 61 +++++++++++-------- mlir/test/Dialect/Polynomial/ops.mlir | 12 ++-- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index c7c61d2ad8190..38e7db85a1e97 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -189,37 +189,38 @@ LogicalResult INTTOp::verify() { ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { // Using the built-in parser.parseAttribute requires the full // #polynomial.typed_int_polynomial syntax, which is excessive. - // Instead we manually parse the components. + // Instead we parse a keyword int to signal it's an integer polynomial Type type; - parser.parseOptionalAttribute(); - - IntPolynomialAttr intPolyAttr; - parser.parseOptionalAttribute(intPolyAttr); - if (intPolyAttr) { - if (parser.parseColon() || parser.parseType(type)) - return failure(); - - result.addAttribute("value", - TypedIntPolynomialAttr::get(type, intPolyAttr)); - result.addTypes(type); - return success(); + if (succeeded(parser.parseOptionalKeyword("float"))) { + Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr); + if (floatPolyAttr) { + if (parser.parseColon() || parser.parseType(type)) + return failure(); + result.addAttribute("value", + TypedFloatPolynomialAttr::get(type, floatPolyAttr)); + result.addTypes(type); + return success(); + } } - Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr, /*optional=*/true); - if (floatPolyAttr) { - if (parser.parseColon() || parser.parseType(type)) - return failure(); - result.addAttribute("value", - TypedFloatPolynomialAttr::get(type, intPolyAttr)); - result.addTypes(type); - return success(); + if (succeeded(parser.parseOptionalKeyword("int"))) { + Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr); + if (intPolyAttr) { + if (parser.parseColon() || parser.parseType(type)) + return failure(); + + result.addAttribute("value", + TypedIntPolynomialAttr::get(type, intPolyAttr)); + result.addTypes(type); + return success(); + } } // In the worst case, still accept the verbose versions. TypedIntPolynomialAttr typedIntPolyAttr; - ParseResult res = parser.parseAttribute( + OptionalParseResult res = parser.parseOptionalAttribute( typedIntPolyAttr, "value", result.attributes); - if (succeeded(res)) { + if (res.has_value() && succeeded(res.value())) { result.addTypes(typedIntPolyAttr.getType()); return success(); } @@ -227,7 +228,7 @@ ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { TypedFloatPolynomialAttr typedFloatPolyAttr; res = parser.parseAttribute( typedFloatPolyAttr, "value", result.attributes); - if (succeeded(res)) { + if (res.has_value() && succeeded(res.value())) { result.addTypes(typedFloatPolyAttr.getType()); return success(); } @@ -237,7 +238,17 @@ ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { void ConstantOp::print(OpAsmPrinter &p) { p << " "; - p.printAttribute(getValue()); + if (auto intPoly = dyn_cast(getValue())) { + p << "int"; + intPoly.getValue().print(p); + } else if (auto floatPoly = dyn_cast(getValue())) { + p << "float"; + floatPoly.getValue().print(p); + } else { + assert(false && "unexpected attribute type"); + } + p << " : "; + p.printType(getOutput().getType()); } LogicalResult ConstantOp::inferReturnTypes( diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir index cfe3446a1dccf..4716e37ff8852 100644 --- a/mlir/test/Dialect/Polynomial/ops.mlir +++ b/mlir/test/Dialect/Polynomial/ops.mlir @@ -74,15 +74,19 @@ module { func.func @test_monic_monomial_mul() { %five = arith.constant 5 : index - %0 = polynomial.constant <1 + x**2> : !polynomial.polynomial + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial, index) -> !polynomial.polynomial return } func.func @test_constant() { - %0 = polynomial.constant <1 + x**2> : !polynomial.polynomial - %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial - %2 = polynomial.constant <1.5 + 0.5 x**2> : !polynomial.polynomial + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial + %1 = polynomial.constant int<1 + x**2> : !polynomial.polynomial + %2 = polynomial.constant float<1.5 + 0.5 x**2> : !polynomial.polynomial + + // Test verbose fallbacks + %verb0 = polynomial.constant #polynomial.typed_int_polynomial<1 + x**2> : !polynomial.polynomial + %verb2 = polynomial.constant #polynomial.typed_float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial return } From 431bf8af1e851d8261e50a06413bf9955dafa97d Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 21 May 2024 15:50:32 -0700 Subject: [PATCH 6/8] update docs one last time --- .../mlir/Dialect/Polynomial/IR/Polynomial.td | 25 +++++++++---------- .../Polynomial/IR/PolynomialAttributes.td | 18 +++++++++++++ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td index a0bd0bb0861bd..f99cbccd243ec 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> { // add two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> - %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> + %1 = polynomial.constant int : !polynomial.polynomial<#ring> %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> { // subtract two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> - %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> + %1 = polynomial.constant int : !polynomial.polynomial<#ring> %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> { // multiply two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> - %1 = polynomial.constant #polynomial.int_polynomial : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> + %1 = polynomial.constant int : !polynomial.polynomial<#ring> %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring> ``` }]; @@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [ // multiply two polynomials modulo x^1024 - 1 #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> %1 = arith.constant 3 : i32 %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32 ``` @@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> { ```mlir #poly = #polynomial.int_polynomial #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring> %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32) ``` }]; @@ -285,12 +285,11 @@ def Polynomial_ConstantOp : Op - #ring = #polynomial.ring - %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring> + !int_poly_ty = !polynomial.polynomial> + %0 = polynomial.constant int<1 + x**2> : !int_poly_ty - #float_ring = #polynomial.ring - %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring> + !float_poly_ty = !polynomial.polynomial> + %1 = polynomial.constant float<0.5 + 1.3e06 x**2> : !float_poly_ty ``` }]; let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value); diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td index 3bae6204299d1..5298542faac9a 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -75,6 +75,15 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr< "TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> { let summary = "a typed int_polynomial"; + let description = [{ + Example: + + ```mlir + !poly_ty = !polynomial.polynomial> + #poly = int<1 x**7 + 4> : !poly_ty + #poly_verbose = #polynomial.typed_int_polynomial<1 x**7 + 4> : !poly_ty + ``` + }]; let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value); let assemblyFormat = "$value `:` $type"; let builders = [ @@ -98,6 +107,15 @@ def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr< def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr< "TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> { let summary = "a typed float_polynomial"; + let description = [{ + Example: + + ```mlir + !poly_ty = !polynomial.polynomial> + #poly = float<1.4 x**7 + 4.5> : !poly_ty + #poly_verbose = #polynomial.typed_float_polynomial<1.4 x**7 + 4.5> : !poly_ty + ``` + }]; let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value); let assemblyFormat = "$value `:` $type"; let builders = [ From f9019bcddeeae49c77f2d82250c2e765a32ae716 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 21 May 2024 15:53:47 -0700 Subject: [PATCH 7/8] remove optional parse option --- .../Polynomial/IR/PolynomialAttributes.td | 10 --- .../Polynomial/IR/PolynomialAttributes.cpp | 61 +++++++------------ 2 files changed, 21 insertions(+), 50 deletions(-) diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td index 5298542faac9a..655020adf808b 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -38,11 +38,6 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom }]; let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial); let hasCustomAssemblyFormat = 1; - let extraClassDeclaration = [{ - /// A parser which, upon failure to parse, does not emit errors and just returns - /// a null attribute. - static Attribute parse(AsmParser &parser, Type type, bool optional); - }]; } def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> { @@ -65,11 +60,6 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p }]; let parameters = (ins "FloatPolynomial":$polynomial); let hasCustomAssemblyFormat = 1; - let extraClassDeclaration = [{ - /// A parser which, upon failure to parse, does not emit errors and just returns - /// a null attribute. - static Attribute parse(AsmParser &parser, Type type, bool optional); - }]; } def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr< diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp index 94169b5e93cf8..cc7d3172b1a1d 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp @@ -38,11 +38,10 @@ using ParseCoefficientFn = std::function; /// a '+'. /// template -ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, - llvm::StringRef &variable, bool &isConstantTerm, - bool &shouldParseMore, - ParseCoefficientFn parseAndStoreCoefficient, - bool optional) { +ParseResult +parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, + bool &isConstantTerm, bool &shouldParseMore, + ParseCoefficientFn parseAndStoreCoefficient) { OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial); isConstantTerm = false; @@ -86,9 +85,8 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, // If there's a **, then the integer exponent is required. APInt parsedExponent(apintBitWidth, 0); if (failed(parser.parseInteger(parsedExponent))) { - if (!optional) - parser.emitError(parser.getCurrentLocation(), - "found invalid integer exponent"); + parser.emitError(parser.getCurrentLocation(), + "found invalid integer exponent"); return failure(); } @@ -107,8 +105,7 @@ template LogicalResult parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, llvm::StringSet<> &variables, - ParseCoefficientFn parseAndStoreCoefficient, - bool optional) { + ParseCoefficientFn parseAndStoreCoefficient) { while (true) { Monomial parsedMonomial; llvm::StringRef parsedVariableRef; @@ -116,9 +113,8 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, bool shouldParseMore; if (failed(parseMonomial( parser, parsedMonomial, parsedVariableRef, isConstantTerm, - shouldParseMore, parseAndStoreCoefficient, optional))) { - if (!optional) - parser.emitError(parser.getCurrentLocation(), "expected a monomial"); + shouldParseMore, parseAndStoreCoefficient))) { + parser.emitError(parser.getCurrentLocation(), "expected a monomial"); return failure(); } @@ -134,20 +130,18 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, if (succeeded(parser.parseOptionalGreater())) { break; } - if (!optional) - parser.emitError( - parser.getCurrentLocation(), - "expected + and more monomials, or > to end polynomial attribute"); + parser.emitError( + parser.getCurrentLocation(), + "expected + and more monomials, or > to end polynomial attribute"); return failure(); } if (variables.size() > 1) { std::string vars = llvm::join(variables.keys(), ", "); - if (!optional) - parser.emitError( - parser.getCurrentLocation(), - "polynomials must have one indeterminate, but there were multiple: " + - vars); + parser.emitError( + parser.getCurrentLocation(), + "polynomials must have one indeterminate, but there were multiple: " + + vars); return failure(); } @@ -155,11 +149,6 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector &monomials, } Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) { - return IntPolynomialAttr::parse(parser, type, /*optional=*/false); -} - -Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type, - bool optional) { if (failed(parser.parseLess())) return {}; @@ -174,27 +163,19 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type, parser.parseOptionalInteger(parsedCoeff); monomial.setCoefficient(parsedCoeff); return result; - }, - optional))) { + }))) { return {}; } auto result = IntPolynomial::fromMonomials(monomials); if (failed(result)) { - if (!optional) - parser.emitError(parser.getCurrentLocation()) - << "parsed polynomial must have unique exponents among monomials"; + parser.emitError(parser.getCurrentLocation()) + << "parsed polynomial must have unique exponents among monomials"; return {}; } return IntPolynomialAttr::get(parser.getContext(), result.value()); } - Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { - return FloatPolynomialAttr::parse(parser, type, /*optional=*/false); -} - -Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type, - bool optional) { if (failed(parser.parseLess())) return {}; @@ -209,8 +190,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type, return OptionalParseResult(result); }; - if (failed(parsePolynomialAttr( - parser, monomials, variables, parseAndStoreCoefficient, optional))) { + if (failed(parsePolynomialAttr(parser, monomials, variables, + parseAndStoreCoefficient))) { return {}; } From 0da56f3f6270069c589cadd6e920a927de536f35 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 21 May 2024 15:59:41 -0700 Subject: [PATCH 8/8] clang-format --- mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 38e7db85a1e97..d0a25fd9288b9 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -218,8 +218,9 @@ ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { // In the worst case, still accept the verbose versions. TypedIntPolynomialAttr typedIntPolyAttr; - OptionalParseResult res = parser.parseOptionalAttribute( - typedIntPolyAttr, "value", result.attributes); + OptionalParseResult res = + parser.parseOptionalAttribute( + typedIntPolyAttr, "value", result.attributes); if (res.has_value() && succeeded(res.value())) { result.addTypes(typedIntPolyAttr.getType()); return success();