Skip to content

Conversation

rolfmorel
Copy link
Contributor

@rolfmorel rolfmorel commented Sep 17, 2025

Introduces a Transform-dialect SMT-extension so that we can have an op to express constrains on Transform-dialect params, in particular when these params are knobs -- see transform.tune.knob -- and can hence be seen as symbolic variables. This op allows expressing joint constraints over multiple params/knobs together.

While the op's semantics are clearly defined, the interpreted semantics -- i.e. the apply() method -- for now just defaults to failure. In the future we should support attaching an implementation so that users can Bring Your Own Solver and thereby control performance of interpreting the op. For now the main usage is to walk schedule IR and collect these constraints so that knobs can be rewritten to constants that satisfy the constraints.

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

Introduces a Transform-dialect SMT-extension so that we can have an op to express constrains on Transform-dialect params, in particular when these params are knobs -- see transform.tune.knob -- and can hence be seen as symbolic variables. This op allows expressing joint constraints over multiple params/knobs together.

While the op's semantics are clearly defined, the operational semantics - i.e. the apply() method - for now just defaults to failure. In the future we should support attaching an implementation so that users can Bring Your Own Solver and thereby control performance of interpreting the op. For now the main usage is to walk schedule IR and collect these constraints so that knobs can be rewritten to constants that satisfy the constraints.


Patch is 24.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159450.diff

18 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/CMakeLists.txt (+1)
  • (added) mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt (+6)
  • (added) mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h (+27)
  • (added) mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h (+22)
  • (added) mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td (+52)
  • (modified) mlir/lib/Bindings/Python/DialectSMT.cpp (+7)
  • (modified) mlir/lib/Dialect/Transform/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt (+12)
  • (added) mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp (+35)
  • (added) mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp (+55)
  • (modified) mlir/lib/RegisterAllExtensions.cpp (+2)
  • (modified) mlir/python/CMakeLists.txt (+9)
  • (added) mlir/python/mlir/dialects/TransformSMTExtensionOps.td (+19)
  • (modified) mlir/python/mlir/dialects/smt.py (+1)
  • (added) mlir/python/mlir/dialects/transform/smt.py (+36)
  • (added) mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir (+32)
  • (added) mlir/test/Dialect/Transform/test-smt-extension.mlir (+87)
  • (added) mlir/test/python/dialects/transform_smt_ext.py (+50)
diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index e70479b2a39f2..eb91ceccd4ef2 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -4,5 +4,6 @@ add_subdirectory(IR)
 add_subdirectory(IRDLExtension)
 add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
+add_subdirectory(SMTExtension)
 add_subdirectory(Transforms)
 add_subdirectory(TuneExtension)
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..da037c1e809de
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS SMTExtensionOps.td)
+mlir_tablegen(SMTExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(SMTExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectSMTExtensionOpsIncGen)
+
+add_mlir_doc(SMTExtensionOps SMTExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h
new file mode 100644
index 0000000000000..7079873cec048
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h
@@ -0,0 +1,27 @@
+//===- SMTExtension.h - SMT extension for Transform dialect -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
+#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the SMT extension of the Transform dialect in the given registry.
+void registerSMTExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
new file mode 100644
index 0000000000000..dfea2039a16c3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
@@ -0,0 +1,22 @@
+//===- SMTExtensionOps.h - SMT extension for Transform dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h.inc"
+
+
+#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
new file mode 100644
index 0000000000000..b987cb31e54bb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
@@ -0,0 +1,52 @@
+//===- SMTExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
+  DeclareOpInterfaceMethods<TransformOpInterface>,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  NoTerminator
+]> {
+  let cppNamespace = [{ mlir::transform::smt }];
+
+  let summary = "Express contraints on params interpreted as symbolic values";
+  let description = [{
+    Allows expressing constraints on params using the SMT dialect.
+
+    Each Transform dialect param provided as an operand has a corresponding
+    argument of SMT-type in the region. The SMT-Dialect ops in the region use
+    these arguments as operands.
+
+    The semantics of this op is that all the ops in the region together express
+    a constraint on the params-interpreted-as-smt-vars. The op fails in case the
+    expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
+    op succeeds.
+
+    ---
+
+    TODO: currently the operational semantics per the Transform interpreter is
+    to always fail. The intention is build out support for hooking in your own
+    operational semantics so you can invoke your favourite solver to determine
+    satisfiability of the corresponding constraint problem.
+  }];
+
+  let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
+  let regions = (region SizedRegion<1>:$body);
+  let assemblyFormat =
+      "`(` $params `)` attr-dict `:` type(operands) $body";
+
+  let hasVerifier = 1;
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 3123e3bdda496..6e28d96ca58a7 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -41,6 +41,13 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
                 return mlirSMTTypeGetBitVector(context, width);
               },
               "cls"_a, "width"_a, "context"_a = nb::none());
