Skip to content

[mlir][polynomial] Move primitive root attr to ring attr #111931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {

`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`

The choice of primitive root may be optionally specified.
The choice of primitive root is specified in the primitiveRootAttr of RingAttr.
Its degree affects the behavior of ntt performed, with n-th primitive root
performing cyclic convolution and 2n-th primitive root performing negacyclic
convolution.
}];
let arguments = (ins
Polynomial_PolynomialType:$input,
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
);
let arguments = (ins Polynomial_PolynomialType:$input);
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
Expand All @@ -335,12 +335,12 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
`polynomial.ntt`). The ring of the polynomial is taken from the required
encoding attribute of the tensor.

The choice of primitive root may be optionally specified.
The choice of primitive root is specified in the primitiveRootAttr of RingAttr.
Its degree affects the behavior of ntt performed, with n-th primitive root
performing cyclic convolution and 2n-th primitive root performing negacyclic
convolution.
}];
let arguments = (
ins RankedTensorOf<[AnyInteger]>:$input,
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
);
let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
let results = (outs Polynomial_PolynomialType:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
Expand Down
56 changes: 29 additions & 27 deletions mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,26 @@ def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
}];
}

def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
let summary = "an attribute containing an integer and its degree as a root of unity";
let description = [{
A primitive root attribute stores an integer root `value` and an integer
`degree`, corresponding to a primitive root of unity of the given degree in
an unspecified ring.

Example:

```mlir
#poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
```
}];
let parameters = (ins
"::mlir::IntegerAttr":$value,
"::mlir::IntegerAttr":$degree
);
let assemblyFormat = "`<` struct(params) `>`";
}

def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let summary = "an attribute specifying a polynomial ring";
let description = [{
Expand All @@ -142,6 +162,9 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
modulus. For single-variable polynomials, an "polynomialModulus" is always specificed
via a single polynomial, which we call `polynomialModulus`.

For ntt/intt and mul to ntt/intt optimization to work, an n-th or 2n-th
_primitiveRoot_ should be specified.

An expressive example is polynomials with i32 coefficients, whose
coefficients are taken modulo `2**32 - 5`, with a polynomial modulus of
`x**1024 - 1`.
Expand Down Expand Up @@ -177,46 +200,25 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
OptionalParameter<"::mlir::polynomial::PrimitiveRootAttr">: $primitiveRoot
);
let genVerifyDecl = 1;
let assemblyFormat = "`<` struct(params) `>`";
let builders = [
AttrBuilderWithInferredContext<
(ins "::mlir::Type":$coefficientTy,
CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
CArg<"::mlir::polynomial::PrimitiveRootAttr", "nullptr"> :$primitiveRootAttr), [{
return $_get(
coefficientTy.getContext(),
coefficientTy,
coefficientModulusAttr,
polynomialModulusAttr);
polynomialModulusAttr,
primitiveRootAttr);
}]>,
];
}

def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
let summary = "an attribute containing an integer and its degree as a root of unity";
let description = [{
A primitive root attribute stores an integer root `value` and an integer
`degree`, corresponding to a primitive root of unity of the given degree in
an unspecified ring.

This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
to specify the root of unity used in lowering the transform.

Example:

```mlir
#poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
```
}];
let parameters = (ins
"::mlir::IntegerAttr":$value,
"::mlir::IntegerAttr":$degree
);
let assemblyFormat = "`<` struct(params) `>`";
}


