Skip to content

Commit 8acf2d1

Browse files
committed
Updated plumbing (manually)
1 parent 736cccd commit 8acf2d1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

functorch/csrc/VmapGeneratedPlumbing.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6042,13 +6042,13 @@ at::Tensor grid_sampler_3d_generated_plumbing(const at::Tensor & input, const at
60426042
return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
60436043
}
60446044
template <typename batch_rule_t, batch_rule_t batch_rule>
6045-
::std::tuple<at::Tensor,at::Tensor> grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
6045+
::std::tuple<at::Tensor,at::Tensor> grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, std::array<bool, 2> output_mask) {
60466046
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
60476047
auto maybe_layer = maybeCurrentDynamicLayer();
60486048
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
60496049
int64_t cur_level = maybe_layer->layerId();
60506050
if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) {
6051-
return ATEN_FN(grid_sampler_3d_backward)(grad_output, input, grid, interpolation_mode, padding_mode, align_corners);
6051+
return ATEN_FN(grid_sampler_3d_backward)(grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask);
60526052
}
60536053
Tensor grad_output_value;
60546054
optional<int64_t> grad_output_bdim;
@@ -6059,7 +6059,7 @@ ::std::tuple<at::Tensor,at::Tensor> grid_sampler_3d_backward_generated_plumbing(
60596059
Tensor grid_value;
60606060
optional<int64_t> grid_bdim;
60616061
std::tie(grid_value, grid_bdim) = unwrapTensorAtLevel(grid, cur_level);
6062-
auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners);
6062+
auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners, output_mask);
60636063
return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level));
60646064
}
60656065
template <typename batch_rule_t, batch_rule_t batch_rule>

0 commit comments

Comments
 (0)