From 7e6e545d3b9c69d9c3bbd2eeb0d107b087a85c4b Mon Sep 17 00:00:00 2001 From: Zenithal Date: Fri, 11 Oct 2024 02:01:45 +0000 Subject: [PATCH] [mlir][polynomial] Move primitive root attr to ring attr Related to https://github.com/llvm/llvm-project/pull/93227 and https://github.com/google/heir/issues/993 When ntt/intt ops are emitted as a result of pattern rewrite, the primitive root attr must be provided in some way, and it is convenient for it to be provided in ring attr. As for using different primitive root for the same polynomial, to_tensor/tensor.cast/from_tensor should be enough for changing primitiveRoot attribute in RingAttr. --- .../mlir/Dialect/Polynomial/IR/Polynomial.td | 20 +++---- .../Polynomial/IR/PolynomialAttributes.td | 56 ++++++++++--------- .../Polynomial/IR/PolynomialAttributes.cpp | 3 +- .../IR/PolynomialCanonicalization.td | 12 ++-- .../Dialect/Polynomial/IR/PolynomialOps.cpp | 17 +++--- .../Dialect/Polynomial/canonicalization.mlir | 10 ++-- mlir/test/Dialect/Polynomial/ops.mlir | 11 ++-- mlir/test/Dialect/Polynomial/ops_errors.mlir | 37 +++++++++--- 8 files changed, 94 insertions(+), 72 deletions(-) diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td index 755396c8b9023..63f9ff1def4e1 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -311,12 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> { `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}` - The choice of primitive root may be optionally specified. + The choice of primitive root is specified in the primitiveRootAttr of RingAttr. + Its degree affects the behavior of ntt performed, with n-th primitive root + performing cyclic convolution and 2n-th primitive root performing negacyclic + convolution. }]; - let arguments = (ins - Polynomial_PolynomialType:$input, - OptionalAttr:$root - ); + let arguments = (ins Polynomial_PolynomialType:$input); let results = (outs RankedTensorOf<[AnyInteger]>:$output); let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; let hasCanonicalizer = 1; @@ -335,12 +335,12 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> { `polynomial.ntt`). The ring of the polynomial is taken from the required encoding attribute of the tensor. - The choice of primitive root may be optionally specified. + The choice of primitive root is specified in the primitiveRootAttr of RingAttr. + Its degree affects the behavior of ntt performed, with n-th primitive root + performing cyclic convolution and 2n-th primitive root performing negacyclic + convolution. }]; - let arguments = ( - ins RankedTensorOf<[AnyInteger]>:$input, - OptionalAttr:$root - ); + let arguments = (ins RankedTensorOf<[AnyInteger]>:$input); let results = (outs Polynomial_PolynomialType:$output); let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td index 7d59add3d37c2..00c9239fc6369 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -126,6 +126,26 @@ def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr< }]; } +def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> { + let summary = "an attribute containing an integer and its degree as a root of unity"; + let description = [{ + A primitive root attribute stores an integer root `value` and an integer + `degree`, corresponding to a primitive root of unity of the given degree in + an unspecified ring. + + Example: + + ```mlir + #poly = #polynomial.primitive_root + ``` + }]; + let parameters = (ins + "::mlir::IntegerAttr":$value, + "::mlir::IntegerAttr":$degree + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { let summary = "an attribute specifying a polynomial ring"; let description = [{ @@ -142,6 +162,9 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { modulus. For single-variable polynomials, an "polynomialModulus" is always specificed via a single polynomial, which we call `polynomialModulus`. + For ntt/intt and mul to ntt/intt optimization to work, an n-th or 2n-th + _primitiveRoot_ should be specified. + An expressive example is polynomials with i32 coefficients, whose coefficients are taken modulo `2**32 - 5`, with a polynomial modulus of `x**1024 - 1`. @@ -177,7 +200,8 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { let parameters = (ins "Type": $coefficientType, OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus, - OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus + OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus, + OptionalParameter<"::mlir::polynomial::PrimitiveRootAttr">: $primitiveRoot ); let genVerifyDecl = 1; let assemblyFormat = "`<` struct(params) `>`"; @@ -185,38 +209,16 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { AttrBuilderWithInferredContext< (ins "::mlir::Type":$coefficientTy, CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr, - CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{ + CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr, + CArg<"::mlir::polynomial::PrimitiveRootAttr", "nullptr"> :$primitiveRootAttr), [{ return $_get( coefficientTy.getContext(), coefficientTy, coefficientModulusAttr, - polynomialModulusAttr); + polynomialModulusAttr, + primitiveRootAttr); }]>, ]; } -def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> { - let summary = "an attribute containing an integer and its degree as a root of unity"; - let description = [{ - A primitive root attribute stores an integer root `value` and an integer - `degree`, corresponding to a primitive root of unity of the given degree in - an unspecified ring. - - This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops - to specify the root of unity used in lowering the transform. - - Example: - - ```mlir - #poly = #polynomial.primitive_root - ``` - }]; - let parameters = (ins - "::mlir::IntegerAttr":$value, - "::mlir::IntegerAttr":$degree - ); - let assemblyFormat = "`<` struct(params) `>`"; -} - - #endif // POLYNOMIAL_ATTRIBUTES diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp index cd7789a2e9531..f3f6afdee9950 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp @@ -206,7 +206,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) { LogicalResult RingAttr::verify(function_ref emitError, Type coefficientType, IntegerAttr coefficientModulus, - IntPolynomialAttr polynomialModulus) { + IntPolynomialAttr polynomialModulus, + PrimitiveRootAttr primitiveRoot) { if (coefficientModulus) { auto coeffIntType = llvm::dyn_cast(coefficientType); if (!coeffIntType) { diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td index 28c45e6846380..a26b34e29d561 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td @@ -14,8 +14,6 @@ include "mlir/Dialect/Polynomial/IR/Polynomial.td" include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" -def Equal : Constraint>; - // Get a -1 integer attribute of the same type as the polynomial SSA value's // ring coefficient type. def getMinusOne @@ -30,15 +28,13 @@ def SubAsAdd : Pat< (Arith_ConstantOp (getMinusOne $g))))>; def INTTAfterNTT : Pat< - (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2), - (replaceWithValue $poly), - [(Equal $r1, $r2)] + (Polynomial_INTTOp (Polynomial_NTTOp $poly)), + (replaceWithValue $poly) >; def NTTAfterINTT : Pat< - (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2), - (replaceWithValue $tensor), - [(Equal $r1, $r2)] + (Polynomial_NTTOp (Polynomial_INTTOp $tensor)), + (replaceWithValue $tensor) >; #endif // POLYNOMIAL_CANONICALIZATION diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 460ef17167e80..30a6a004c50af 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -134,8 +134,7 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n, /// Verify that the types involved in an NTT or INTT operation are /// compatible. static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, - RankedTensorType tensorType, - std::optional root) { + RankedTensorType tensorType) { Attribute encoding = tensorType.getEncoding(); if (!encoding) { return op->emitOpError() @@ -166,9 +165,10 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, return diag; } - if (root.has_value()) { - APInt rootValue = root.value().getValue().getValue(); - APInt rootDegree = root.value().getDegree().getValue(); + auto root = ring.getPrimitiveRoot(); + if (root) { + APInt rootValue = root.getValue().getValue(); + APInt rootDegree = root.getDegree().getValue(); APInt cmod = ring.getCoefficientModulus().getValue(); if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) { return op->emitOpError() @@ -177,6 +177,9 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, << "of unity mod " << cmod.getZExtValue() << ", with the specified degree " << rootDegree.getZExtValue(); } + } else { + return op->emitOpError() + << "primitive root not provided but ntt/intt op called"; } return success(); @@ -184,12 +187,12 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, LogicalResult NTTOp::verify() { return verifyNTTOp(this->getOperation(), getInput().getType().getRing(), - getOutput().getType(), getRoot()); + getOutput().getType()); } LogicalResult INTTOp::verify() { return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(), - getInput().getType(), getRoot()); + getInput().getType()); } ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir index c0ee514daab64..5a517a5e1ed9b 100644 --- a/mlir/test/Dialect/Polynomial/canonicalization.mlir +++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -canonicalize %s | FileCheck %s #ntt_poly = #polynomial.int_polynomial<-1 + x**8> -#ntt_ring = #polynomial.ring #root = #polynomial.primitive_root +#ntt_ring = #polynomial.ring !ntt_poly_ty = !polynomial.polynomial !tensor_ty = tensor<8xi32, #ntt_ring> @@ -11,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty // CHECK-NOT: polynomial.ntt // CHECK-NOT: polynomial.intt // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]] - %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty - %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty + %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty + %p1 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty // CHECK: return %[[RESULT]] : [[T]] return %p2 : !ntt_poly_ty @@ -24,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty { // CHECK-NOT: polynomial.intt // CHECK-NOT: polynomial.ntt // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]] - %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty - %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty + %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty + %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty %t2 = arith.addi %t1, %t1 : !tensor_ty // CHECK: return %[[RESULT]] : [[T]] return %t2 : !tensor_ty diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir index faeb68a8b2c09..4998730c80c7e 100644 --- a/mlir/test/Dialect/Polynomial/ops.mlir +++ b/mlir/test/Dialect/Polynomial/ops.mlir @@ -15,12 +15,13 @@ !poly_ty = !polynomial.polynomial #ntt_poly = #polynomial.int_polynomial<-1 + x**8> -#ntt_ring = #polynomial.ring +#ntt_ring_root = #polynomial.primitive_root +#ntt_ring = #polynomial.ring !ntt_poly_ty = !polynomial.polynomial #ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536> -#ntt_ring_2 = #polynomial.ring #ntt_ring_2_root = #polynomial.primitive_root +#ntt_ring_2 = #polynomial.ring !ntt_poly_ty_2 = !polynomial.polynomial module { @@ -96,17 +97,17 @@ module { } func.func @test_ntt(%0 : !ntt_poly_ty) { - %1 = polynomial.ntt %0 {root=#polynomial.primitive_root} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring> + %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring> return } func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) { - %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2> + %1 = polynomial.ntt %0 : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2> return } func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) { - %1 = polynomial.intt %0 {root=#polynomial.primitive_root} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty + %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty return } } diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir index 4937e17027afa..003967e3f4228 100644 --- a/mlir/test/Dialect/Polynomial/ops_errors.mlir +++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir @@ -55,36 +55,39 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty { // ----- #my_poly = #polynomial.int_polynomial<-1 + x**1024> -#ring = #polynomial.ring +#root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_ntt // CHECK-NOT: polynomial.ntt func.func @test_invalid_ntt(%0 : !poly_ty) { // expected-error@below {{expects a ring encoding to be provided to the tensor}} - %1 = polynomial.ntt %0 {root=#polynomial.primitive_root} : !poly_ty -> tensor<1024xi32> + %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32> return } // ----- #my_poly = #polynomial.int_polynomial<-1 + x**1024> -#ring = #polynomial.ring +#root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_ntt // CHECK-NOT: polynomial.ntt func.func @test_invalid_ntt(%0 : !poly_ty) { // expected-error@below {{tensor encoding is not a ring attribute}} - %1 = polynomial.ntt %0 {root=#polynomial.primitive_root} : !poly_ty -> tensor<1024xi32, #my_poly> + %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly> return } // ----- #my_poly = #polynomial.int_polynomial<-1 + x**1024> +#root = #polynomial.primitive_root #ring = #polynomial.ring -#ring1 = #polynomial.ring +#ring1 = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_intt @@ -98,7 +101,8 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) { // ----- #my_poly = #polynomial.int_polynomial<-1 + x**1024> -#ring = #polynomial.ring +#root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_intt @@ -106,7 +110,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) { func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) { // expected-error@below {{does not match output type}} // expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}} - %1 = polynomial.intt %0 {root=#polynomial.primitive_root} : tensor<1025xi32, #ring> -> !poly_ty + %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty return } @@ -114,13 +118,28 @@ func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) { #my_poly = #polynomial.int_polynomial<-1 + x**8> // A valid root is 31 -#ring = #polynomial.ring +#root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-NOT: @test_invalid_intt // CHECK-NOT: polynomial.intt func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) { // expected-error@below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}} - %1 = polynomial.intt %0 {root=#polynomial.primitive_root} : tensor<8xi32, #ring> -> !poly_ty + %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty + return +} + +// ----- + +#my_poly = #polynomial.int_polynomial<-1 + x**8> +#ring = #polynomial.ring +!poly_ty = !polynomial.polynomial + +// CHECK-NOT: @test_invalid_intt +// CHECK-NOT: polynomial.intt +func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) { + // expected-error@below {{primitive root not provided but ntt/intt op called}} + %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty return }