#endif // POLYNOMIAL_ATTRIBUTES
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
LogicalResult
RingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
Type coefficientType, IntegerAttr coefficientModulus,
IntPolynomialAttr polynomialModulus) {
IntPolynomialAttr polynomialModulus,
PrimitiveRootAttr primitiveRoot) {
if (coefficientModulus) {
auto coeffIntType = llvm::dyn_cast<IntegerType>(coefficientType);
if (!coeffIntType) {
Expand Down
12 changes: 4 additions & 8 deletions mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ include "mlir/Dialect/Polynomial/IR/Polynomial.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"

def Equal : Constraint<CPred<"$0 == $1">>;

// Get a -1 integer attribute of the same type as the polynomial SSA value's
// ring coefficient type.
def getMinusOne
Expand All @@ -30,15 +28,13 @@ def SubAsAdd : Pat<
(Arith_ConstantOp (getMinusOne $g))))>;

def INTTAfterNTT : Pat<
(Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
(replaceWithValue $poly),
[(Equal $r1, $r2)]
(Polynomial_INTTOp (Polynomial_NTTOp $poly)),
(replaceWithValue $poly)
>;

def NTTAfterINTT : Pat<
(Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
(replaceWithValue $tensor),
[(Equal $r1, $r2)]
(Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
(replaceWithValue $tensor)
>;

#endif // POLYNOMIAL_CANONICALIZATION
17 changes: 10 additions & 7 deletions mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
/// Verify that the types involved in an NTT or INTT operation are
/// compatible.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
RankedTensorType tensorType,
std::optional<PrimitiveRootAttr> root) {
RankedTensorType tensorType) {
Attribute encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
Expand Down Expand Up @@ -166,9 +165,10 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
return diag;
}

if (root.has_value()) {
APInt rootValue = root.value().getValue().getValue();
APInt rootDegree = root.value().getDegree().getValue();
auto root = ring.getPrimitiveRoot();
if (root) {
APInt rootValue = root.getValue().getValue();
APInt rootDegree = root.getDegree().getValue();
APInt cmod = ring.getCoefficientModulus().getValue();
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
return op->emitOpError()
Expand All @@ -177,19 +177,22 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
<< "of unity mod " << cmod.getZExtValue()
<< ", with the specified degree " << rootDegree.getZExtValue();
}
} else {
return op->emitOpError()
<< "primitive root not provided but ntt/intt op called";
}

return success();
}

LogicalResult NTTOp::verify() {
return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
getOutput().getType(), getRoot());
getOutput().getType());
}

LogicalResult INTTOp::verify() {
return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
getInput().getType(), getRoot());
getInput().getType());
}

ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/Polynomial/canonicalization.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: mlir-opt -canonicalize %s | FileCheck %s
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=#root>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
!tensor_ty = tensor<8xi32, #ntt_ring>

Expand All @@ -11,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
// CHECK-NOT: polynomial.ntt
// CHECK-NOT: polynomial.intt
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
%t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
%p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
%t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
%p1 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
%p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
// CHECK: return %[[RESULT]] : [[T]]
return %p2 : !ntt_poly_ty
Expand All @@ -24,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
// CHECK-NOT: polynomial.intt
// CHECK-NOT: polynomial.ntt
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
%p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
%t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
%p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
%t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
%t2 = arith.addi %t1, %t1 : !tensor_ty
// CHECK: return %[[RESULT]] : [[T]]
return %t2 : !tensor_ty
Expand Down
11 changes: 6 additions & 5 deletions mlir/test/Dialect/Polynomial/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
!poly_ty = !polynomial.polynomial<ring=#ring>

#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
#ntt_ring_root = #polynomial.primitive_root<value=31:i32, degree=8:index>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=#ntt_ring_root>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>

#ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
#ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2, primitiveRoot=#ntt_ring_2_root>
!ntt_poly_ty_2 = !polynomial.polynomial<ring=#ntt_ring_2>

module {
Expand Down Expand Up @@ -96,17 +97,17 @@ module {
}

func.func @test_ntt(%0 : !ntt_poly_ty) {
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
%1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
return
}

func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
%1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
%1 = polynomial.ntt %0 : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
return
}

func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
%1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
return
}
}
37 changes: 28 additions & 9 deletions mlir/test/Dialect/Polynomial/ops_errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,39 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
// -----

#my_poly = #polynomial.int_polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error@below {{expects a ring encoding to be provided to the tensor}}
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
return
}

// -----

#my_poly = #polynomial.int_polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error@below {{tensor encoding is not a ring attribute}}
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
return
}

// -----

#my_poly = #polynomial.int_polynomial<-1 + x**1024>
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

// CHECK-NOT: @test_invalid_intt
Expand All @@ -98,29 +101,45 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
// -----

#my_poly = #polynomial.int_polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
// expected-error@below {{does not match output type}}
// expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
%1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
return
}

// -----

#my_poly = #polynomial.int_polynomial<-1 + x**8>
// A valid root is 31
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
#root = #polynomial.primitive_root<value=32:i32, degree=8:index>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
// expected-error@below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}}
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=32:i16, degree=8:index>} : tensor<8xi32, #ring> -> !poly_ty
%1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
return
}

// -----

#my_poly = #polynomial.int_polynomial<-1 + x**8>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>

// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
// expected-error@below {{primitive root not provided but ntt/intt op called}}
%1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
return
}
Loading