Skip to content

[mlir][polynomial] Move primitive root attr to ring attr #111931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

ZenithalHourlyRate
Copy link
Member

Related to #93227
and google/heir#993

When ntt/intt ops are emitted as a result of pattern rewrite,
the primitive root attr must be provided in some way, and it
is convenient for it to be provided in ring attr.

As for using different primitive root for the same polynomial,
to_tensor/tensor.cast/from_tensor should be enough for changing
primitiveRoot attribute in RingAttr.

Cc @j2kun

Related to llvm#93227
and google/heir#993

When ntt/intt ops are emitted as a result of pattern rewrite,
the primitive root attr must be provided in some way, and it
is convenient for it to be provided in ring attr.

As for using different primitive root for the same polynomial,
to_tensor/tensor.cast/from_tensor should be enough for changing
primitiveRoot attribute in RingAttr.
@llvmbot llvmbot added the mlir label Oct 11, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2024

@llvm/pr-subscribers-mlir

Author: Hongren Zheng (ZenithalHourlyRate)

Changes

Related to #93227
and google/heir#993

When ntt/intt ops are emitted as a result of pattern rewrite,
the primitive root attr must be provided in some way, and it
is convenient for it to be provided in ring attr.

As for using different primitive root for the same polynomial,
to_tensor/tensor.cast/from_tensor should be enough for changing
primitiveRoot attribute in RingAttr.

Cc @j2kun


Full diff: https://github.com/llvm/llvm-project/pull/111931.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td (+10-10)
  • (modified) mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td (+29-27)
  • (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td (+4-8)
  • (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp (+10-7)
  • (modified) mlir/test/Dialect/Polynomial/canonicalization.mlir (+5-5)
  • (modified) mlir/test/Dialect/Polynomial/ops.mlir (+6-5)
  • (modified) mlir/test/Dialect/Polynomial/ops_errors.mlir (+28-9)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 755396c8b90235..63f9ff1def4e19 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -311,12 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
 
       `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
 
-    The choice of primitive root may be optionally specified.
+    The choice of primitive root is specified in the primitiveRootAttr of RingAttr.
+    Its degree affects the behavior of ntt performed, with n-th primitive root
+    performing cyclic convolution and 2n-th primitive root performing negacyclic
+    convolution.
   }];
-  let arguments = (ins
-    Polynomial_PolynomialType:$input,
-    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
-  );
+  let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
@@ -335,12 +335,12 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
     `polynomial.ntt`). The ring of the polynomial is taken from the required
     encoding attribute of the tensor.
 
-    The choice of primitive root may be optionally specified.
+    The choice of primitive root is specified in the primitiveRootAttr of RingAttr.
+    Its degree affects the behavior of ntt performed, with n-th primitive root
+    performing cyclic convolution and 2n-th primitive root performing negacyclic
+    convolution.
   }];
-  let arguments = (
-    ins RankedTensorOf<[AnyInteger]>:$input,
-    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
-  );
+  let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 7d59add3d37c2b..00c9239fc6369d 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -126,6 +126,26 @@ def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
   }];
 }
 
+def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
+  let summary = "an attribute containing an integer and its degree as a root of unity";
+  let description = [{
+    A primitive root attribute stores an integer root `value` and an integer
+    `degree`, corresponding to a primitive root of unity of the given degree in
+    an unspecified ring.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
+    ```
+  }];
+  let parameters = (ins
+    "::mlir::IntegerAttr":$value,
+    "::mlir::IntegerAttr":$degree
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let summary = "an attribute specifying a polynomial ring";
   let description = [{
@@ -142,6 +162,9 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
     modulus. For single-variable polynomials, an "polynomialModulus" is always specificed
     via a single polynomial, which we call `polynomialModulus`.
 
+    For ntt/intt and mul to ntt/intt optimization to work, an n-th or 2n-th
+    _primitiveRoot_ should be specified.
+
     An expressive example is polynomials with i32 coefficients, whose
     coefficients are taken modulo `2**32 - 5`, with a polynomial modulus of
     `x**1024 - 1`.
@@ -177,7 +200,8 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let parameters = (ins
     "Type": $coefficientType,
     OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
+    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
+    OptionalParameter<"::mlir::polynomial::PrimitiveRootAttr">: $primitiveRoot
   );
   let genVerifyDecl = 1;
   let assemblyFormat = "`<` struct(params) `>`";
@@ -185,38 +209,16 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
     AttrBuilderWithInferredContext<
         (ins "::mlir::Type":$coefficientTy,
               CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
-              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
+              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
+              CArg<"::mlir::polynomial::PrimitiveRootAttr", "nullptr"> :$primitiveRootAttr), [{
       return $_get(
         coefficientTy.getContext(),
         coefficientTy,
         coefficientModulusAttr,
-        polynomialModulusAttr);
+        polynomialModulusAttr,
+        primitiveRootAttr);
     }]>,
   ];
 }
 
-def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
-  let summary = "an attribute containing an integer and its degree as a root of unity";
-  let description = [{
-    A primitive root attribute stores an integer root `value` and an integer
-    `degree`, corresponding to a primitive root of unity of the given degree in
-    an unspecified ring.
-
-    This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
-    to specify the root of unity used in lowering the transform.
-
-    Example:
-
-    ```mlir
-    #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
-    ```
-  }];
-  let parameters = (ins
-    "::mlir::IntegerAttr":$value,
-    "::mlir::IntegerAttr":$degree
-  );
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
-
 #endif // POLYNOMIAL_ATTRIBUTES
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index cd7789a2e9531c..f3f6afdee9950c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -206,7 +206,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
 LogicalResult
 RingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
                  Type coefficientType, IntegerAttr coefficientModulus,
-                 IntPolynomialAttr polynomialModulus) {
+                 IntPolynomialAttr polynomialModulus,
+                 PrimitiveRootAttr primitiveRoot) {
   if (coefficientModulus) {
     auto coeffIntType = llvm::dyn_cast<IntegerType>(coefficientType);
     if (!coeffIntType) {
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 28c45e6846380c..a26b34e29d561f 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -14,8 +14,6 @@ include "mlir/Dialect/Polynomial/IR/Polynomial.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/PatternBase.td"
 
-def Equal : Constraint<CPred<"$0 == $1">>;
-
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
 // ring coefficient type.
 def getMinusOne
@@ -30,15 +28,13 @@ def SubAsAdd : Pat<
       (Arith_ConstantOp (getMinusOne $g))))>;
 
 def INTTAfterNTT : Pat<
-  (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
-  (replaceWithValue $poly),
-  [(Equal $r1, $r2)]
+  (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+  (replaceWithValue $poly)
 >;
 
 def NTTAfterINTT : Pat<
-  (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
-  (replaceWithValue $tensor),
-  [(Equal $r1, $r2)]
+  (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+  (replaceWithValue $tensor)
 >;
 
 #endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 460ef17167e801..30a6a004c50aff 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -134,8 +134,7 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
 /// Verify that the types involved in an NTT or INTT operation are
 /// compatible.
 static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
-                                 RankedTensorType tensorType,
-                                 std::optional<PrimitiveRootAttr> root) {
+                                 RankedTensorType tensorType) {
   Attribute encoding = tensorType.getEncoding();
   if (!encoding) {
     return op->emitOpError()
@@ -166,9 +165,10 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
     return diag;
   }
 
-  if (root.has_value()) {
-    APInt rootValue = root.value().getValue().getValue();
-    APInt rootDegree = root.value().getDegree().getValue();
+  auto root = ring.getPrimitiveRoot();
+  if (root) {
+    APInt rootValue = root.getValue().getValue();
+    APInt rootDegree = root.getDegree().getValue();
     APInt cmod = ring.getCoefficientModulus().getValue();
     if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
       return op->emitOpError()
@@ -177,6 +177,9 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
              << "of unity mod " << cmod.getZExtValue()
              << ", with the specified degree " << rootDegree.getZExtValue();
     }
+  } else {
+    return op->emitOpError()
+           << "primitive root not provided but ntt/intt op called";
   }
 
   return success();
@@ -184,12 +187,12 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
 
 LogicalResult NTTOp::verify() {
   return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
-                     getOutput().getType(), getRoot());
+                     getOutput().getType());
 }
 
 LogicalResult INTTOp::verify() {
   return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
-                     getInput().getType(), getRoot());
+                     getInput().getType());
 }
 
 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index c0ee514daab645..5a517a5e1ed9b4 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -canonicalize %s | FileCheck %s
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
 #root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=#root>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 !tensor_ty = tensor<8xi32, #ntt_ring>
 
@@ -11,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
   // CHECK-NOT: polynomial.ntt
   // CHECK-NOT: polynomial.intt
   // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]]  : [[T]]
-  %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %p1 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
   %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %p2 : !ntt_poly_ty
@@ -24,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
   // CHECK-NOT: polynomial.intt
   // CHECK-NOT: polynomial.ntt
   // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
-  %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
+  %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
   %t2 = arith.addi %t1, %t1 : !tensor_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %t2 : !tensor_ty
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index faeb68a8b2c093..4998730c80c7ea 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -15,12 +15,13 @@
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
+#ntt_ring_root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=#ntt_ring_root>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 
 #ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
-#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
 #ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
+#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2, primitiveRoot=#ntt_ring_2_root>
 !ntt_poly_ty_2 = !polynomial.polynomial<ring=#ntt_ring_2>
 
 module {
@@ -96,17 +97,17 @@ module {
   }
 
   func.func @test_ntt(%0 : !ntt_poly_ty) {
-    %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+    %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
     return
   }
 
   func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
-    %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
+    %1 = polynomial.ntt %0 : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
     return
   }
 
   func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
-    %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+    %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
     return
   }
 }
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index 4937e17027afaa..003967e3f4228c 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -55,36 +55,39 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error@below {{expects a ring encoding to be provided to the tensor}}
-  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error@below {{tensor encoding is not a ring attribute}}
-  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
@@ -98,7 +101,8 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
@@ -106,7 +110,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
 func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
   // expected-error@below {{does not match output type}}
   // expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
-  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
+  %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
   return
 }
 
@@ -114,13 +118,28 @@ func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**8>
 // A valid root is 31
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#root = #polynomial.primitive_root<value=32:i32, degree=8:index>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
 func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
   // expected-error@below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}}
-  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=32:i16, degree=8:index>} : tensor<8xi32, #ring> -> !poly_ty
+  %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.int_polynomial<-1 + x**8>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<ring=#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
+  // expected-error@below {{primitive root not provided but ntt/intt op called}}
+  %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
   return
 }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants