Skip to content

Commit 39cdefb

Browse files
authored
[mlir][nvvm] Add prefetch.tensormap (llvm#67564)
This PR adds `prefetch.tensormap` Op. It brings the cache line containing the given tma descriptor for subsequent use by the cp.async.bulk.tensor instruction. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prefetch-prefetchu
1 parent f2898de commit 39cdefb

File tree

5 files changed

+58
-0
lines changed

5 files changed

+58
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,17 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
14381438
let hasVerifier = 1;
14391439
}
14401440

1441+
def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
1442+
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1443+
Arguments<(ins LLVM_i64ptr_any:$tmaDescriptor, PtxPredicate:$predicate)> {
1444+
let assemblyFormat = "$tmaDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
1445+
let extraClassDefinition = [{
1446+
std::string $cppClass::getPtx() {
1447+
return std::string("prefetch.tensormap [%0];");
1448+
}
1449+
}];
1450+
}
1451+
14411452
//===----------------------------------------------------------------------===//
14421453
// NVVM Wgmma Ops
14431454
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,18 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
619619
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phase `,` $ticks attr-dict `:` type($barriers)";
620620
}
621621

622+
def NVGPU_TmaPrefetchOp : NVGPU_Op<"tma.prefetch.descriptor", []> {
623+
let summary = "Prefetch given `nvgpu.tensormap.descriptor` ";
624+
let description = [{
625+
The Op brings the cache line containing the given `$tmaDescriptor` for
626+
subsequent use by the `tma.async.load` instruction.
627+
}];
628+
let arguments = (ins NVGPU_TensorMapDescriptor:$tensorMapDescriptor, Optional<I1>:$predicate);
629+
let assemblyFormat = [{
630+
$tensorMapDescriptor (`,` $predicate^)? attr-dict `:` type($tensorMapDescriptor)
631+
}];
632+
}
633+
622634
def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", []> {
623635
let summary = "TMA asynchronous load";
624636
let description = [{

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,6 +1610,18 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
16101610
}
16111611
};
16121612

1613+
struct NVGPUTmaPrefetchOpLowering
1614+
: public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1615+
using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1616+
LogicalResult
1617+
matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1618+
ConversionPatternRewriter &rewriter) const override {
1619+
rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1620+
op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1621+
return success();
1622+
}
1623+
};
1624+
16131625
} // namespace
16141626

16151627
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1623,6 +1635,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
16231635
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
16241636
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
16251637
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1638+
NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
16261639
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
16271640
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
16281641
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,17 @@ func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : m
643643
func.return
644644
}
645645

646+
// CHECK-LABEL: @tma_prefetch(
647+
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none>, %[[arg1:[a-zA-Z0-9_]+]]: i1
648+
func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
649+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none> to !llvm.ptr
650+
// CHECK: nvvm.prefetch.tensormap %[[S0]] : !llvm.ptr
651+
nvgpu.tma.prefetch.descriptor %tensorMap1d: !tensorMap1d
652+
// CHECK: nvvm.prefetch.tensormap %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
653+
nvgpu.tma.prefetch.descriptor %tensorMap1d, %p: !tensorMap1d
654+
func.return
655+
}
656+
646657
!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
647658
!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
648659

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,14 @@ func.func @elect_one_leader_sync() {
504504
%cnd = nvvm.elect.sync -> i1
505505
return
506506
}
507+
508+
// -----
509+
510+
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
511+
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
512+
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
513+
nvvm.prefetch.tensormap %desc : !llvm.ptr
514+
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
515+
nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
516+
llvm.return
517+
}

0 commit comments

Comments
 (0)