Skip to content

Commit ab29203

Browse files
authored
[mlir][polynomial] use typed attributes for polynomial.constant op (#92818)
Co-authored-by: Jeremy Kun <[email protected]>
1 parent 66db7c6 commit ab29203

File tree

5 files changed

+180
-31
lines changed

5 files changed

+180
-31
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

+19-19
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
5252
// add two polynomials modulo x^1024 - 1
5353
#poly = #polynomial.int_polynomial<x**1024 - 1>
5454
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
55-
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
56-
%1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
55+
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
56+
%1 = polynomial.constant int<x**5 - x + 1> : !polynomial.polynomial<#ring>
5757
%2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
5858
```
5959
}];
@@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
7676
// subtract two polynomials modulo x^1024 - 1
7777
#poly = #polynomial.int_polynomial<x**1024 - 1>
7878
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
79-
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
80-
%1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
79+
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
80+
%1 = polynomial.constant int<x**5 - x + 1> : !polynomial.polynomial<#ring>
8181
%2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
8282
```
8383
}];
@@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
101101
// multiply two polynomials modulo x^1024 - 1
102102
#poly = #polynomial.int_polynomial<x**1024 - 1>
103103
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
104-
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
105-
%1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
104+
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
105+
%1 = polynomial.constant int<x**5 - x + 1> : !polynomial.polynomial<#ring>
106106
%2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
107107
```
108108
}];
@@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
126126
// multiply two polynomials modulo x^1024 - 1
127127
#poly = #polynomial.int_polynomial<x**1024 - 1>
128128
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
129-
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
129+
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
130130
%1 = arith.constant 3 : i32
131131
%2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
132132
```
@@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
157157
```mlir
158158
#poly = #polynomial.int_polynomial<x**1024 - 1>
159159
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
160-
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
160+
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
161161
%1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
162162
```
163163
}];
@@ -272,29 +272,29 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
272272
let hasVerifier = 1;
273273
}
274274

275-
def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
276-
Polynomial_FloatPolynomialAttr,
277-
Polynomial_IntPolynomialAttr
275+
def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
276+
Polynomial_TypedFloatPolynomialAttr,
277+
Polynomial_TypedIntPolynomialAttr
278278
]>;
279279

280280
// Not deriving from Polynomial_Op due to need for custom assembly format
281-
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
281+
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
282+
[Pure, InferTypeOpAdaptor]> {
282283
let summary = "Define a constant polynomial via an attribute.";
283284
let description = [{
284285
Example:
285286

286287
```mlir
287-
#poly = #polynomial.int_polynomial<x**1024 - 1>
288-
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
289-
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
288+
!int_poly_ty = !polynomial.polynomial<ring=<coefficientType=i32>>
289+
%0 = polynomial.constant int<1 + x**2> : !int_poly_ty
290290

291-
#float_ring = #polynomial.ring<coefficientType=f32>
292-
%0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
291+
!float_poly_ty = !polynomial.polynomial<ring=<coefficientType=f32>>
292+
%1 = polynomial.constant float<0.5 + 1.3e06 x**2> : !float_poly_ty
293293
```
294294
}];
295-
let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
295+
let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value);
296296
let results = (outs Polynomial_PolynomialType:$output);
297-
let assemblyFormat = "attr-dict `:` type($output)";
297+
let hasCustomAssemblyFormat = 1;
298298
}
299299

300300
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {

mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td

+67-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
1818
}
1919

