Skip to content

Commit 59f3c32

Browse files
authored
Output allocate patch (#1790)
Caching strides along with sizes. This is to support current expand, which introduces non-contiguous output tensor
1 parent fe93bf5 commit 59f3c32

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

torch/csrc/jit/codegen/cuda/executor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,9 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
762762
if (outputs.empty()) {
763763
FUSER_PERF_SCOPE("ExecutorRunFusion::OutputAlloc");
764764
for (const auto i : c10::irange(executor_entry->output_sizes.size())) {
765-
allocated_outputs.push_back(at::native::empty_cuda(
765+
allocated_outputs.push_back(at::native::empty_strided_cuda(
766766
executor_entry->output_sizes[i],
767+
executor_entry->output_strides[i],
767768
executor_entry->output_types[i],
768769
c10::nullopt,
769770
options_.device,
@@ -934,6 +935,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
934935
executor_entry->io_alias_indices = alias_indices;
935936
for (const auto& output : allocated_outputs) {
936937
executor_entry->output_sizes.push_back(output.sizes().vec());
938+
executor_entry->output_strides.push_back(output.strides().vec());
937939
executor_entry->output_types.push_back(output.scalar_type());
938940
}
939941

torch/csrc/jit/codegen/cuda/executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
7474
LaunchParams launch_params;
7575
std::vector<std::pair<int, int>> io_alias_indices;
7676
std::vector<std::vector<int64_t>> output_sizes;
77+
std::vector<std::vector<int64_t>> output_strides;
7778
std::vector<at::ScalarType> output_types;
7879
std::vector<std::vector<int64_t>> buffer_sizes;
7980
std::vector<at::ScalarType> buffer_types;

torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,40 @@ TEST_F(NVFuserTest, FusionComputeAtRootDomainMapWithView_CUDA) {
972972
tv1->axis(1)->toString());
973973
}
974974

975+
TEST_F(NVFuserTest, FusionExpandRepro_CUDA) {
976+
Fusion fusion;
977+
FusionGuard fg(&fusion);
978+
979+
const std::vector<int64_t> input_shape1{4, 1, 1};
980+
const std::vector<int64_t> input_shape2{4, 3, 2};
981+
982+
auto tv0 = makeConcreteTensor({-1, 1, 1});
983+
fusion.addInput(tv0);
984+
auto tv1 = makeSymbolicTensor(3);
985+
fusion.addInput(tv1);
986+
987+
auto tv2 = expand_as(tv0, tv1);
988+
fusion.addOutput(tv2);
989+
990+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
991+
at::Tensor at_x = at::randn(input_shape1, options);
992+
at::Tensor at_y = at::randn(input_shape2, options);
993+
std::vector<IValue> aten_inputs = {at_x, at_y};
994+
995+
FusionExecutor fe;
996+
fe.compileFusion(&fusion);
997+
LaunchParams l_params;
998+
auto outputs = fe.runFusion(aten_inputs, {}, l_params, 0);
999+
1000+
auto out = at_x.expand_as(at_y);
1001+
1002+
testValidate(&fusion, outputs, aten_inputs, {out}, __LINE__, __FILE__);
1003+
1004+
// second run to verify cached output allocation
1005+
outputs = fe.runFusion(aten_inputs, {}, l_params, 0);
1006+
testValidate(&fusion, outputs, aten_inputs, {out}, __LINE__, __FILE__);
1007+
}
1008+
9751009
} // namespace jit
9761010
} // namespace torch
9771011
#endif // #if defined(USE_CUDA)

0 commit comments

Comments
 (0)