@@ -6042,13 +6042,13 @@ at::Tensor grid_sampler_3d_generated_plumbing(const at::Tensor & input, const at
6042
6042
return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
6043
6043
}
6044
6044
template <typename batch_rule_t, batch_rule_t batch_rule>
6045
- ::std::tuple<at::Tensor,at::Tensor> grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
6045
+ ::std::tuple<at::Tensor,at::Tensor> grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, std::array<bool, 2> output_mask ) {
6046
6046
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
6047
6047
auto maybe_layer = maybeCurrentDynamicLayer();
6048
6048
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
6049
6049
int64_t cur_level = maybe_layer->layerId();
6050
6050
if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) {
6051
- return ATEN_FN(grid_sampler_3d_backward)(grad_output, input, grid, interpolation_mode, padding_mode, align_corners);
6051
+ return ATEN_FN(grid_sampler_3d_backward)(grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask );
6052
6052
}
6053
6053
Tensor grad_output_value;
6054
6054
optional<int64_t> grad_output_bdim;
@@ -6059,7 +6059,7 @@ ::std::tuple<at::Tensor,at::Tensor> grid_sampler_3d_backward_generated_plumbing(
6059
6059
Tensor grid_value;
6060
6060
optional<int64_t> grid_bdim;
6061
6061
std::tie(grid_value, grid_bdim) = unwrapTensorAtLevel(grid, cur_level);
6062
- auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners);
6062
+ auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners, output_mask );
6063
6063
return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level));
6064
6064
}
6065
6065
template <typename batch_rule_t, batch_rule_t batch_rule>
0 commit comments