Skip to content

[mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments #97213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 20, 2024

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Jun 30, 2024

This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated.

When converting a block signature, a BlockTypeConversionRewrite object and potentially multiple ReplaceBlockArgRewrite are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in BlockTypeConversionRewrite::commit and some were replaced in ReplaceBlockArgRewrite::commit. The new
BlockTypeConversionRewrite::commit implementation is much simpler and no longer modifies any IR; that is done only in ReplaceBlockArgRewrite now. The ConvertedArgInfo data structure is no longer needed.

To that end, materializations of dropped arguments are now built in applySignatureConversion instead of materializeLiveConversions; the latter function no longer has to deal with dropped arguments.

Other minor improvements:

  • Add more comments to applySignatureConversion.

Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in legalizeUnresolvedMaterialization instead of legalizeConvertedArgumentTypes.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion.

This is a re-upload of #96207.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/block_arg_rewrite_2 branch from 552f1a8 to c254f78 Compare July 13, 2024 15:37
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/arg_mat_experiment July 13, 2024 15:37
@matthias-springer matthias-springer force-pushed the users/matthias-springer/arg_mat_experiment branch from d8a0ebe to 5773176 Compare July 13, 2024 19:42
@matthias-springer matthias-springer force-pushed the users/matthias-springer/block_arg_rewrite_2 branch from c254f78 to b0b7813 Compare July 13, 2024 19:52
@matthias-springer matthias-springer marked this pull request as ready for review July 14, 2024 09:06
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jul 14, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 14, 2024

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated.

When converting a block signature, a BlockTypeConversionRewrite object and potentially multiple ReplaceBlockArgRewrite are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in BlockTypeConversionRewrite::commit and some were replaced in ReplaceBlockArgRewrite::commit. The new
BlockTypeConversionRewrite::commit implementation is much simpler and no longer modifies any IR; that is done only in ReplaceBlockArgRewrite now. The ConvertedArgInfo data structure is no longer needed.

To that end, materializations of dropped arguments are now built in applySignatureConversion instead of materializeLiveConversions; the latter function no longer has to deal with dropped arguments.

Other minor improvements:

  • Add more comments to applySignatureConversion.

Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in legalizeUnresolvedMaterialization instead of legalizeConvertedArgumentTypes.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion.

This is a re-upload of #96207.


Full diff: https://github.com/llvm/llvm-project/pull/97213.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+55-118)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+2-4)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1e0afee2373a9..0b552a7e1ca3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
   Block *insertBeforeBlock;
 };
 
-/// This structure contains the information pertaining to an argument that has
-/// been converted.
-struct ConvertedArgInfo {
-  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
-                   Value castValue = nullptr)
-      : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
-  /// The start index of in the new argument list that contains arguments that
-  /// replace the original.
-  unsigned newArgIdx;
-
-  /// The number of arguments that replaced the original argument.
-  unsigned newArgSize;
-
-  /// The cast value that was created to cast from the new arguments to the
-  /// old. This only used if 'newArgSize' > 1.
-  Value castValue;
-};
-
 /// Block type conversion. This rewrite is partially reflected in the IR.
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
-  BlockTypeConversionRewrite(
-      ConversionPatternRewriterImpl &rewriterImpl, Block *block,
-      Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
-      const TypeConverter *converter)
+  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                             Block *block, Block *origBlock,
+                             const TypeConverter *converter)
       : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
-        origBlock(origBlock), argInfo(argInfo), converter(converter) {}
+        origBlock(origBlock), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   /// The original block that was requested to have its signature converted.
   Block *origBlock;
 
