Skip to content

[mlir][polynomial] use typed attributes for polynomial.constant op #92818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
// add two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
%1 = polynomial.constant int<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
```
}];
Expand All @@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
// subtract two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
%1 = polynomial.constant int<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
```
}];
Expand All @@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
// multiply two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
%1 = polynomial.constant int<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
```
}];
Expand All @@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
// multiply two polynomials modulo x^1024 - 1
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
%1 = arith.constant 3 : i32
%2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
```
Expand Down Expand Up @@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
```mlir
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
%1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
```
}];
Expand Down Expand Up @@ -272,29 +272,29 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
let hasVerifier = 1;
}

def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
Polynomial_FloatPolynomialAttr,
Polynomial_IntPolynomialAttr
def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
Polynomial_TypedFloatPolynomialAttr,
Polynomial_TypedIntPolynomialAttr
]>;

// Not deriving from Polynomial_Op due to need for custom assembly format
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
[Pure, InferTypeOpAdaptor]> {
let summary = "Define a constant polynomial via an attribute.";
let description = [{
Example:

```mlir
#poly = #polynomial.int_polynomial<x**1024 - 1>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
%0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
!int_poly_ty = !polynomial.polynomial<ring=<coefficientType=i32>>
%0 = polynomial.constant int<1 + x**2> : !int_poly_ty

#float_ring = #polynomial.ring<coefficientType=f32>
%0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
!float_poly_ty = !polynomial.polynomial<ring=<coefficientType=f32>>
%1 = polynomial.constant float<0.5 + 1.3e06 x**2> : !float_poly_ty
```
}];
let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value);
let results = (outs Polynomial_PolynomialType:$output);
let assemblyFormat = "attr-dict `:` type($output)";
let hasCustomAssemblyFormat = 1;
}

def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
Expand Down
70 changes: 67 additions & 3 deletions mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
}

def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
let summary = "an attribute containing a single-variable polynomial with integer coefficients";
let description = [{
A polynomial attribute represents a single-variable polynomial with integer
coefficients, which is used to define the modulus of a `RingAttr`, as well
Expand All @@ -41,7 +41,7 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
}

def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients";
let description = [{
A polynomial attribute represents a single-variable polynomial with double
precision floating point coefficients.
Expand All @@ -62,8 +62,72 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
let hasCustomAssemblyFormat = 1;
}

def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
"TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
let summary = "a typed int_polynomial";
let description = [{
Example:

```mlir
!poly_ty = !polynomial.polynomial<ring=<coefficientType=i32>>
#poly = int<1 x**7 + 4> : !poly_ty
#poly_verbose = #polynomial.typed_int_polynomial<1 x**7 + 4> : !poly_ty
```
}];
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value);
let assemblyFormat = "$value `:` $type";
let builders = [
AttrBuilderWithInferredContext<(ins "Type":$type,
"const IntPolynomial &":$value), [{
return $_get(
type.getContext(),
type,
IntPolynomialAttr::get(type.getContext(), value));
}]>,
AttrBuilderWithInferredContext<(ins "Type":$type,
"const Attribute &":$value), [{
return $_get(type.getContext(), type, ::llvm::cast<IntPolynomialAttr>(value));
}]>
];
let extraClassDeclaration = [{
using ValueType = ::mlir::Attribute;
}];
}

def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
"TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
let summary = "a typed float_polynomial";
let description = [{
Example:

```mlir
!poly_ty = !polynomial.polynomial<ring=<coefficientType=f32>>
#poly = float<1.4 x**7 + 4.5> : !poly_ty
#poly_verbose = #polynomial.typed_float_polynomial<1.4 x**7 + 4.5> : !poly_ty
```
}];
let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value);
let assemblyFormat = "$value `:` $type";
let builders = [
AttrBuilderWithInferredContext<(ins "Type":$type,
"const FloatPolynomial &":$value), [{
return $_get(
type.getContext(),
type,
FloatPolynomialAttr::get(type.getContext(), value));
}]>,
AttrBuilderWithInferredContext<(ins "Type":$type,
"const Attribute &":$value), [{
return $_get(type.getContext(), type, ::llvm::cast<FloatPolynomialAttr>(value));
}]>
];
let extraClassDeclaration = [{
using ValueType = ::mlir::Attribute;
}];
}

