Skip to content

Commit 8f4f163

Browse files
swolchokpytorchmergebot
authored andcommitted
[PyTorch] Flip polarity of masked_softmax mask (#78)
Summary: X-link: pytorch/pytorch-canary#78 Pull Request resolved: pytorch#75039 It didn't match torch.nn.MultiheadAttention. Now it does. ghstack-source-id: 152815449 Test Plan: updated tests Reviewed By: zrphercule Differential Revision: D34929186 fbshipit-source-id: 1eaee615bafd5a6f058f1faefa54f8f4aa01c92e (cherry picked from commit 00eea72)
1 parent 87ab665 commit 8f4f163

File tree

4 files changed

+8
-7
lines changed

4 files changed

+8
-7
lines changed

aten/src/ATen/native/SoftMax.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ void host_softmax(
170170
}
171171
} else {
172172
for (const auto d : c10::irange(0, dim_size)) {
173-
if (mask_data[d * dim_stride]) {
173+
if (!mask_data[d * dim_stride]) {
174174
max_input = is_meaningful_max
175175
? std::max(max_input, input_data[d * dim_stride])
176176
: input_data[d * dim_stride];
@@ -183,7 +183,7 @@ void host_softmax(
183183
acc_type<scalar_t, false> tmpsum = 0;
184184
for (const auto d : c10::irange(dim_size)) {
185185
scalar_t z{};
186-
if (!MaskedSoftMax || mask_data[d * dim_stride]) {
186+
if (!MaskedSoftMax || !mask_data[d * dim_stride]) {
187187
z = std::exp(input_data[d * dim_stride] - max_input);
188188
} else {
189189
z = 0;

aten/src/ATen/native/cuda/PersistentSoftmax.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
126126
if (!is_transformer_mask) {
127127
idx += i*element_count;
128128
}
129-
if (mask[idx]) {
129+
if (!mask[idx]) {
130130
max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
131131
is_meaningful_max = true;
132132
}
@@ -160,7 +160,7 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
160160
idx += i*element_count;
161161
}
162162

163-
if (mask[idx]) {
163+
if (!mask[idx]) {
164164
if (is_log_softmax) {
165165
sum[i] += std::exp(elements[i][it] - max_value[i]);
166166
} else {
@@ -188,7 +188,7 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
188188
if (!is_transformer_mask) {
189189
idx += i*element_count;
190190
}
191-
if (!mask[idx]) {
191+
if (mask[idx]) {
192192
dst[i*element_count+it*WARP_SIZE] = 0;
193193
continue;
194194
}

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -959,8 +959,7 @@ Tensor masked_softmax_cuda(const Tensor& input, const Tensor& mask) {
959959
input.scalar_type(),
960960
"masked_softmax",
961961
[&] {
962-
Tensor mask_not = mask.logical_not();
963-
output = at::softmax(input.masked_fill(mask_not, -std::numeric_limits<scalar_t>::infinity()), -1);
962+
output = at::softmax(input.masked_fill(mask, -std::numeric_limits<scalar_t>::infinity()), -1);
964963
});
965964
return output;
966965
}

test/test_nn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16155,6 +16155,7 @@ def test_masked_softmax(self, device):
1615516155
mask = mask.cuda()
1615616156
mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
1615716157
native_res = torch._masked_softmax(input, mask)
16158+
mask = ~mask
1615816159
mask = mask.float()
1615916160

1616016161
def slow_masked_softmax(input, mask):
@@ -16178,6 +16179,7 @@ def test_masked_softmax_transformer_layout(self, device):
1617816179
mask = mask.bool()
1617916180
native_res = torch._masked_softmax(input, mask)
1618016181
mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L)
16182+
mask = ~mask
1618116183
mask = mask.float()
1618216184

1618316185
def slow_masked_softmax(input, mask):

0 commit comments

Comments
 (0)