+  auto smtIntType = mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
+                         .def_classmethod(
+                             "get",
+                             [](const nb::object &, MlirContext context) {
+                               return mlirSMTTypeGetInt(context);
+                             },
+                             "cls"_a, "context"_a.none() = nb::none());
 
   auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
                          bool indentLetBody) {
diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 6e628353258d6..123c4b92271fe 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(IR)
 add_subdirectory(IRDLExtension)
 add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
+add_subdirectory(SMTExtension)
 add_subdirectory(Transforms)
 add_subdirectory(TuneExtension)
 add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..ba1cc464e506d
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRTransformSMTExtension
+  SMTExtension.cpp
+  SMTExtensionOps.cpp
+
+  DEPENDS
+  MLIRTransformDialectSMTExtensionOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRTransformDialect
+  MLIRSMT
+)
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
new file mode 100644
index 0000000000000..228e8d342a1f6
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
@@ -0,0 +1,35 @@
+//===- SMTExtension.cpp - SMT extension for the Transform dialect ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class SMTExtension : public transform::TransformDialectExtension<SMTExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SMTExtension)
+
+  SMTExtension() {
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+void mlir::transform::registerSMTExtension(DialectRegistry &dialectRegistry) {
+  dialectRegistry.addExtensions<SMTExtension>();
+}
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
new file mode 100644
index 0000000000000..8e7d0b18b7311
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -0,0 +1,55 @@
+//===- SMTExtensionOps.cpp - SMT extension for the Transform dialect ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ConstrainParamsOp
+//===----------------------------------------------------------------------===//
+
+void transform::smt::ConstrainParamsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getParamsMutable(), effects);
+}
+
+DiagnosedSilenceableFailure
+transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
+                                         transform::TransformResults &results,
+                                         transform::TransformState &state) {
+  // TODO: Proper operational semantics are to chuck the SMT problem in the body
+  //       to a SMT solver with the arguments of the body constrained to the
+  //       values passed into the op. Success or failure is then determined by
+  //       the solver's result.
+  //       One way to support this is to just promise the TransformOpInterface
+  //       and allow for users to attach their own implementation, which would,
+  //       e.g., translate the ops to SMTLIB and hand that over to the user's
+  //       favourite solver. This requires changes to the dialect's verifier.
+  return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+}
+
+LogicalResult transform::smt::ConstrainParamsOp::verify() {
+  if (getOperands().size() != getBody().getNumArguments())
+    return emitOpError(
+        "must have the same number of block arguments as operands");
+
+  for (auto &op : getBody().getOps()) {
+    if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
+      return emitOpError(
+          "ops contained in region should belong to SMT-dialect");
+  }
+
+  return success();
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 69a85dbe141ce..3839172fd0b42 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -53,6 +53,7 @@
 #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
 #include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
@@ -108,6 +109,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
   transform::registerIRDLExtension(registry);
   transform::registerLoopExtension(registry);
   transform::registerPDLExtension(registry);
+  transform::registerSMTExtension(registry);
   transform::registerTuneExtension(registry);
   vector::registerTransformDialectExtension(registry);
   arm_neon::registerTransformDialectExtension(registry);
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c983914722ce1..cc7676ada7583 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   DIALECT_NAME transform
   EXTENSION_NAME transform_pdl_extension)
 
