From 3c2859cd78f7ddb26635d8bd5140c7bf5fb44b38 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 26 Nov 2021 05:51:25 +0000 Subject: [PATCH] [batching] slice_scatter --- functorch/csrc/BatchRulesScatterOps.cpp | 35 +++++++++++++++++++++++++ test/test_ops.py | 1 - test/test_vmap.py | 1 - 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/functorch/csrc/BatchRulesScatterOps.cpp b/functorch/csrc/BatchRulesScatterOps.cpp index 75a730554..22afd1507 100644 --- a/functorch/csrc/BatchRulesScatterOps.cpp +++ b/functorch/csrc/BatchRulesScatterOps.cpp @@ -423,9 +423,44 @@ std::tuple> index_select_batch_rule( return std::make_tuple(result, 0); } +namespace { +Tensor get_expanded_index(const Tensor& index, IntArrayRef self_size, int64_t dim) { + if (index.dim() == 0) { + return index.expand(self_size); + } + + // setup new_index_shape as [BS, 1, ..., idx_size, ..., 1] + // to reshape index_ + auto idx_size = index.size(0); // get non-batch size of index tensor + Tensor index_; + { + VmapDimVector new_index_shape(self_size.size(), 1); + new_index_shape[dim] = idx_size; + index_ = index.reshape(new_index_shape); + } + // Now apply expand to index_ + { + VmapDimVector new_index_shape = {self_size.begin(), self_size.end()}; + new_index_shape[dim] = idx_size; + index_ = index_.expand(new_index_shape); + } + return index_; +} +} + +Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src, + int64_t dim, c10::optional start, + c10::optional end, int64_t step) { + auto idx = at::arange(start.value_or(0), end.value_or(self.size(dim)), step, self.options().dtype(kLong)); + idx = get_expanded_index(idx, self.sizes(), dim); + return at::scatter(self, dim, idx, src); + +} + TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { m.impl("index.Tensor", index_plumbing); m.impl("index_put_", index_put__plumbing); + m.impl("slice_scatter", slice_scatter_decomp); VMAP_SUPPORT("gather", gather_batch_rule); VMAP_SUPPORT("gather_backward", gather_backward_batch_rule); VMAP_SUPPORT("scatter.value", scatter_value_batch_rule); diff --git a/test/test_ops.py b/test/test_ops.py index 37938b7db..8053bce3f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -495,7 +495,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('nn.functional.instance_norm'), xfail('nn.functional.poisson_nll_loss'), xfail('select_scatter'), - xfail('slice_scatter'), })) def test_vmapvjp_has_batch_rule(self, device, dtype, op): # These are too annoying to put into the list above diff --git a/test/test_vmap.py b/test/test_vmap.py index f56bd74f1..f6960b176 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3195,7 +3195,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('searchsorted'), xfail('select_scatter'), xfail('short', 'channels_last'), - xfail('slice_scatter'), xfail('unique_consecutive'), xfail('unique'), xfail('nn.functional.conv1d'),