-  /// The conversion information for each of the arguments. The information is
-  /// std::nullopt if the argument was dropped during conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
   /// The type converter used to convert the arguments.
   const TypeConverter *converter;
 };
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
 /// The type of materialization.
 enum MaterializationKind {
   /// This materialization materializes a conversion for an illegal block
-  /// argument type, to a legal one.
+  /// argument type, to the original one.
   Argument,
 
   /// This materialization materializes a conversion from an illegal type to a
   /// legal one.
-  Target
+  Target,
+
+  /// This materialization materializes a conversion from a legal type back to
+  /// an illegal one.
+  Source
 };
 
 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
       converterAndKind;
 };
 } // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        ValueRange inputs, Type outputType,
                                        const TypeConverter *converter);
 
-  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
-                                               ValueRange inputs,
-                                               Type outputType,
-                                               const TypeConverter *converter);
-
   Value buildUnresolvedTargetMaterialization(Location loc, Value input,
                                              Type outputType,
                                              const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
           dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
     for (Operation *op : block->getUsers())
       listener->notifyOperationModified(op);
-
-  // Process the remapping for each of the original arguments.
-  for (auto [origArg, info] :
-       llvm::zip_equal(origBlock->getArguments(), argInfo)) {
-    // Handle the case of a 1->0 value mapping.
-    if (!info) {
-      if (Value newArg =
-              rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        rewriter.replaceAllUsesWith(origArg, newArg);
-      continue;
-    }
-
-    // Otherwise this is a 1->1+ value mapping.
-    Value castValue = info->castValue;
-    assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
-    // If the argument is still used, replace it with the generated cast.
-    if (!origArg.use_empty()) {
-      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
-                                               castValue, origArg.getType()));
-    }
-  }
 }
 
 void BlockTypeConversionRewrite::rollback() {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
       continue;
 
     Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
-    bool isDroppedArg = replacementValue == origArg;
-    if (!isDroppedArg)
-      builder.setInsertionPointAfterValue(replacementValue);
+    assert(replacementValue && "replacement value not found");
     Value newArg;
     if (converter) {
+      builder.setInsertionPointAfterValue(replacementValue);
       newArg = converter->materializeSourceConversion(
-          builder, origArg.getLoc(), origArg.getType(),
-          isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+          builder, origArg.getLoc(), origArg.getType(), replacementValue);
       assert((!newArg || newArg.getType() == origArg.getType()) &&
              "materialization hook did not provide a value of the expected "
              "type");
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
           << "failed to materialize conversion for block argument #"
           << it.index() << " that remained live after conversion, type was "
           << origArg.getType();
-      if (!isDroppedArg)
-        diag << ", with target type " << replacementValue.getType();
       diag.attachNote(liveUser->getLoc())
           << "see existing live user here: " << *liveUser;
       return failure();
@@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   // Replace all uses of the old block with the new block.
   block->replaceAllUsesWith(newBlock);
 
-  // Remap each of the original arguments as determined by the signature
-  // conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-  argInfo.resize(origArgCount);
-
   for (unsigned i = 0; i != origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap)
-      continue;
     BlockArgument origArg = block->getArgument(i);
+    Type origArgType = origArg.getType();
+
+    std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
+        signatureConversion.getInputMapping(i);
+    if (!inputMap) {
+      // This block argument was dropped and no replacement value was provided.
+      // Materialize a replacement value "out of thin air".
+      Value repl = buildUnresolvedMaterialization(
+          MaterializationKind::Source, newBlock, newBlock->begin(),
+          origArg.getLoc(), /*inputs=*/ValueRange(),
+          /*outputType=*/origArgType, converter);
+      mapping.map(origArg, repl);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+      continue;
+    }
 
-    // If inputMap->replacementValue is not nullptr, then the argument is
-    // dropped and a replacement value is provided to be the remappedValue.
-    if (inputMap->replacementValue) {
+    if (Value repl = inputMap->replacementValue) {
+      // This block argument was dropped and a replacement value was provided.
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, inputMap->replacementValue);
+      mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
     }
 
-    // Otherwise, this is a 1->1+ mapping.
+    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
+    // dialect conversion. Therefore, we need an argument materialization to
+    // turn the replacement block arguments into a single SSA value that can be
+    // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
+    Value argMat = buildUnresolvedMaterialization(
+        MaterializationKind::Argument, newBlock, newBlock->begin(),
+        origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
+    mapping.map(origArg, argMat);
+    appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
 
-    // If this is a 1->1 mapping and the types of new and replacement arguments
-    // match (i.e. it's an identity map), then the argument is mapped to its
-    // original type.
     // FIXME: We simply pass through the replacement argument if there wasn't a
     // converter, which isn't great as it allows implicit type conversions to
     // appear. We should properly restructure this code to handle cases where a
     // converter isn't provided and also to properly handle the case where an
     // argument materialization is actually a temporary source materialization
     // (e.g. in the case of 1->N).
-    if (replArgs.size() == 1 &&
-        (!converter || replArgs[0].getType() == origArg.getType())) {
-      newArg = replArgs.front();
-      mapping.map(origArg, newArg);
-    } else {
-      // Build argument materialization: new block arguments -> old block
-      // argument type.
-      Value argMat = buildUnresolvedArgumentMaterialization(
-          newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
-      mapping.map(origArg, argMat);
-
-      // Build target materialization: old block argument type -> legal type.
-      // Note: This function returns an "empty" type if no valid conversion to
-      // a legal type exists. In that case, we continue the conversion with the
-      // original block argument type.
-      Type legalOutputType = converter->convertType(origArg.getType());
-      if (legalOutputType && legalOutputType != origArg.getType()) {
-        newArg = buildUnresolvedTargetMaterialization(
-            origArg.getLoc(), argMat, legalOutputType, converter);
-        mapping.map(argMat, newArg);
-      } else {
-        newArg = argMat;
-      }
+    Type legalOutputType;
+    if (converter)
+      legalOutputType = converter->convertType(origArgType);
+    if (legalOutputType && legalOutputType != origArgType) {
+      Value targetMat = buildUnresolvedTargetMaterialization(
+          origArg.getLoc(), argMat, legalOutputType, converter);
+      mapping.map(argMat, targetMat);
     }
-
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
-    argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
-                                            converter);
+  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
 
   // Erase the old block. (It is just unlinked for now and will be erased during
   // cleanup.)
@@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
-Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
-    Block *block, Location loc, ValueRange inputs, Type outputType,
-    const TypeConverter *converter) {
-  return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
-                                        block->begin(), loc, inputs, outputType,
-                                        converter);
-}
 Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
     Location loc, Value input, Type outputType,
     const TypeConverter *converter) {
@@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), outputType, inputOperands);
       break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
     }
     if (newMaterialization) {
       assert(newMaterialization.getType() == outputType &&
@@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
 
   InFlightDiagnostic diag = op->emitError()
                             << "failed to legalize unresolved materialization "
-                               "from "
-                            << inputOperands.getTypes() << " to " << outputType
+                               "from ("
+                            << inputOperands.getTypes() << ") to " << outputType
                             << " that remained live after conversion";
   if (Operation *liveUser = findLiveUser(op->getUsers())) {
     diag.attachNote(liveUser->getLoc())
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index b35cda8e724f6..8254be68912c8 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -2,9 +2,8 @@
 
 
 func.func @test_invalid_arg_materialization(
-  // expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
+  // expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
   %arg0: i16) {
-  // expected-note@below {{see existing live user here}}
   "foo.return"(%arg0) : (i16) -> ()
 }
 
@@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
 // Make sure argument type changes aren't implicitly forwarded.
 func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
-  // expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
+  // expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
   ^bb0(%arg0: f32):
-    // expected-note@below {{see existing live user here}}
     "test.type_consumer"(%arg0) : (f32) -> ()
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()

@llvmbot
Copy link
Member

llvmbot commented Jul 14, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated.

When converting a block signature, a BlockTypeConversionRewrite object and potentially multiple ReplaceBlockArgRewrite are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in BlockTypeConversionRewrite::commit and some were replaced in ReplaceBlockArgRewrite::commit. The new
BlockTypeConversionRewrite::commit implementation is much simpler and no longer modifies any IR; that is done only in ReplaceBlockArgRewrite now. The ConvertedArgInfo data structure is no longer needed.

To that end, materializations of dropped arguments are now built in applySignatureConversion instead of materializeLiveConversions; the latter function no longer has to deal with dropped arguments.

Other minor improvements:

  • Add more comments to applySignatureConversion.

Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in legalizeUnresolvedMaterialization instead of legalizeConvertedArgumentTypes.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion.

This is a re-upload of #96207.


Full diff: https://github.com/llvm/llvm-project/pull/97213.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+55-118)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+2-4)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1e0afee2373a9..0b552a7e1ca3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
   Block *insertBeforeBlock;
 };
 
-/// This structure contains the information pertaining to an argument that has
-/// been converted.
-struct ConvertedArgInfo {
-  ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
-                   Value castValue = nullptr)
-      : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
-  /// The start index of in the new argument list that contains arguments that
-  /// replace the original.
-  unsigned newArgIdx;
-
-  /// The number of arguments that replaced the original argument.
-  unsigned newArgSize;
-
-  /// The cast value that was created to cast from the new arguments to the
-  /// old. This only used if 'newArgSize' > 1.
-  Value castValue;
-};
-
 /// Block type conversion. This rewrite is partially reflected in the IR.
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
-  BlockTypeConversionRewrite(
-      ConversionPatternRewriterImpl &rewriterImpl, Block *block,
-      Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
-      const TypeConverter *converter)
+  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                             Block *block, Block *origBlock,
+                             const TypeConverter *converter)
       : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
