Skip to content

Commit d18dcc0

Browse files
naromero77amdjithunnair-amd
authored andcommitted
[UPCP][release/2.4] TunableOp fix for batched MM with views. (#1723)
Fixes pytorch#140278 Based on PR: pytorch#140673 Note in test/linalg.py that I had to include Case #4 from upstream in addition to Case #5 to resolve the merge conflict from the cherry-pick of the upstream commit. Verified manually that test_bmm_tunableop_rocm UT passes. cc: @jeffdaily
1 parent def7d70 commit d18dcc0

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

aten/src/ATen/cuda/tunable/GemmCommon.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,19 +177,19 @@ struct GemmStridedBatchedParams : OpParams {
177177
}
178178

179179
size_t GetSizeA() const {
180-
size_t size_stride = std::min(lda, stride_a) * ((transa == 'n' || transa == 'N') ? k : m) * batch;
180+
size_t size_stride = stride_a * batch;
181181
size_t size_dense = m * k * batch;
182182
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
183183
}
184184

185185
size_t GetSizeB() const {
186-
size_t size_stride = std::min(ldb, stride_b) * ((transb == 'n' || transb == 'N') ? n : k) * batch;
186+
size_t size_stride = stride_b * batch;
187187
size_t size_dense = k * n * batch;
188188
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
189189
}
190190

191191
size_t GetSizeC() const {
192-
size_t size_stride = std::min(ldc, stride_c) * n * batch;
192+
size_t size_stride = stride_c * batch;
193193
size_t size_dense = m * n * batch;
194194
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
195195
}

test/test_linalg.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4577,6 +4577,26 @@ def test_bmm_tunableop_rocm(self, device, dtype):
45774577
i2 = torch.randn((M, B, K), device=device, dtype=dtype)
45784578
i2 = torch.permute(i2, (1, 2, 0))
45794579
out = torch.bmm(i1, i2)
4580+
# case 4
4581+
input_tensor = torch.rand((1920, 1, 100), device=device, dtype=dtype)
4582+
input_tensor = torch.as_strided(
4583+
input_tensor, size=(1920, 1, 100), stride=(100, 100, 1)
4584+
)
4585+
batch1_tensor = torch.rand((1920, 256, 512), device=device, dtype=dtype)
4586+
batch1_tensor = torch.as_strided(
4587+
batch1_tensor, size=(1920, 256, 512), stride=(512, 983040, 1)
4588+
)
4589+
batch2_tensor = torch.rand((1920, 512, 100), device=device, dtype=dtype)
4590+
batch2_tensor = torch.as_strided(
4591+
batch2_tensor, size=(1920, 512, 100), stride=(51200, 100, 1)
4592+
)
4593+
out = torch.baddbmm(input_tensor, batch1_tensor, batch2_tensor)
4594+
# case 5
4595+
q = torch.randn([16, 16, 1024, 64], device=device, dtype=dtype)
4596+
k = torch.randn([16, 16, 1024, 64], device=device, dtype=dtype)
4597+
q_chunks = q.split(512, dim=-2)
4598+
k_chunks = k.split(64, dim=-2)
4599+
C = torch.matmul(q_chunks[0], k_chunks[0])
45804600
# clean up, remove any file that was generated
45814601
try:
45824602
import os

0 commit comments

Comments
 (0)