Skip to content

Commit 82002db

Browse files
authored
Lowering aten::pad to aten::constant_pad_nd/aten::reflection_padXd/aten::replication_padXd (#1588)
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent dd34ec1 commit 82002db

File tree

7 files changed

+358
-0
lines changed

7 files changed

+358
-0
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
144144
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
145145
passes::ReplaceScalarImplicit(g);
146146
passes::RewriteInputsWithParams(g, params);
147+
passes::ReplaceAtenPad(g);
147148
LOG_GRAPH(*g);
148149
}
149150

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ cc_library(
2828
"remove_dropout.cpp",
2929
"remove_nops.cpp",
3030
"remove_unnecessary_casts.cpp",
31+
"replace_aten_pad.cpp",
3132
"rewrite_inputs_with_params.cpp",
3233
"silu_to_sigmoid_multiplication.cpp",
3334
"unpack_addmm.cpp",

core/lowering/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ target_sources(${lib_name}
1515
"${CMAKE_CURRENT_SOURCE_DIR}/remove_nops.cpp"
1616
"${CMAKE_CURRENT_SOURCE_DIR}/remove_set_attrs.cpp"
1717
"${CMAKE_CURRENT_SOURCE_DIR}/remove_unnecessary_casts.cpp"
18+
"${CMAKE_CURRENT_SOURCE_DIR}/replace_aten_pad.cpp"
1819
"${CMAKE_CURRENT_SOURCE_DIR}/silu_to_sigmoid_multiplication.cpp"
1920
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_addmm.cpp"
2021
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_batch_norm.cpp"

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::str
4545
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
4646
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
4747
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
48+
void ReplaceAtenPad(std::shared_ptr<torch::jit::Graph>& graph);
4849

4950
// utility functions exposed for testing
5051
std::string unmangle_cls_name(const std::string& name);
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void ReplaceAtenPad(std::shared_ptr<torch::jit::Graph>& graph) {
11+
for (auto it = graph->block()->nodes().begin(), end = graph->block()->nodes().end(); it != end; ++it) {
12+
if (it->kind() == c10::Symbol::fromQualString("aten::pad")) {
13+
// aten::pad(Tensor self, int[] pad, str mode='constant', float? value=None) -> (Tensor)
14+
auto mode = it->inputs()[2];
15+
if (mode->type()->isSubtypeOf(c10::StringType::get())) {
16+
std::string mode_str = torch::jit::toIValue(mode)->to<std::string>();
17+
if (mode_str == "reflect") {
18+
auto pad = it->inputs()[1];
19+
c10::List<int64_t> pad_list = torch::jit::toIValue(pad)->to<c10::List<int64_t>>();
20+
if (pad_list.size() == 2) {
21+
// aten::reflection_pad1d(Tensor self, int[2] padding) -> (Tensor)
22+
torch::jit::Node* new_node;
23+
new_node = graph->create(
24+
c10::Symbol::fromQualString("aten::reflection_pad1d"),
25+
torch::jit::ArrayRef<torch::jit::Value*>({it->inputs()[0], it->inputs()[1]}),
26+
1);
27+
new_node->insertAfter(*it);
28+
new_node->outputs()[0]->setType(c10::TensorType::get());
29+
it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
30+
auto pre = --it;
31+
++it;
32+
it->destroy();
33+
it = pre;
34+
} else if (pad_list.size() == 4) {
35+
// aten::reflection_pad2d(Tensor self, int[4] padding) -> (Tensor)
36+
torch::jit::Node* new_node;
37+
new_node = graph->create(
38+
c10::Symbol::fromQualString("aten::reflection_pad2d"),
39+
torch::jit::ArrayRef<torch::jit::Value*>({it->inputs()[0], it->inputs()[1]}),
40+
1);
41+
new_node->insertAfter(*it);
42+
new_node->outputs()[0]->setType(c10::TensorType::get());
43+
it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
44+
auto pre = --it;
45+
++it;
46+
it->destroy();
47+
it = pre;
48+
} else if (pad_list.size() == 6) {
49+
LOG_ERROR("Torch-TRT doesn't support aten::reflection_pad3d currently.");
50+
}
51+
52+
} else if (mode_str == "replicate") {
53+
auto pad = it->inputs()[1];
54+
c10::List<int64_t> pad_list = torch::jit::toIValue(pad)->to<c10::List<int64_t>>();
55+
if (pad_list.size() == 2) {
56+
// aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor)
57+
torch::jit::Node* new_node;
58+
new_node = graph->create(
59+
c10::Symbol::fromQualString("aten::replication_pad1d"),
60+
torch::jit::ArrayRef<torch::jit::Value*>({it->inputs()[0], it->inputs()[1]}),
61+
1);
62+
new_node->insertAfter(*it);
63+
new_node->outputs()[0]->setType(c10::TensorType::get());
64+
it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
65+
auto pre = --it;
66+
++it;
67+
it->destroy();
68+
it = pre;
69+
} else if (pad_list.size() == 4) {
70+
// aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor)
71+
torch::jit::Node* new_node;
72+
new_node = graph->create(
73+
c10::Symbol::fromQualString("aten::replication_pad2d"),
74+
torch::jit::ArrayRef<torch::jit::Value*>({it->inputs()[0], it->inputs()[1]}),
75+
1);
76+
new_node->insertAfter(*it);
77+
new_node->outputs()[0]->setType(c10::TensorType::get());
78+
it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
79+
auto pre = --it;
80+
++it;
81+
it->destroy();
82+
it = pre;
83+
} else if (pad_list.size() == 6) {
84+
// aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor)
85+
torch::jit::Node* new_node;
86+
new_node = graph->create(
87+
c10::Symbol::fromQualString("aten::replication_pad3d"),
88+
torch::jit::ArrayRef<torch::jit::Value*>({it->inputs()[0], it->inputs()[1]}),
89+
1);
90+
new_node->insertAfter(*it);
91+
new_node->outputs()[0]->setType(c10::TensorType::get());
92+
it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
93+
auto pre = --it;
94+
++it;
95+
it->destroy();
96+
it = pre;
97+
}
98+
99+
} else if (mode_str == "constant") {
100+
// aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)
101+
torch::jit::Node* new_node;
102+
new_node = graph->create(
103+
c10::Symbol::fromQualString("aten::constant_pad_nd"),
104+
torch::jit::ArrayRef<torch::jit::Value*>({it->inputs()[0], it->inputs()[1], it->inputs()[3]}),
105+
1);
106+
new_node->insertAfter(*it);
107+
new_node->outputs()[0]->setType(c10::TensorType::get());
108+
it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
109+
auto pre = --it;
110+
++it;
111+
it->destroy();
112+
it = pre;
113+
} else if (mode_str == "circular") {
114+
LOG_ERROR("Torch-TRT doesn't support circular padding currently.");
115+
}
116+
}
117+
}
118+
}
119+
LOG_GRAPH("Post map aten::pad -> aten::constant_pad_nd/aten::reflection_padXd/aten::replication_padXd: " << *graph);
120+
}
121+
122+
} // namespace passes
123+
} // namespace lowering
124+
} // namespace core
125+
} // namespace torch_tensorrt

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ lowering_test(
9595
name = "test_rewrite_inputs_with_params",
9696
)
9797

98+
lowering_test(
99+
name = "test_replace_aten_pad_pass",
100+
)
101+
98102
test_suite(
99103
name = "lowering_tests",
100104
tests = [
@@ -111,6 +115,7 @@ test_suite(
111115
":test_remove_detach_pass",
112116
":test_remove_dropout_pass",
113117
":test_remove_unnecessary_casts",
118+
":test_replace_aten_pad_pass",
114119
":test_rewrite_inputs_with_params",
115120
":test_unpack_hardsigmoid",
116121
":test_unpack_hardswish",

0 commit comments

Comments
 (0)