-        origBlock(origBlock), argInfo(argInfo), converter(converter) {}
+        origBlock(origBlock), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   /// The original block that was requested to have its signature converted.
   Block *origBlock;
 
-  /// The conversion information for each of the arguments. The information is
-  /// std::nullopt if the argument was dropped during conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
   /// The type converter used to convert the arguments.
   const TypeConverter *converter;
 };
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
 /// The type of materialization.
 enum MaterializationKind {
   /// This materialization materializes a conversion for an illegal block
-  /// argument type, to a legal one.
+  /// argument type, to the original one.
   Argument,
 
   /// This materialization materializes a conversion from an illegal type to a
   /// legal one.
-  Target
+  Target,
+
+  /// This materialization materializes a conversion from a legal type back to
+  /// an illegal one.
+  Source
 };
 
 /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+  llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
       converterAndKind;
 };
 } // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        ValueRange inputs, Type outputType,
                                        const TypeConverter *converter);
 
-  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
-                                               ValueRange inputs,
-                                               Type outputType,
-                                               const TypeConverter *converter);
-
   Value buildUnresolvedTargetMaterialization(Location loc, Value input,
                                              Type outputType,
                                              const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
           dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
     for (Operation *op : block->getUsers())
       listener->notifyOperationModified(op);
-
-  // Process the remapping for each of the original arguments.
-  for (auto [origArg, info] :
-       llvm::zip_equal(origBlock->getArguments(), argInfo)) {
-    // Handle the case of a 1->0 value mapping.
-    if (!info) {
-      if (Value newArg =
-              rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        rewriter.replaceAllUsesWith(origArg, newArg);
-      continue;
-    }
-
-    // Otherwise this is a 1->1+ value mapping.
-    Value castValue = info->castValue;
-    assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
-    // If the argument is still used, replace it with the generated cast.
-    if (!origArg.use_empty()) {
-      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
-                                               castValue, origArg.getType()));
-    }
-  }
 }
 
 void BlockTypeConversionRewrite::rollback() {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
       continue;
 
     Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
-    bool isDroppedArg = replacementValue == origArg;
-    if (!isDroppedArg)
-      builder.setInsertionPointAfterValue(replacementValue);
+    assert(replacementValue && "replacement value not found");
     Value newArg;
     if (converter) {
+      builder.setInsertionPointAfterValue(replacementValue);
       newArg = converter->materializeSourceConversion(
-          builder, origArg.getLoc(), origArg.getType(),
-          isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+          builder, origArg.getLoc(), origArg.getType(), replacementValue);
       assert((!newArg || newArg.getType() == origArg.getType()) &&
              "materialization hook did not provide a value of the expected "
              "type");
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
           << "failed to materialize conversion for block argument #"
           << it.index() << " that remained live after conversion, type was "
           << origArg.getType();
-      if (!isDroppedArg)
-        diag << ", with target type " << replacementValue.getType();
       diag.attachNote(liveUser->getLoc())
           << "see existing live user here: " << *liveUser;
       return failure();
@@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   // Replace all uses of the old block with the new block.
   block->replaceAllUsesWith(newBlock);
 
-  // Remap each of the original arguments as determined by the signature
-  // conversion.
-  SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-  argInfo.resize(origArgCount);
-
   for (unsigned i = 0; i != origArgCount; ++i) {
-    auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap)
-      continue;
     BlockArgument origArg = block->getArgument(i);
+    Type origArgType = origArg.getType();
+
+    std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
+        signatureConversion.getInputMapping(i);
+    if (!inputMap) {
+      // This block argument was dropped and no replacement value was provided.
+      // Materialize a replacement value "out of thin air".
+      Value repl = buildUnresolvedMaterialization(
+          MaterializationKind::Source, newBlock, newBlock->begin(),
+          origArg.getLoc(), /*inputs=*/ValueRange(),
+          /*outputType=*/origArgType, converter);
+      mapping.map(origArg, repl);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+      continue;
+    }
 
-    // If inputMap->replacementValue is not nullptr, then the argument is
-    // dropped and a replacement value is provided to be the remappedValue.
-    if (inputMap->replacementValue) {
+    if (Value repl = inputMap->replacementValue) {
+      // This block argument was dropped and a replacement value was provided.
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, inputMap->replacementValue);
+      mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
     }
 
-    // Otherwise, this is a 1->1+ mapping.
+    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
+    // dialect conversion. Therefore, we need an argument materialization to
+    // turn the replacement block arguments into a single SSA value that can be
+    // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
+    Value argMat = buildUnresolvedMaterialization(
+        MaterializationKind::Argument, newBlock, newBlock->begin(),
+        origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
+    mapping.map(origArg, argMat);
+    appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
 
-    // If this is a 1->1 mapping and the types of new and replacement arguments
-    // match (i.e. it's an identity map), then the argument is mapped to its
-    // original type.
     // FIXME: We simply pass through the replacement argument if there wasn't a
     // converter, which isn't great as it allows implicit type conversions to
     // appear. We should properly restructure this code to handle cases where a
     // converter isn't provided and also to properly handle the case where an
     // argument materialization is actually a temporary source materialization
     // (e.g. in the case of 1->N).
-    if (replArgs.size() == 1 &&
-        (!converter || replArgs[0].getType() == origArg.getType())) {
-      newArg = replArgs.front();
-      mapping.map(origArg, newArg);
-    } else {
-      // Build argument materialization: new block arguments -> old block
-      // argument type.
-      Value argMat = buildUnresolvedArgumentMaterialization(
-          newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
-      mapping.map(origArg, argMat);
-
-      // Build target materialization: old block argument type -> legal type.
-      // Note: This function returns an "empty" type if no valid conversion to
-      // a legal type exists. In that case, we continue the conversion with the
-      // original block argument type.
-      Type legalOutputType = converter->convertType(origArg.getType());
-      if (legalOutputType && legalOutputType != origArg.getType()) {
-        newArg = buildUnresolvedTargetMaterialization(
-            origArg.getLoc(), argMat, legalOutputType, converter);
-        mapping.map(argMat, newArg);
-      } else {
-        newArg = argMat;
-      }
+    Type legalOutputType;
+    if (converter)
+      legalOutputType = converter->convertType(origArgType);
+    if (legalOutputType && legalOutputType != origArgType) {
+      Value targetMat = buildUnresolvedTargetMaterialization(
+          origArg.getLoc(), argMat, legalOutputType, converter);
+      mapping.map(argMat, targetMat);
     }
-
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
-    argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
-                                            converter);
+  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
 
   // Erase the old block. (It is just unlinked for now and will be erased during
   // cleanup.)
@@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
-Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
-    Block *block, Location loc, ValueRange inputs, Type outputType,
-    const TypeConverter *converter) {
-  return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
-                                        block->begin(), loc, inputs, outputType,
-                                        converter);
-}
 Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
     Location loc, Value input, Type outputType,
     const TypeConverter *converter) {
@@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), outputType, inputOperands);
       break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
     }
     if (newMaterialization) {
       assert(newMaterialization.getType() == outputType &&
@@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
 
   InFlightDiagnostic diag = op->emitError()
                             << "failed to legalize unresolved materialization "
-                               "from "
-                            << inputOperands.getTypes() << " to " << outputType
+                               "from ("
+                            << inputOperands.getTypes() << ") to " << outputType
                             << " that remained live after conversion";
   if (Operation *liveUser = findLiveUser(op->getUsers())) {
     diag.attachNote(liveUser->getLoc())
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index b35cda8e724f6..8254be68912c8 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -2,9 +2,8 @@
 
 
 func.func @test_invalid_arg_materialization(
-  // expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
+  // expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
   %arg0: i16) {
-  // expected-note@below {{see existing live user here}}
   "foo.return"(%arg0) : (i16) -> ()
 }
 
@@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
 // Make sure argument type changes aren't implicitly forwarded.
 func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
-  // expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
+  // expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
   ^bb0(%arg0: f32):
-    // expected-note@below {{see existing live user here}}
     "test.type_consumer"(%arg0) : (f32) -> ()
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()

@matthias-springer matthias-springer force-pushed the users/matthias-springer/arg_mat_experiment branch from 5773176 to cbbf741 Compare July 15, 2024 14:09
Base automatically changed from users/matthias-springer/arg_mat_experiment to main July 15, 2024 15:04
@ftynse
Copy link
Member

ftynse commented Jul 19, 2024

Has the issue that caused the revert of this been resolved?

@matthias-springer
Copy link
Member Author

Yes, #97903. I’m going to re-land the reverted PRs, but with a week pause in between in case something else is breaking for downstream users. (Our dialect conversion test suite in MLIR is not good.)

…rguments

This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated.

When converting a block signature, a BlockTypeConversionRewrite object and potentially multiple ReplaceBlockArgRewrite are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in BlockTypeConversionRewrite::commit and some were replaced in ReplaceBlockArgRewrite::commit. The new
BlockTypeConversionRewrite::commit implementation is much simpler and no longer modifies any IR; that is done only in ReplaceBlockArgRewrite now. The ConvertedArgInfo data structure is no longer needed.

To that end, materializations of dropped arguments are now built in applySignatureConversion instead of materializeLiveConversions; the latter function no longer has to deal with dropped arguments.

Other minor improvements:

Improve variable name: origOutputType -> origArgType. Add an assertion to check that this field is only used for argument materializations.
Add more comments to applySignatureConversion.
Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in legalizeUnresolvedMaterialization instead of legalizeConvertedArgumentTypes.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion.

This is a re-upload of #96207.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/block_arg_rewrite_2 branch from b0b7813 to 4114d5b Compare July 20, 2024 07:24
@matthias-springer matthias-springer merged commit bbd4af5 into main Jul 20, 2024
7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/block_arg_rewrite_2 branch July 20, 2024 08:12
ScottTodd added a commit to iree-org/llvm-project that referenced this pull request Jul 25, 2024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're seeing asserts downstream after this PR that go away if the commit is reverted.

mlir/lib/IR/Region.cpp:25: MLIRContext *mlir::Region::getContext(): Assertion `container && "region is not attached to a container"' failed.

Logs:

IR before the pass that crashes for one test case: https://gist.github.com/ScottTodd/6a9fdc0976d7336291d61ccf24bcb22b

Our downstream code involved in the callstack is here: https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp . I'm not sure yet if our usage downstream needs to change or if there is a bug in this upstream code. Do you have any suggestions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like an edge case with detached IR. can you try passing the context from the rewriter to “ buildUnresolvedMaterialization” as an extra argument (in the dialect conversion)? if that doesn’t fix it, please revert and i will get back to it when i have time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and create the builder with the context:

// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
OpBuilder builder(insertBlock, insertPt);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This patch appears to fix our crashes: https://gist.github.com/ScottTodd/7d05663c3180f5ae5711e278479f0146

I have us downstream set to carry a local revert as a bandaid fix. Some options:

  1. Fix-forward: I send this patch as a PR (not sure what tests to add, I don't understand this code well enough)
  2. Fix-forward: you take over the patch
  3. We revert this PR and then proceed with a rollforward including the fix patch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's all that's necessary I can send a fix later today.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Here's my patch as a commit, if it helps: ScottTodd@f766cd2

yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…rguments (#97213)

Summary:
This commit simplifies the handling of dropped arguments and updates
some dialect conversion documentation that is outdated.

When converting a block signature, a `BlockTypeConversionRewrite` object
and potentially multiple `ReplaceBlockArgRewrite` are created. During
the "commit" phase, uses of the old block arguments are replaced with
the new block arguments, but the old implementation was written in an
inconsistent way: some block arguments were replaced in
`BlockTypeConversionRewrite::commit` and some were replaced in
`ReplaceBlockArgRewrite::commit`. The new
`BlockTypeConversionRewrite::commit` implementation is much simpler and
no longer modifies any IR; that is done only in `ReplaceBlockArgRewrite`
now. The `ConvertedArgInfo` data structure is no longer needed.

To that end, materializations of dropped arguments are now built in
`applySignatureConversion` instead of `materializeLiveConversions`; the
latter function no longer has to deal with dropped arguments.

Other minor improvements:
- Add more comments to `applySignatureConversion`.

Note: Error messages around failed materializations for dropped basic
block arguments changed slightly. That is because those materializations
are now built in `legalizeUnresolvedMaterialization` instead of
`legalizeConvertedArgumentTypes`.

This commit is in preparation of decoupling argument/source/target
materializations from the dialect conversion.

This is a re-upload of #96207.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251281
ScottTodd added a commit to iree-org/llvm-project that referenced this pull request Jul 26, 2024
bjacob pushed a commit to iree-org/llvm-project that referenced this pull request Jul 27, 2024
bjacob pushed a commit to iree-org/llvm-project that referenced this pull request Jul 27, 2024
bjacob pushed a commit to iree-org/llvm-project that referenced this pull request Jul 29, 2024
bjacob pushed a commit to iree-org/llvm-project that referenced this pull request Jul 29, 2024
hanhanW pushed a commit to iree-org/llvm-project that referenced this pull request Jul 29, 2024
@MaheshRavishankar
Copy link
Contributor

I am posting here (but I can make it an issue as well), cause I dont fully follow if this is a downstream error or an error in this PR. To repro you can use this example

func.func @alloc_transfer_read_write_vector4_vector8(%arg0: memref<4096x4096xf32>, %x: index, %y: index) {
  %cst = arith.constant 0.000000e+00 : f32
  %0 = memref.alloc() : memref<128x32xf32, 3>
  %v = vector.transfer_read %arg0[%x, %y], %cst : memref<4096x4096xf32>, vector<4xf32>
  vector.transfer_write %v, %0[%x, %y] : vector<4xf32>, memref<128x32xf32, 3>
  %mat = vector.transfer_read %arg0[%x, %y], %cst : memref<4096x4096xf32>, vector<8xf32>
  vector.transfer_write %mat, %0[%x, %y] : vector<8xf32>, memref<128x32xf32, 3>
  memref.dealloc %0 : memref<128x32xf32, 3>
  return
}

and the following command from IREE

iree-opt --pass-pipeline="builtin.module(func.func(iree-spirv-vectorize-load-store))" repro.mlir

With this change, I am hitting an assertion here cause the pass expects the element type to be vector, but it isnt. Debugging this and adding a break point + --debug flag I see this being printed before the assertion

mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (memref<4096x1024xvector<4xf32>>, index, index) -> (), sym_name = "alloc_transfer_read_write_vector4_vector8"}> ({
^bb0(%arg0: memref<4096x1024xvector<4xf32>>, %arg1: index, %arg2: index):
  %0 = "builtin.unrealized_conversion_cast"(%arg0) : (memref<4096x1024xvector<4xf32>>) -> memref<4096x4096xf32>
  %1 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
  %2 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<128x8xvector<4xf32>, 3>
  %3 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<128x32xf32, 3>
  %4 = "vector.transfer_read"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, %1) <{in_bounds = [false], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d1)>}> : (memref<4096x4096xf32>, index, index, f32) -> vector<4xf32>
  "vector.transfer_write"(%4, %3, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) <{in_bounds = [false], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d1)>}> : (vector<4xf32>, memref<128x32xf32, 3>, index, index) -> ()
  %5 = "vector.transfer_read"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, %1) <{in_bounds = [false], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d1)>}> : (memref<4096x4096xf32>, index, index, f32) -> vector<8xf32>
  "vector.transfer_write"(%5, %3, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) <{in_bounds = [false], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d1)>}> : (vector<8xf32>, memref<128x32xf32, 3>, index, index) -> ()
  "memref.dealloc"(%3) : (memref<128x32xf32, 3>) -> ()
  "func.return"() : () -> ()
}) : () -> ()


} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'vector.transfer_read'(0x555555665f40) {
  %4 = "vector.transfer_read"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, %1) <{in_bounds = [false], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d1)>}> : (memref<4096x4096xf32>, index, index, f32) -> vector<4xf32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'vector.transfer_read -> ()' {
Trying to match "mlir::iree_compiler::(anonymous namespace)::ProcessTransferRead"

Looking at the IR (and stepping through dialect conversion as much as I understood it, there might be couple of things happening

  1. At this point, the adaptor.geSource() is supposed to be of type memref<...<vector<4xf32>> but it is instead the original type. So somewhere the value tracking went off the tracks.
  2. The vector.transfer_read"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, %1) is also suspect. I think the signatureConversion that happens when this pattern runs on the func.func was supposed to instead have this
    vector.transfer_read"(%0, %arg1, %arg2, %1) which means something is probably suspect in the signature conversion here (https://github.com/iree-org/iree/blob/488cfc9688032d61cd47c63ff795d7f3e4642a16/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp#L312) . I am debugging this, but posting here for some input.

cc @matthias-springer @antiagainst and @bjacob

matthias-springer added a commit that referenced this pull request Jul 30, 2024
This code got lost in #97213 and there was no test for it. Add it back with an MLIR test.

When a pattern is run without a type converter, we can assume that the new block arugments of a signature conversion are legal.
@matthias-springer
Copy link
Member Author

@MaheshRavishankar There's an untested piece of code in the dialect conversion that got lost. I think this should fix it, can you give it a try? #101148

bjacob pushed a commit to iree-org/llvm-project that referenced this pull request Jul 30, 2024
matthias-springer added a commit that referenced this pull request Jul 30, 2024
…101148)

This code got lost in #97213 and there was no test for it. Add it back
with an MLIR test.

When a pattern is run without a type converter, we can assume that the
new block argument types of a signature conversion are legal. That's
because they were specified by the user. This won't work for 1->N
conversions due to limitations in the dialect conversion infrastructure,
so the original `FIXME` has to stay in place.
hanhanW pushed a commit to iree-org/llvm-project that referenced this pull request Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants