Skip to content

Commit 85258ec

Browse files
george-qipytorchmergebot
authored andcommitted
Add mask_type=2 to masked_softmax for when mask.size() == input.size() (pytorch#85915)
Pull Request resolved: pytorch#85915 Approved by: https://github.com/cpuhrsch
1 parent 6004c65 commit 85258ec

File tree

3 files changed

+18
-16
lines changed

3 files changed

+18
-16
lines changed

aten/src/ATen/native/SoftMax.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ void host_softmax(
137137
if (MaskedSoftMax) {
138138
TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
139139
int64_t mask_type = mask_type_.value();
140-
TORCH_CHECK((mask_type == 0) || (mask_type == 1), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask)");
140+
// If mask_type == 2, then mask_.sizes() must equal input_.sizes()
141+
TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)");
141142

142143
// TODO: Add support for TxT src_mask
143144
TORCH_CHECK(mask_type != 0, "src_mask not currently supported on CPU");

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,7 @@ Tensor masked_softmax_cuda(const Tensor& input_, const Tensor& mask_, const c10:
963963

964964
TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
965965
int64_t mask_type = mask_type_.value();
966-
TORCH_CHECK((mask_type == 0) || (mask_type == 1), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask)");
966+
TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask), 1 (src_key_padding_mask), or 2 (default_mask)");
967967

968968
// If input is [B, H, T, T] and mask is [B, T]
969969
// we have special fast kernel
@@ -975,6 +975,7 @@ Tensor masked_softmax_cuda(const Tensor& input_, const Tensor& mask_, const c10:
975975
// TODO We should have special fast kernel for TxT mask as well
976976
// mask_type == 0 => mask_ is a src_mask
977977
bool is_TxT_mask = (mask_type == 0) && input_.dim() == 4 && mask_.dim() == 2 && input_.size(3) == mask_.size(1) && input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1);
978+
// If mask_type == 2, then mask_.sizes() must equal input_.sizes()
978979
TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input. mask: ", mask_.sizes(), " input: ", input_.sizes());
979980

980981
auto input = input_.dim() == 0 ? input_.view(1) : input_;

test/test_nn.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15449,26 +15449,26 @@ def test_masked_softmax_grad(self, device):
1544915449
for shape in shapes:
1545015450
dims = [0, len(shape) - 1] if len(shape) > 0 else [0]
1545115451
for dim in dims:
15452-
input = torch.randn(shape, requires_grad=True)
15453-
mask = torch.randint(0, 2, shape).bool()
15454-
mask_type = 1 # BxL => src_key_padding_mask
15455-
if (self.device_type == "cuda"):
15456-
input = input.cuda().detach().requires_grad_()
15457-
mask = mask.cuda()
15458-
self._test_masked_softmax_helper(input, dim, mask, mask_type)
15452+
for mask_type in [1, 2]: # 1 = BxL => src_key_padding_mask
15453+
input = torch.randn(shape, requires_grad=True)
15454+
mask = torch.randint(0, 2, shape).bool()
15455+
if (self.device_type == "cuda"):
15456+
input = input.cuda().detach().requires_grad_()
15457+
mask = mask.cuda()
15458+
self._test_masked_softmax_helper(input, dim, mask, mask_type)
1545915459

1546015460
# In this test, the forward pass is expected to produce nan's because when dim=0, we only have unspecified values
1546115461
def test_masked_softmax_forward_with_nans(self, device):
1546215462
dim = 0
1546315463
shapes = [(4, 5), (50, 100), (1500, 1200)]
1546415464
for (x, y) in shapes:
15465-
input = torch.randn((x, y), requires_grad=True)
15466-
mask = torch.tensor([i % 2 for i in range(y)]).expand((x, y)).bool()
15467-
mask_type = 1 # BxL => src_key_padding_mask
15468-
if (self.device_type == "cuda"):
15469-
input = input.cuda().detach().requires_grad_()
15470-
mask = mask.cuda()
15471-
self._test_masked_softmax_helper(input, dim, mask, mask_type)
15465+
for mask_type in [1, 2]: # 1 = BxL => src_key_padding_mask
15466+
input = torch.randn((x, y), requires_grad=True)
15467+
mask = torch.tensor([i % 2 for i in range(y)]).expand((x, y)).bool()
15468+
if (self.device_type == "cuda"):
15469+
input = input.cuda().detach().requires_grad_()
15470+
mask = mask.cuda()
15471+
self._test_masked_softmax_helper(input, dim, mask, mask_type)
1547215472

1547315473
@onlyCUDA
1547415474
def test_masked_softmax_transformer_layout(self, device):

0 commit comments

Comments
 (0)