def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let summary = "An attribute specifying a polynomial ring.";
let summary = "an attribute specifying a polynomial ring";
let description = [{
A ring describes the domain in which polynomial arithmetic occurs. The ring
attribute in `polynomial` represents the more specific case of polynomials
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
return success();
}

template <typename PolynoimalAttrTy, typename Monomial>
template <typename Monomial>
LogicalResult
parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
llvm::StringSet<> &variables,
Expand Down Expand Up @@ -155,7 +155,7 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
llvm::SmallVector<IntMonomial> monomials;
llvm::StringSet<> variables;

if (failed(parsePolynomialAttr<IntPolynomialAttr, IntMonomial>(
if (failed(parsePolynomialAttr<IntMonomial>(
parser, monomials, variables,
[&](IntMonomial &monomial) -> OptionalParseResult {
APInt parsedCoeff(apintBitWidth, 1);
Expand All @@ -175,7 +175,6 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
}
return IntPolynomialAttr::get(parser.getContext(), result.value());
}

Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
Expand All @@ -191,8 +190,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
return OptionalParseResult(result);
};

if (failed(parsePolynomialAttr<FloatPolynomialAttr, FloatMonomial>(
parser, monomials, variables, parseAndStoreCoefficient))) {
if (failed(parsePolynomialAttr<FloatMonomial>(parser, monomials, variables,
parseAndStoreCoefficient))) {
return {};
}

Expand Down
82 changes: 82 additions & 0 deletions mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,88 @@ LogicalResult INTTOp::verify() {
return verifyNTTOp(this->getOperation(), ring, tensorType);
}

ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
// Using the built-in parser.parseAttribute requires the full
// #polynomial.typed_int_polynomial syntax, which is excessive.
// Instead we parse a keyword int to signal it's an integer polynomial
Type type;
if (succeeded(parser.parseOptionalKeyword("float"))) {
Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr);
if (floatPolyAttr) {
if (parser.parseColon() || parser.parseType(type))
return failure();
result.addAttribute("value",
TypedFloatPolynomialAttr::get(type, floatPolyAttr));
result.addTypes(type);
return success();
}
}

if (succeeded(parser.parseOptionalKeyword("int"))) {
Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr);
if (intPolyAttr) {
if (parser.parseColon() || parser.parseType(type))
return failure();

result.addAttribute("value",
TypedIntPolynomialAttr::get(type, intPolyAttr));
result.addTypes(type);
return success();
}
}

// In the worst case, still accept the verbose versions.
TypedIntPolynomialAttr typedIntPolyAttr;
OptionalParseResult res =
parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
typedIntPolyAttr, "value", result.attributes);
if (res.has_value() && succeeded(res.value())) {
result.addTypes(typedIntPolyAttr.getType());
return success();
}

TypedFloatPolynomialAttr typedFloatPolyAttr;
res = parser.parseAttribute<TypedFloatPolynomialAttr>(
typedFloatPolyAttr, "value", result.attributes);
if (res.has_value() && succeeded(res.value())) {
result.addTypes(typedFloatPolyAttr.getType());
return success();
}

return failure();
}

void ConstantOp::print(OpAsmPrinter &p) {
p << " ";
if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) {
p << "int";
intPoly.getValue().print(p);
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
p << "float";
floatPoly.getValue().print(p);
} else {
assert(false && "unexpected attribute type");
}
p << " : ";
p.printType(getOutput().getType());
}

LogicalResult ConstantOp::inferReturnTypes(
MLIRContext *context, std::optional<mlir::Location> location,
ConstantOp::Adaptor adaptor,
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
Attribute operand = adaptor.getValue();
if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
inferredReturnTypes.push_back(intPoly.getType());
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
inferredReturnTypes.push_back(floatPoly.getType());
} else {
assert(false && "unexpected attribute type");
return failure();
}
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd canonicalization patterns
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 8 additions & 4 deletions mlir/test/Dialect/Polynomial/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,19 @@ module {

func.func @test_monic_monomial_mul() {
%five = arith.constant 5 : index
%0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
%1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
return
}

func.func @test_constant() {
%0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
%1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<ring=#ring1>
%2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<ring=#ring2>
%0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
%1 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
%2 = polynomial.constant float<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>

// Test verbose fallbacks
%verb0 = polynomial.constant #polynomial.typed_int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
%verb2 = polynomial.constant #polynomial.typed_float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
return
}

Expand Down
Loading