+declare_mlir_dialect_extension_python_bindings(
+ADD_TO_PARENT MLIRPythonSources.Dialects
+ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/TransformSMTExtensionOps.td
+  SOURCES
+    dialects/transform/smt.py
+  DIALECT_NAME transform
+  EXTENSION_NAME transform_smt_extension)
+
 declare_mlir_dialect_extension_python_bindings(
 ADD_TO_PARENT MLIRPythonSources.Dialects
 ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/TransformSMTExtensionOps.td b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td
new file mode 100644
index 0000000000000..3e92417a35d13
--- /dev/null
+++ b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td
@@ -0,0 +1,19 @@
+//===-- TransformSMTExtensionOps.td - Binding entry point --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the generated Python bindings for the SMT extension of the
+// Transform dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
+#define PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
+
+include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td"
+
+#endif // PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py
index ae7a4c41cbc3a..38970d17abd47 100644
--- a/mlir/python/mlir/dialects/smt.py
+++ b/mlir/python/mlir/dialects/smt.py
@@ -3,6 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._smt_ops_gen import *
+from ._smt_enum_gen import *
 
 from .._mlir_libs._mlirDialectsSMT import *
 from ..extras.meta import region_op
diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py
new file mode 100644
index 0000000000000..7cb06e8bfed54
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/smt.py
@@ -0,0 +1,36 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Sequence
+
+from ...ir import Type, Block
+from .._transform_smt_extension_ops_gen import *
+from .._transform_smt_extension_ops_gen import _Dialect
+from ...dialects import transform
+
+try:
+    from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class ConstrainParamsOp(ConstrainParamsOp):
+    def __init__(
+        self,
+        params: Sequence[transform.AnyParamType],
+        arg_types: Sequence[Type],
+        loc=None,
+        ip=None,
+    ):
+        assert len(params) == len(arg_types)
+        super().__init__(
+            params,
+            loc=loc,
+            ip=ip,
+        )
+        self.regions[0].blocks.append(*arg_types)
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
new file mode 100644
index 0000000000000..3961d7c5ba72b
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+// CHECK-LABEL: @constraint_not_using_smt_ops
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
+    %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+    // expected-error@below {{ops contained in region should belong to SMT-dialect}}
+    transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+      ^bb0(%param_as_smt_var: !smt.int):
+      %c4 = arith.constant 4 : i32
+      // This is the kind of thing one might think works:
+      //arith.remsi %param_as_smt_var, %c4 : i32
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @operands_not_one_to_one_with_vars
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
+    %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+    // expected-error@below {{must have the same number of block arguments as operands}}
+    transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+      ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
+      // This is the kind of thing one might think works:
+      //arith.remsi %param_as_smt_var, %c4 : i32
+    }
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-smt-extension.mlir b/mlir/test/Dialect/Transform/test-smt-extension.mlir
new file mode 100644
index 0000000000000..29d15175ae4ec
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-smt-extension.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-opt %s --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @schedule_with_constrained_param
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @schedule_with_constrained_param(%arg0: !transform.any_op {transform.readonly}) {
+    // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
+    %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+
+    // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
+    transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+      // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
+      ^bb0(%param_as_smt_var: !smt.int):
+      // CHECK: %[[C0:.*]] = smt.int.constant 0
+      %c0 = smt.int.constant 0
+      // CHECK: %[[C43:.*]] = smt.int.constant 43
+      %c43 = smt.int.constant 43
+      // CHECK: %[[LOWER_BOUND:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
+      %lower_bound = smt.int.cmp le %c0, %param_as_smt_var
+      // CHECK: smt.assert %[[LOWER_BOUND]]
+      smt.assert %lower_bound
+      // CHECK: %[[UPPER_BOUND:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
+      %upper_bound = smt.int.cmp le %param_as_smt_var, %c43
+      // CHECK: smt.assert %[[UPPER_BOUND]]
+      smt.assert %upper_bound
+    }
+    ...
[truncated]

Copy link

github-actions bot commented Sep 17, 2025

✅ With the latest revision this PR passed the Python code formatter.

Copy link

github-actions bot commented Sep 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@fschlimb fschlimb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM.