2020
def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
21-
let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
21+
let summary = "an attribute containing a single-variable polynomial with integer coefficients";
2222
let description = [{
2323
A polynomial attribute represents a single-variable polynomial with integer
2424
coefficients, which is used to define the modulus of a `RingAttr`, as well
@@ -41,7 +41,7 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
4141
}
4242

4343
def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
44-
let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
44+
let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients";
4545
let description = [{
4646
A polynomial attribute represents a single-variable polynomial with double
4747
precision floating point coefficients.
@@ -62,8 +62,72 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
6262
let hasCustomAssemblyFormat = 1;
6363
}
6464

65+
def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
66+
"TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
67+
let summary = "a typed int_polynomial";
68+
let description = [{
69+
Example:
70+
71+
```mlir
72+
!poly_ty = !polynomial.polynomial<ring=<coefficientType=i32>>
73+
#poly = int<1 x**7 + 4> : !poly_ty
74+
#poly_verbose = #polynomial.typed_int_polynomial<1 x**7 + 4> : !poly_ty
75+
```
76+
}];
77+
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value);
78+
let assemblyFormat = "$value `:` $type";
79+
let builders = [
80+
AttrBuilderWithInferredContext<(ins "Type":$type,
81+
"const IntPolynomial &":$value), [{
82+
return $_get(
83+
type.getContext(),
84+
type,
85+
IntPolynomialAttr::get(type.getContext(), value));
86+
}]>,
87+
AttrBuilderWithInferredContext<(ins "Type":$type,
88+
"const Attribute &":$value), [{
89+
return $_get(type.getContext(), type, ::llvm::cast<IntPolynomialAttr>(value));
90+
}]>
91+
];
92+
let extraClassDeclaration = [{
93+
using ValueType = ::mlir::Attribute;
94+
}];
95+
}
96+
97+
def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
98+
"TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
99+
let summary = "a typed float_polynomial";
100+
let description = [{
101+
Example:
102+
103+
```mlir
104+
!poly_ty = !polynomial.polynomial<ring=<coefficientType=f32>>
105+
#poly = float<1.4 x**7 + 4.5> : !poly_ty
106+
#poly_verbose = #polynomial.typed_float_polynomial<1.4 x**7 + 4.5> : !poly_ty
107+
```
108+
}];
109+
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value);
110+
let assemblyFormat = "$value `:` $type";
111+
let builders = [
112+
AttrBuilderWithInferredContext<(ins "Type":$type,
113+
"const FloatPolynomial &":$value), [{
114+
return $_get(
115+
type.getContext(),
116+
type,
117+
FloatPolynomialAttr::get(type.getContext(), value));
118+
}]>,
119+
AttrBuilderWithInferredContext<(ins "Type":$type,
120+
"const Attribute &":$value), [{
121+
return $_get(type.getContext(), type, ::llvm::cast<FloatPolynomialAttr>(value));
122+
}]>
123+
];
124+
let extraClassDeclaration = [{
125+
using ValueType = ::mlir::Attribute;
126+
}];
127+
}
128+
65129
def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
66-
let summary = "An attribute specifying a polynomial ring.";
130+
let summary = "an attribute specifying a polynomial ring";
67131
let description = [{
68132
A ring describes the domain in which polynomial arithmetic occurs. The ring
69133
attribute in `polynomial` represents the more specific case of polynomials

mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
101101
return success();
102102
}
103103

104-
template <typename PolynoimalAttrTy, typename Monomial>
104+
template <typename Monomial>
105105
LogicalResult
106106
parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
107107
llvm::StringSet<> &variables,
@@ -155,7 +155,7 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
155155
llvm::SmallVector<IntMonomial> monomials;
156156
llvm::StringSet<> variables;
157157

158-
if (failed(parsePolynomialAttr<IntPolynomialAttr, IntMonomial>(
158+
if (failed(parsePolynomialAttr<IntMonomial>(
159159
parser, monomials, variables,
160160
[&](IntMonomial &monomial) -> OptionalParseResult {
161161
APInt parsedCoeff(apintBitWidth, 1);
@@ -175,7 +175,6 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
175175
}
176176
return IntPolynomialAttr::get(parser.getContext(), result.value());
177177
}
178-
179178
Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
180179
if (failed(parser.parseLess()))
181180
return {};
@@ -191,8 +190,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
191190
return OptionalParseResult(result);
192191
};
193192

194-
if (failed(parsePolynomialAttr<FloatPolynomialAttr, FloatMonomial>(
195-
parser, monomials, variables, parseAndStoreCoefficient))) {
193+
if (failed(parsePolynomialAttr<FloatMonomial>(parser, monomials, variables,
194+
parseAndStoreCoefficient))) {
196195
return {};
197196
}
198197

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

+82
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,88 @@ LogicalResult INTTOp::verify() {
186186
return verifyNTTOp(this->getOperation(), ring, tensorType);
187187
}
188188

189+
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
190+
// Using the built-in parser.parseAttribute requires the full
191+
// #polynomial.typed_int_polynomial syntax, which is excessive.
192+
// Instead we parse a keyword int to signal it's an integer polynomial
193+
Type type;
194+
if (succeeded(parser.parseOptionalKeyword("float"))) {
195+
Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr);
196+
if (floatPolyAttr) {
197+
if (parser.parseColon() || parser.parseType(type))
198+
return failure();
199+
result.addAttribute("value",
200+
TypedFloatPolynomialAttr::get(type, floatPolyAttr));
201+
result.addTypes(type);
202+
return success();
203+
}
204+
}
205+
206+
if (succeeded(parser.parseOptionalKeyword("int"))) {
207+
Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr);
208+
if (intPolyAttr) {
209+
if (parser.parseColon() || parser.parseType(type))
210+
return failure();
211+
212+
result.addAttribute("value",
213+
TypedIntPolynomialAttr::get(type, intPolyAttr));
214+
result.addTypes(type);
215+
return success();
216+
}
217+
}
218+
219+
// In the worst case, still accept the verbose versions.
220+
TypedIntPolynomialAttr typedIntPolyAttr;
221+
OptionalParseResult res =
222+
parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
223+
typedIntPolyAttr, "value", result.attributes);
224+
if (res.has_value() && succeeded(res.value())) {
225+
result.addTypes(typedIntPolyAttr.getType());
226+
return success();
227+
}
228+
229+
TypedFloatPolynomialAttr typedFloatPolyAttr;
230+
res = parser.parseAttribute<TypedFloatPolynomialAttr>(
231+
typedFloatPolyAttr, "value", result.attributes);
232+
if (res.has_value() && succeeded(res.value())) {
233+
result.addTypes(typedFloatPolyAttr.getType());
234+
return success();
235+
}
236+
237+
return failure();
238+
}
239+
240+
void ConstantOp::print(OpAsmPrinter &p) {
241+
p << " ";
242+
if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) {
243+
p << "int";
244+
intPoly.getValue().print(p);
245+
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
246+
p << "float";
247+
floatPoly.getValue().print(p);
248+
} else {
249+
assert(false && "unexpected attribute type");
250+
}
251+
p << " : ";
252+
p.printType(getOutput().getType());
253+
}
254+
255+
LogicalResult ConstantOp::inferReturnTypes(
256+
MLIRContext *context, std::optional<mlir::Location> location,
257+
ConstantOp::Adaptor adaptor,
258+
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
259+
Attribute operand = adaptor.getValue();
260+
if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
261+
inferredReturnTypes.push_back(intPoly.getType());
262+
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
263+
inferredReturnTypes.push_back(floatPoly.getType());
264+
} else {
265+
assert(false && "unexpected attribute type");
266+
return failure();
267+
}
268+
return success();
269+
}
270+
189271
//===----------------------------------------------------------------------===//
190272
// TableGen'd canonicalization patterns
191273
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Polynomial/ops.mlir

+8-4
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,19 @@ module {
7474

7575
func.func @test_monic_monomial_mul() {
7676
%five = arith.constant 5 : index
77-
%0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
77+
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
7878
%1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
7979
return
8080
}
8181

8282
func.func @test_constant() {
83-
%0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
84-
%1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<ring=#ring1>
85-
%2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<ring=#ring2>
83+
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
84+
%1 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
85+
%2 = polynomial.constant float<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
86+
87+
// Test verbose fallbacks
88+
%verb0 = polynomial.constant #polynomial.typed_int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
89+
%verb2 = polynomial.constant #polynomial.typed_float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
8690
return
8791
}
8892

0 commit comments

Comments
 (0)