diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index ba146920fae2e..0db097d14cd3c 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -25,22 +25,37 @@ using namespace mlir; -/// Remap locations from the inlined blocks with CallSiteLoc locations with the -/// provided caller location. +/// Remap all locations reachable from the inlined blocks with CallSiteLoc +/// locations with the provided caller location. static void remapInlinedLocations(iterator_range inlinedBlocks, Location callerLoc) { - DenseMap mappedLocations; - auto remapOpLoc = [&](Operation *op) { - auto it = mappedLocations.find(op->getLoc()); - if (it == mappedLocations.end()) { - auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); - it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; + DenseMap mappedLocations; + auto remapLoc = [&](Location loc) { + auto [it, inserted] = mappedLocations.try_emplace(loc); + // Only query the attribute uniquer once per callsite attribute. + if (inserted) { + auto newLoc = CallSiteLoc::get(loc, callerLoc); + it->getSecond() = newLoc; } - op->setLoc(it->second); + return it->second; }; - for (auto &block : inlinedBlocks) - block.walk(remapOpLoc); + + AttrTypeReplacer attrReplacer; + attrReplacer.addReplacement( + [&](LocationAttr loc) -> std::pair { + return {remapLoc(loc), WalkResult::skip()}; + }); + + for (Block &block : inlinedBlocks) { + for (BlockArgument &arg : block.getArguments()) + if (LocationAttr newLoc = remapLoc(arg.getLoc())) + arg.setLoc(newLoc); + + for (Operation &op : block) + attrReplacer.recursivelyReplaceElementsIn(&op, /*replaceAttrs=*/false, + /*replaceLocs=*/true); + } } static void remapInlinedOperands(iterator_range inlinedBlocks, diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir index 2a08e625ba79e..79a2936b104fa 100644 --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -215,9 +215,9 @@ func.func @func_with_block_args_location(%arg0 : i32) { // INLINE-LOC-LABEL: func @func_with_block_args_location_callee1 // INLINE-LOC: cf.br -// INLINE-LOC: ^bb{{[0-9]+}}(%{{.*}}: i32 loc("foo") +// INLINE-LOC: ^bb{{[0-9]+}}(%{{.*}}: i32 loc(callsite("foo" at "bar")) func.func @func_with_block_args_location_callee1(%arg0 : i32) { - call @func_with_block_args_location(%arg0) : (i32) -> () + call @func_with_block_args_location(%arg0) : (i32) -> () loc("bar") return }