You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: aten/src/ATen/native/SoftMax.cpp
+2-1Lines changed: 2 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -137,7 +137,8 @@ void host_softmax(
137
137
if (MaskedSoftMax) {
138
138
TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
139
139
int64_t mask_type = mask_type_.value();
140
-
TORCH_CHECK((mask_type == 0) || (mask_type == 1), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask)");
140
+
// If mask_type == 2, then mask_.sizes() must equal input_.sizes()
141
+
TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)");
141
142
142
143
// TODO: Add support for TxT src_mask
143
144
TORCH_CHECK(mask_type != 0, "src_mask not currently supported on CPU");
0 commit comments