Skip to content

Commit 3d14cda

Browse files
committed
feat(//core/lowering): Adding a new pass to handle new dim checks for
batchnorm Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6eeba1c commit 3d14cda

File tree

4 files changed

+91
-0
lines changed

4 files changed

+91
-0
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
4040
passes::Conv2DToConvolution(g);
4141
passes::Conv3DToConvolution(g);
4242
passes::FuseAddMMBranches(g);
43+
passes::RemoveBNDimCheck(g);
4344
torch::jit::EliminateCommonSubexpression(g);
4445
// torch::jit::UnrollLoops(g);
4546
torch::jit::EliminateCommonSubexpression(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cc_library(
1818
"exception_elimination.cpp",
1919
"fuse_addmm_branches.cpp",
2020
"fuse_flatten_linear.cpp",
21+
"remove_bn_dim_check.cpp",
2122
"remove_contiguous.cpp",
2223
"remove_dropout.cpp",
2324
"remove_to.cpp",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1212
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
1313
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
1414
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
15+
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
1516
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
1617
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
1718
void RemoveTo(std::shared_ptr<torch::jit::Graph> graph);
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include "torch/csrc/jit/ir/alias_analysis.h"
2+
#include "torch/csrc/jit/jit_log.h"
3+
#include "torch/csrc/jit/passes/constant_propagation.h"
4+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
5+
#include "torch/csrc/jit/passes/guard_elimination.h"
6+
#include "torch/csrc/jit/passes/peephole.h"
7+
#include "torch/csrc/jit/runtime/graph_executor.h"
8+
9+
#include "core/util/prelude.h"
10+
11+
#include <vector>
12+
13+
namespace trtorch {
14+
namespace core {
15+
namespace lowering {
16+
namespace passes {
17+
namespace {
18+
using namespace torch::jit;
19+
struct BNDimCheckRemoval {
20+
BNDimCheckRemoval(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}
21+
22+
void run() {
23+
findBNDimCheckNodes(graph_->block());
24+
torch::jit::EliminateDeadCode(graph_);
25+
LOG_GRAPH("Post aten::addmm branch fusion: " << *graph_);
26+
}
27+
28+
private:
29+
bool isBNDimCheckNodes(Node* n) {
30+
/// Check if this Node hosts a pattern like so:
31+
/// %290 : bool = aten::ne(%289, %9)
32+
/// = prim::If(%290)
33+
/// block0():
34+
/// %291 : str = aten::format(%10, %289)
35+
/// = prim::RaiseException(%291)
36+
/// -> ()
37+
/// block1():
38+
/// -> ()
39+
40+
if (n->blocks().size() != 2) {
41+
return false;
42+
}
43+
auto arm1 = n->blocks()[0];
44+
auto arm2 = n->blocks()[1];
45+
if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
46+
// Make sure that the node doesn't actually produce any Value that are
47+
// used by other nodes
48+
return false;
49+
}
50+
51+
auto arm1_start = arm1->nodes().begin();
52+
53+
if ((*arm1_start)->kind() != c10::Symbol::fromQualString("aten::format") && (*(++arm1_start))->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) {
54+
// Make sure that block0 is solely just the exception and the return
55+
return false;
56+
}
57+
58+
if ((*(arm2->nodes().begin()))->kind() != prim::Return) {
59+
// Make sure that block1 is solely the return
60+
return false;
61+
}
62+
63+
return true;
64+
}
65+
66+
void findBNDimCheckNodes(Block* b) {
67+
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
68+
auto n = *it;
69+
if (n->kind() == prim::If && isBNDimCheckNodes(n)) {
70+
LOG_GRAPH("Found that node " << *n << " is an batch norm dim check node (EliminateChecks)" << std::endl);
71+
it.destroyCurrent();
72+
}
73+
}
74+
}
75+
76+
std::shared_ptr<Graph> graph_;
77+
};
78+
} // namespace
79+
80+
void RemoveBNDimCheck(std::shared_ptr<Graph> graph) {
81+
BNDimCheckRemoval bndcr(std::move(graph));
82+
bndcr.run();
83+
}
84+
85+
} // namespace passes
86+
} // namespace lowering
87+
} // namespace core
88+
} // namespace trtorch

0 commit comments

Comments
 (0)