Skip to content

Commit 6a14fcb

Browse files
DenisVieriu97pytorchmergebot
authored andcommitted
[MPS] Add support for aten::masked_select on mps (#119) (pytorch#85818)
Reuse the `index.Tensor_out` implementation since it's already expanding the bool/byte indices to long tensors. Pull Request resolved: pytorch#85818 Approved by: https://github.com/kulinseth
1 parent 85258ec commit 6a14fcb

File tree

4 files changed

+52
-4
lines changed

4 files changed

+52
-4
lines changed

aten/src/ATen/mps/MPSFallback.mm

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
4848
m.impl("linalg_vector_norm", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
4949
m.impl("sgn.out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
5050
m.impl("nonzero", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
51-
m.impl("masked_select", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
5251
}
5352

5453
} // namespace at

aten/src/ATen/native/mps/operations/Indexing.mm

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ bool dispatchIndexKernel(TensorIteratorBase& iter,
146146
return true;
147147
}
148148

149-
150149
static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, const std::string& op, bool accumulate) {
151150
using namespace mps;
152151

@@ -186,6 +185,43 @@ void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_size, IntArray
186185
}
187186
}
188187

188+
static Tensor & masked_select_out_mps_impl(Tensor & result, const Tensor & self, const Tensor & mask) {
189+
NoNamesGuard guard;
190+
191+
TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
192+
"masked_select: expected BoolTensor or ByteTensor for mask");
193+
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
194+
"masked_select(): self and result must have the same scalar type");
195+
196+
auto mask_temp = (mask.dim() == 0)
197+
? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0))
198+
: c10::MaybeOwned<Tensor>::borrowed(mask);
199+
auto self_temp = (self.dim() == 0)
200+
? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0))
201+
: c10::MaybeOwned<Tensor>::borrowed(self);
202+
203+
// Cannot reassign to mask_temp and self_temp here! if they are
204+
// owning and expand_outplace returns a borrow, the returned borrow
205+
// would dangle.
206+
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
207+
at::index_out(
208+
result, *std::get<1>(mask_self_expanded),
209+
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));
210+
211+
return result;
212+
}
213+
214+
Tensor masked_select_mps(const Tensor & self, const Tensor & mask) {
215+
namedinference::compute_broadcast_outnames(self, mask);
216+
Tensor result = at::empty({0}, self.options());
217+
return masked_select_out_mps_impl(result, self, mask);
218+
}
219+
220+
Tensor & masked_select_out_mps(const Tensor & self, const Tensor & mask, Tensor & result) {
221+
namedinference::compute_broadcast_outnames(self, mask);
222+
return masked_select_out_mps_impl(result, self, mask);
223+
}
224+
189225
Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
190226
using namespace mps;
191227

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8026,13 +8026,15 @@
80268026
dispatch:
80278027
CPU: masked_select_out_cpu
80288028
CUDA: masked_select_out_cuda
8029+
MPS: masked_select_out_mps
80298030
tags: dynamic_output_shape
80308031

80318032
- func: masked_select(Tensor self, Tensor mask) -> Tensor
80328033
variants: method, function
80338034
dispatch:
80348035
CPU: masked_select_cpu
80358036
CUDA: masked_select_cuda
8037+
MPS: masked_select_mps
80368038
tags: dynamic_output_shape
80378039

80388040
- func: masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor

test/test_mps.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6048,6 +6048,17 @@ class TestAdvancedIndexing(TestCase):
60486048
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
60496049
supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
60506050

6051+
def test_masked_select(self):
6052+
x = torch.randn(3, 4)
6053+
x_mps = x.to("mps")
6054+
mask = x.ge(0.5)
6055+
mask_mps = x_mps.ge(0.5)
6056+
6057+
res = torch.masked_select(x, mask)
6058+
res_mps = torch.masked_select(x_mps, mask_mps)
6059+
6060+
self.assertEqual(res, res_mps)
6061+
60516062
# examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
60526063
def test_indexing_get(self):
60536064
def helper(dtype):
@@ -6390,10 +6401,10 @@ def test_index_put_accumulate_duplicate_indices(self, device="mps"):
63906401
delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
63916402

63926403
# cumsum not supported on 'mps', fallback on 'cpu'
6393-
indices = delta.to("cpu").cumsum(0).long().to("mps")
6404+
indices = delta.cpu().cumsum(0).long().to("mps")
63946405

63956406
# abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
6396-
input = torch.randn(indices.to("cpu").abs().to("mps").max() + 1, device=device)
6407+
input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
63976408
values = torch.randn(indices.size(0), device=device)
63986409
output = input.index_put((indices,), values, accumulate=True)
63996410

0 commit comments

Comments
 (0)