Skip to content

Commit 723ac1d

Browse files
abhi-iyernarendasan
authored andcommitted
fix(): added some fixes, trt/jit output still mismatches
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent d7c3164 commit 723ac1d

File tree

5 files changed

+43
-4
lines changed

5 files changed

+43
-4
lines changed

core/conversion/converters/impl/lstm_cell.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
8484

8585
auto out2 = (args[5].isIValue() && args[5].IValue()->isNone()) ? mm2_out : add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n);
8686

87-
// gates
87+
// get all 4 gates
8888
auto add = ctx->net->addElementWise(*out1, *out2, nvinfer1::ElementWiseOperation::kSUM);
8989
TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n);
9090
auto add_out = add->getOutput(0);
@@ -135,14 +135,17 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
135135
TRTORCH_CHECK(in_cell, "Unable to create ElementWise layer from node: " << *n);
136136
auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);
137137
TRTORCH_CHECK(cy, "Unable to create ElementWise layer from node: " << *n);
138-
auto cy_out = ctx->AssociateValueAndTensor(n->outputs()[1], cy->getOutput(0));
138+
auto cy_out = cy->getOutput(0);
139139

140140
// compute hy
141141
auto cy_tanh = ctx->net->addActivation(*cy_out, nvinfer1::ActivationType::kTANH);
142142
TRTORCH_CHECK(cy_tanh, "Unable to create tanh activation layer from node: " << *n);
143143
auto hy = ctx->net->addElementWise(*outgate, *cy_tanh->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
144144
TRTORCH_CHECK(hy, "Unable to create ElementWise layer from node: " << *n);
145-
auto hy_out = ctx->AssociateValueAndTensor(n->outputs()[0], hy->getOutput(0));
145+
auto hy_out = hy->getOutput(0);
146+
147+
ctx->AssociateValueAndTensor(n->outputs()[0], hy_out);
148+
ctx->AssociateValueAndTensor(n->outputs()[1], cy_out);
146149

147150
LOG_DEBUG("Output tensor [hy] shape: " << hy_out->getDimensions());
148151
LOG_DEBUG("Output tensor [cy] shape: " << cy_out->getDimensions());

core/lowering/lowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
3232
passes::RemoveDropout(g);
3333
passes::FuseFlattenLinear(g);
3434
passes::Conv2DToConvolution(g);
35+
passes::Conv3DToConvolution(g);
3536
passes::FuseAddMMBranches(g);
3637
torch::jit::EliminateCommonSubexpression(g);
37-
torch::jit::UnrollLoops(g);
38+
//torch::jit::UnrollLoops(g);
3839
torch::jit::EliminateCommonSubexpression(g);
3940
passes::UnpackAddMM(g);
4041
//passes::UnpackBatchNorm(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
],
1515
srcs = [
1616
"conv2d_to_convolution.cpp",
17+
"conv3d_to_convolution.cpp",
1718
"exception_elimination.cpp",
1819
"fuse_addmm_branches.cpp",
1920
"fuse_flatten_linear.cpp",
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string conv3d_pattern = R"IR(
12+
graph(%x, %w, %b, %s, %p, %d, %g):
13+
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
14+
return (%4))IR";
15+
std::string convolution_pattern = R"IR(
16+
graph(%x, %w, %b, %s, %p, %d, %g):
17+
%1 : bool = prim::Constant[value=0]()
18+
%2 : int[] = prim::Constant[value=[0, 0]]()
19+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1)
20+
return (%4))IR";;
21+
22+
// replace matmul + add pattern to linear
23+
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
24+
map_conv3d_to_convolution.RegisterRewritePattern(
25+
conv3d_pattern, convolution_pattern);
26+
map_conv3d_to_convolution.runOnGraph(graph);
27+
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
28+
}
29+
30+
} // namespace passes
31+
} // namespace lowering
32+
} // namespace core
33+
} // namespace trtorch

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace lowering {
88
namespace passes {
99

1010
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
11+
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1112
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
1213
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
1314
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);

0 commit comments

Comments
 (0)