@@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
283
283
284
284
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref ().getType ());
285
285
Value srcPtr =
286
- getStridedElementPtr (b.getLoc (), srcMemrefType, adaptor. getSrcMemref () ,
287
- adaptor.getIndices (), rewriter );
286
+ getStridedElementPtr (rewriter, b.getLoc (), srcMemrefType,
287
+ adaptor.getSrcMemref (), adaptor. getIndices () );
288
288
Value ldMatrixResult = b.create <NVVM::LdMatrixOp>(
289
289
ldMatrixResultType, srcPtr,
290
290
/* num=*/ op.getNumTiles (),
@@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
661
661
Location loc = op.getLoc ();
662
662
auto dstMemrefType = cast<MemRefType>(op.getDst ().getType ());
663
663
Value dstPtr =
664
- getStridedElementPtr (b.getLoc (), dstMemrefType, adaptor. getDst () ,
665
- adaptor.getDstIndices (), rewriter );
664
+ getStridedElementPtr (rewriter, b.getLoc (), dstMemrefType,
665
+ adaptor.getDst (), adaptor. getDstIndices () );
666
666
FailureOr<unsigned > dstAddressSpace =
667
667
getTypeConverter ()->getMemRefAddressSpace (dstMemrefType);
668
668
if (failed (dstAddressSpace))
@@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
676
676
return rewriter.notifyMatchFailure (
677
677
loc, " source memref address space not convertible to integer" );
678
678
679
- Value scrPtr = getStridedElementPtr (loc, srcMemrefType, adaptor.getSrc (),
680
- adaptor.getSrcIndices (), rewriter);
679
+ Value scrPtr =
680
+ getStridedElementPtr (rewriter, loc, srcMemrefType, adaptor.getSrc (),
681
+ adaptor.getSrcIndices ());
681
682
// Intrinsics takes a global pointer so we need an address space cast.
682
683
auto srcPointerGlobalType = LLVM::LLVMPointerType::get (
683
684
op->getContext (), NVVM::NVVMMemorySpace::kGlobalMemorySpace );
@@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
814
815
MemRefType mbarrierMemrefType =
815
816
nvgpu::getMBarrierMemrefType (rewriter.getContext (), mbarType);
816
817
return ConvertToLLVMPattern::getStridedElementPtr (
817
- b.getLoc (), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter );
818
+ rewriter, b.getLoc (), mbarrierMemrefType, memrefDesc, {mbarId});
818
819
}
819
820
};
820
821
@@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
995
996
ConversionPatternRewriter &rewriter) const override {
996
997
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
997
998
auto srcMemrefType = cast<MemRefType>(op.getDst ().getType ());
998
- Value dest = getStridedElementPtr (op->getLoc (), srcMemrefType,
999
- adaptor.getDst (), {}, rewriter );
999
+ Value dest = getStridedElementPtr (rewriter, op->getLoc (), srcMemrefType,
1000
+ adaptor.getDst (), {});
1000
1001
Value barrier =
1001
1002
getMbarrierPtr (b, op.getBarriers ().getType (), adaptor.getBarriers (),
1002
1003
adaptor.getMbarId (), rewriter);
@@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
1021
1022
ConversionPatternRewriter &rewriter) const override {
1022
1023
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1023
1024
auto srcMemrefType = cast<MemRefType>(op.getSrc ().getType ());
1024
- Value dest = getStridedElementPtr (op->getLoc (), srcMemrefType,
1025
- adaptor.getSrc (), {}, rewriter );
1025
+ Value dest = getStridedElementPtr (rewriter, op->getLoc (), srcMemrefType,
1026
+ adaptor.getSrc (), {});
1026
1027
SmallVector<Value> coords = adaptor.getCoordinates ();
1027
1028
for (auto [index, value] : llvm::enumerate (coords)) {
1028
1029
coords[index] = truncToI32 (b, value);
@@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
1083
1084
Value leadDim = makeConst (leadDimVal);
1084
1085
1085
1086
Value baseAddr = getStridedElementPtr (
1086
- op->getLoc (), cast<MemRefType>(op.getTensor ().getType ()),
1087
- adaptor.getTensor (), {}, rewriter );
1087
+ rewriter, op->getLoc (), cast<MemRefType>(op.getTensor ().getType ()),
1088
+ adaptor.getTensor (), {});
1088
1089
Value basePtr = b.create <LLVM::PtrToIntOp>(ti64, baseAddr);
1089
1090
// Just use 14 bits for base address
1090
1091
Value basePtr14bit = shiftRight (shiftLeft (basePtr, 46 ), 50 );
0 commit comments