Introduces a SMT Transform-dialect extension so that we can have an op
to express constrains on Transform-dialect params, in particular when
these params are knobs -- see transform.tune.knob -- and can hence can
be seen as symbolic variables. This op allows expressing joint
constraints over multiple params/knobs together.

While the op's semantics are clearly defined, the operational semantics
- i.e. the `apply()` method - for now just defaults to failure. In the
future we should support attaching an implementation so that users can
Bring Your Own Solver and thereby control performance of interpreting
the op. For now the main usage is to walk schedule IR and collect these
constraints so that knobs can be rewritten to constants that satify the
constraints.
@rolfmorel rolfmorel enabled auto-merge (squash) September 21, 2025 20:24
@rolfmorel rolfmorel merged commit d8b84be into llvm:main Sep 21, 2025
9 checks passed
ckoparkar added a commit to ckoparkar/llvm-project that referenced this pull request Sep 23, 2025
* main: (1562 commits)
  Document Policy on supporting newer C++ standard in LLVM codebase (llvm#156823)
  [MLIR][Transform][SMT] Introduce transform.smt.constrain_params (llvm#159450)
  Reapply "[compiler-rt] Remove %T from shared object substitutions (llvm#155302)"
  [NFC] [IndVarSimplify] Add non-overflowing usub test (llvm#159683)
  [Github] Remove separate tools checkout from pr-code workflows (llvm#159967)
  [clang] fix using enum redecl in template regression (llvm#159996)
  [DAG] Skip `mstore` combine for `<1 x ty>` vectors (llvm#159915)
  [mlir] Expose optional `PatternBenefit` to `func` populate functions (NFC) (llvm#159986)
  [LV] Set correct costs for interleave group members.
  [clang] ast-dump: use template pattern for `instantiated_from` (llvm#159952)
  [ARM] ha-alignstack-call.ll - regenerate test checks (llvm#159988)
  [LLD][MachO] Silence warning when building with MSVC
  [llvm][Analysis] Silence warning when building with MSVC
  [LV] Skip select cost for invariant divisors in legacy cost model.
  [Clang] Fix an error-recovery crash after d1a80de (llvm#159976)
  [VPlanPatternMatch] Introduce m_ConstantInt (llvm#159558)
  [GlobalISel] Add G_ABS computeKnownBits (llvm#154413)
  [gn build] Port 4cabd1e
  Reland "[clangd] Add feature modules registry" (llvm#154836)
  [LV] Also handle non-uniform scalarized loads when processing AddrDefs.
  ...
SeongjaeP pushed a commit to SeongjaeP/llvm-project that referenced this pull request Sep 23, 2025
…#159450)

Introduces a Transform-dialect SMT-extension so that we can have an op
to express constrains on Transform-dialect params, in particular when
these params are knobs -- see transform.tune.knob -- and can hence be
seen as symbolic variables. This op allows expressing joint constraints
over multiple params/knobs together.

While the op's semantics are clearly defined, per SMTLIB, the interpreted
semantics -- i.e. the `apply()` method -- for now just defaults to failure. In
the future we should support attaching an implementation so that users
can Bring Your Own Solver and thereby control performance of 
interpreting the op. For now the main usage is to walk schedule IR and 
collect these constraints so that knobs can be rewritten to constants that
satisfy the constraints.
YixingZhang007 pushed a commit to YixingZhang007/llvm-project that referenced this pull request Sep 27, 2025
…#159450)

Introduces a Transform-dialect SMT-extension so that we can have an op
to express constrains on Transform-dialect params, in particular when
these params are knobs -- see transform.tune.knob -- and can hence be
seen as symbolic variables. This op allows expressing joint constraints
over multiple params/knobs together.

While the op's semantics are clearly defined, per SMTLIB, the interpreted
semantics -- i.e. the `apply()` method -- for now just defaults to failure. In
the future we should support attaching an implementation so that users
can Bring Your Own Solver and thereby control performance of 
interpreting the op. For now the main usage is to walk schedule IR and 
collect these constraints so that knobs can be rewritten to constants that
satisfy the constraints.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants