Skip to content

Commit 2b7d6b4

Browse files
[mlir][transform] Fix handling of transitive include in interpreter.
Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not *also* declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules. This PR extends the loading missing as follows: in `defineDeclaredSymbols`, not only are the definitions inserted that are forward-declared in the main module, but any such inserted definition is scanned for further dependencies, and those are processed in the same way as the forward-declarations from the main module.
1 parent 214ce4d commit 2b7d6b4

File tree

3 files changed

+95
-9
lines changed

3 files changed

+95
-9
lines changed

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -311,32 +311,43 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
311311
auto readOnlyName =
312312
StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
313313

314+
// Collect symbols missing in the block.
315+
SmallVector<SymbolOpInterface> missingSymbols;
316+
LLVM_DEBUG(DBGS() << "searching block for missing symbols:\n");
314317
for (Operation &op : llvm::make_early_inc_range(block)) {
315318
LLVM_DEBUG(DBGS() << op << "\n");
316319
auto symbol = dyn_cast<SymbolOpInterface>(op);
317320
if (!symbol)
318321
continue;
319322
if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
320323
continue;
324+
LLVM_DEBUG(DBGS() << " -> symbol missing\n");
325+
missingSymbols.push_back(symbol);
326+
}
321327

322-
LLVM_DEBUG(DBGS() << "looking for definition of symbol "
323-
<< symbol.getNameAttr() << ":");
324-
SymbolTable symbolTable(definitions);
325-
Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
328+
// Resolve missing symbols until they are all resolved.
329+
while (!missingSymbols.empty()) {
330+
SymbolOpInterface symbol = missingSymbols.pop_back_val();
331+
LLVM_DEBUG(DBGS() << "looking for definition of symbol @"
332+
<< symbol.getNameAttr().getValue() << ": ");
333+
SymbolTable definitionsSymbolTable(definitions);
334+
Operation *externalSymbol =
335+
definitionsSymbolTable.lookup(symbol.getNameAttr());
326336
if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
327337
externalSymbol->getRegion(0).empty()) {
328338
LLVM_DEBUG(llvm::dbgs() << "not found\n");
329339
continue;
330340
}
331341

332-
auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
342+
auto symbolFunc = dyn_cast<FunctionOpInterface>(symbol.getOperation());
333343
auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
334344
if (!symbolFunc || !externalSymbolFunc) {
335345
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
336346
continue;
337347
}
338348

339-
LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
349+
LLVM_DEBUG(llvm::dbgs() << "found " << externalSymbol << " from "
350+
<< externalSymbol->getLoc() << "\n");
340351
if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
341352
return symbolFunc.emitError()
342353
<< "external definition has a mismatching signature ("
@@ -367,10 +378,53 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
367378
}
368379
}
369380

370-
OpBuilder builder(&op);
371-
builder.setInsertionPoint(&op);
372-
builder.clone(*externalSymbol);
381+
OpBuilder builder(symbol);
382+
builder.setInsertionPoint(symbol);
383+
Operation *newSymbol = builder.clone(*externalSymbol);
384+
builder.setInsertionPoint(newSymbol);
373385
symbol->erase();
386+
387+
LLVM_DEBUG(DBGS() << "scanning definition of @"
388+
<< externalSymbolFunc.getNameAttr().getValue()
389+
<< " for symbol usages\n");
390+
externalSymbolFunc.walk([&](CallOpInterface callOp) {
391+
LLVM_DEBUG(DBGS() << " found symbol usage in:\n" << callOp << "\n");
392+
CallInterfaceCallable callable = callOp.getCallableForCallee();
393+
if (!isa<SymbolRefAttr>(callable)) {
394+
LLVM_DEBUG(DBGS() << " not a 'SymbolRefAttr'\n");
395+
return WalkResult::advance();
396+
}
397+
398+
StringRef callableSymbol =
399+
cast<SymbolRefAttr>(callable).getLeafReference();
400+
LLVM_DEBUG(DBGS() << " looking for @" << callableSymbol
401+
<< " in definitions: ");
402+
403+
Operation *callableOp = definitionsSymbolTable.lookup(callableSymbol);
404+
if (!isa<SymbolRefAttr>(callable)) {
405+
LLVM_DEBUG(llvm::dbgs() << "not found\n");
406+
return WalkResult::advance();
407+
}
408+
LLVM_DEBUG(llvm::dbgs() << "found " << callableOp << " from "
409+
<< callableOp->getLoc() << "\n");
410+
411+
if (!block.getParent() || !block.getParent()->getParentOp()) {
412+
LLVM_DEBUG(DBGS() << "could not get parent of provided block");
413+
return WalkResult::advance();
414+
}
415+
416+
SymbolTable targetSymbolTable(block.getParent()->getParentOp());
417+
if (targetSymbolTable.lookup(callableSymbol)) {
418+
LLVM_DEBUG(DBGS() << " symbol @" << callableSymbol
419+
<< " already present in target\n");
420+
return WalkResult::advance();
421+
}
422+
423+
LLVM_DEBUG(DBGS() << " cloning op into target\n");
424+
builder.clone(*callableOp);
425+
426+
return WalkResult::advance();
427+
});
374428
}
375429

376430
return success();
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
2+
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
3+
4+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
5+
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
6+
7+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
8+
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
9+
10+
// The definition of the @bar named sequence is provided in another file. It
11+
// will be included because of the pass option. That sequence uses another named
12+
// sequence @foo, which should be made available here. Repeated application of
13+
// the same pass, with or without the library option, should not be a problem.
14+
// Note that the same diagnostic produced twice at the same location only
15+
// needs to be matched once.
16+
17+
// expected-remark @below {{message}}
18+
module attributes {transform.with_named_sequence} {
19+
// CHECK-DAG: transform.named_sequence @foo
20+
// CHECK-DAG: transform.named_sequence @bar
21+
transform.named_sequence private @bar(!transform.any_op {transform.readonly})
22+
23+
transform.sequence failures(propagate) {
24+
^bb0(%arg0: !transform.any_op):
25+
include @bar failures(propagate) (%arg0) : (!transform.any_op) -> ()
26+
}
27+
}

mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
// RUN: mlir-opt %s
22

33
module attributes {transform.with_named_sequence} {
4+
transform.named_sequence @bar(%arg0: !transform.any_op) {
5+
transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
6+
transform.yield
7+
}
8+
49
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
510
transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
611
transform.yield

0 commit comments

Comments
 (0)