Skip to content

Commit 3787ca3

Browse files
zrpherculefacebook-github-bot
authored andcommitted
Migrate away from CustomFuseGraph (#3403)
Summary: This is basiclly the glow version of pytorch/tvm#72 Will not use PyTorch's customFuseNode anymore. Will add comment indicate the copied code and fix the lint once finished. Please dont give detailed review until WIP is removed, but feel free to leave any big-scope opinion. Pull Request resolved: #3403 Reviewed By: jackm321 Differential Revision: D16775646 Pulled By: zrphercule fbshipit-source-id: a6d4dd757bf0db2ec0f4092330962b7e7fdf241d
1 parent 3e24abe commit 3787ca3

File tree

6 files changed

+230
-110
lines changed

6 files changed

+230
-110
lines changed

torch_glow/src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ link_directories(${PYTORCH_DIR}/lib)
1313

1414
add_library(PyTorchModelLoader
1515
PyTorchCommon.cpp
16-
FusingOptimizer.cpp
16+
GlowFuser.cpp
1717
PyTorchModelLoader.cpp
1818
CachingGraphRunner.cpp)
1919
target_compile_options(PyTorchModelLoader

torch_glow/src/FusingOptimizer.cpp

Lines changed: 0 additions & 92 deletions
This file was deleted.

torch_glow/src/GlowFuser.cpp

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "GlowFuser.h"
18+
19+
#include <llvm/Support/raw_ostream.h>
20+
#include <torch/csrc/jit/jit_log.h>
21+
#include <torch/csrc/jit/passes/alias_analysis.h>
22+
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
23+
#include <torch/csrc/jit/passes/dead_code_elimination.h>
24+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
25+
26+
namespace glow {
27+
28+
// This is mainly copied from pytorch/tvm
29+
// This pass fuse the addmm or matmul + add generated by JIT back to linear
30+
// to allow direct support with Glow integration with Glow IR
31+
// This pass can be deleted once the JIT can emit the aten::linear in the future
32+
void FuseLinear(std::shared_ptr<torch::jit::Graph> &graph) {
33+
std::string addmm_pattern = R"IR(
34+
graph(%input, %weight, %bias, %4):
35+
%weight_t = aten::t(%weight)
36+
%res = aten::addmm(%bias, %input, %weight_t, %4, %4)
37+
return (%res))IR";
38+
std::string matmul_add_pattern = R"IR(
39+
graph(%input, %weight, %bias, %4):
40+
%weight_t = aten::t(%weight)
41+
%output = aten::matmul(%input, %weight_t)
42+
%res = aten::add_(%output, %bias, %4)
43+
return (%res))IR";
44+
std::string mm_add_pattern = R"IR(
45+
graph(%input, %weight, %bias, %4):
46+
%weight_t = aten::t(%weight)
47+
%output = aten::mm(%input, %weight_t)
48+
%res = aten::add_(%output, %bias, %4)
49+
return (%res))IR";
50+
std::string fused_linear = R"IR(
51+
graph(%input, %weight, %bias, %4):
52+
%res = aten::linear(%input, %weight, %bias)
53+
return (%res))IR";
54+
55+
std::string matmul_pattern = R"IR(
56+
graph(%input, %weight):
57+
%weight_t = aten::t(%weight)
58+
%output = aten::matmul(%input, %weight_t)
59+
return (%output))IR";
60+
std::string mm_pattern = R"IR(
61+
graph(%input, %weight):
62+
%weight_t = aten::t(%weight)
63+
%output = aten::mm(%input, %weight_t)
64+
return (%output))IR";
65+
std::string fused_linear_bias_none = R"IR(
66+
graph(%input, %weight):
67+
%bias: Tensor? = prim::Constant()
68+
%res = aten::linear(%input, %weight, %bias)
69+
return (%res))IR";
70+
71+
// replace addmm pattern to linear
72+
torch::jit::SubgraphRewriter addmm_to_linear;
73+
addmm_to_linear.RegisterRewritePattern(addmm_pattern, fused_linear);
74+
addmm_to_linear.runOnGraph(graph);
75+
76+
// replace matmul + add pattern to linear
77+
torch::jit::SubgraphRewriter matmuladd_to_linear;
78+
matmuladd_to_linear.RegisterRewritePattern(matmul_add_pattern, fused_linear);
79+
matmuladd_to_linear.runOnGraph(graph);
80+
81+
// replace mm + add pattern to linear
82+
torch::jit::SubgraphRewriter mmadd_to_linear;
83+
mmadd_to_linear.RegisterRewritePattern(mm_add_pattern, fused_linear);
84+
mmadd_to_linear.runOnGraph(graph);
85+
86+
// replace matmul with bias=None pattern to linear
87+
torch::jit::SubgraphRewriter matmul_to_linear;
88+
matmul_to_linear.RegisterRewritePattern(matmul_pattern,
89+
fused_linear_bias_none);
90+
matmul_to_linear.runOnGraph(graph);
91+
92+
// replace mm with bias=None pattern to linear
93+
torch::jit::SubgraphRewriter mm_to_linear;
94+
mm_to_linear.RegisterRewritePattern(mm_pattern, fused_linear_bias_none);
95+
mm_to_linear.runOnGraph(graph);
96+
}
97+
98+
torch::jit::value_list
99+
sortReverseTopological(at::ArrayRef<torch::jit::Value *> inputs,
100+
torch::jit::Block *block) {
101+
torch::jit::value_list result;
102+
for (auto i : inputs) {
103+
if (i->node()->owningBlock() == block) {
104+
result.push_back(i);
105+
}
106+
}
107+
108+
std::sort(result.begin(), result.end(),
109+
[&](torch::jit::Value *a, torch::jit::Value *b) {
110+
return a->node()->isAfter(b->node());
111+
});
112+
return result;
113+
}
114+
115+
bool canMerge(torch::jit::Node *node, isSupportFunc fn) {
116+
return node->kind() == torch::jit::prim::Constant || fn(node);
117+
}
118+
119+
bool canMerge(torch::jit::Block *block, isSupportFunc fn) {
120+
for (torch::jit::Node *node : block->nodes()) {
121+
if (!canMerge(node, fn)) {
122+
return false;
123+
}
124+
}
125+
return true;
126+
}
127+
128+
#define REQ(cond, log_info) \
129+
if (!(cond)) { \
130+
llvm::errs() << log_info; \
131+
return c10::nullopt; \
132+
}
133+
134+
c10::optional<torch::jit::Node *> tryMerge(torch::jit::Node *consumer,
135+
torch::jit::Node *producer,
136+
torch::jit::AliasDb &aliasDb,
137+
isSupportFunc fn, at::Symbol kind) {
138+
139+
std::string symbol_name_producer = producer->kind().toQualString();
140+
std::string symbol_name_consumer = consumer->kind().toQualString();
141+
REQ(canMerge(producer, fn),
142+
"Detected unknown node: " + symbol_name_producer + ".\n")
143+
REQ(consumer->kind() == kind || canMerge(consumer, fn),
144+
"Detected unknown node: " + symbol_name_consumer + ".\n")
145+
146+
// Alias checks
147+
// Requirement:
148+
// - moveAfterTopologicallyValid(consumer, producer)
149+
// - One of:
150+
// 1) Both are in-place ops
151+
// 2) Consumer is in-place, producer !hasInputWriters
152+
// 3) Producer is in-place, consumer !hasOutputWriters
153+
REQ(aliasDb.moveAfterTopologicallyValid(consumer, producer),
154+
"Unable to move after topologically valid.");
155+
156+
// 1)
157+
if (!(aliasDb.isMutable(consumer) && aliasDb.isMutable(producer))) {
158+
// 2)
159+
if (aliasDb.isMutable(consumer)) {
160+
REQ(!aliasDb.hasInputWriters(producer),
161+
"Producer does not have input writer when merging.");
162+
// 3)
163+
} else if (aliasDb.isMutable(producer)) {
164+
REQ(!aliasDb.hasOutputWriters(consumer),
165+
"Consumer does not have output writer when merging.");
166+
}
167+
}
168+
169+
if (!consumer->hasAttribute(torch::jit::attr::Subgraph) &&
170+
consumer->kind() != kind) {
171+
consumer =
172+
torch::jit::SubgraphUtils::createSingletonSubgraph(consumer, kind);
173+
}
174+
if (producer->kind() == torch::jit::prim::Constant) {
175+
auto &subgraph = consumer->g(torch::jit::attr::Subgraph);
176+
torch::jit::Node *in_const = subgraph->createClone(
177+
producer, [](torch::jit::Value *) -> torch::jit::Value * {
178+
throw std::runtime_error("unexpected input");
179+
});
180+
subgraph->insertNode(in_const);
181+
} else {
182+
torch::jit::SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
183+
}
184+
return consumer;
185+
}
186+
#undef REQ
187+
188+
torch::jit::graph_node_list::iterator
189+
getNewNode(torch::jit::Node *node, torch::jit::AliasDb &aliasDb,
190+
torch::jit::Block *block, isSupportFunc fn, at::Symbol kind) {
191+
auto node_inputs = sortReverseTopological(node->inputs(), block);
192+
for (auto input : node_inputs) {
193+
if (auto group = tryMerge(node, input->node(), aliasDb, fn, kind)) {
194+
return group.value()->reverseIterator();
195+
}
196+
}
197+
return ++node->reverseIterator();
198+
}
199+
200+
void GlowCustomFuse(std::shared_ptr<torch::jit::Graph> graph, isSupportFunc fn,
201+
at::Symbol kind) {
202+
torch::jit::AliasDb aliasDb(graph);
203+
auto block = graph->block();
204+
205+
for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
206+
it = getNewNode(*it, aliasDb, block, fn, kind);
207+
}
208+
EliminateCommonSubexpression(graph);
209+
EliminateDeadCode(graph);
210+
}
211+
212+
} // namespace glow

torch_glow/src/FusingOptimizer.h renamed to torch_glow/src/GlowFuser.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
* limitations under the License.
1515
*/
1616

17-
#ifndef GLOW_TORCH_GLOW_SRC_FUSINGOPTIMIZER_H
18-
#define GLOW_TORCH_GLOW_SRC_FUSINGOPTIMIZER_H
19-
2017
#include <torch/csrc/jit/ir.h>
18+
#include <torch/csrc/jit/passes/graph_fuser.h>
19+
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
2120

2221
namespace glow {
22+
23+
typedef std::function<bool(torch::jit::Node *)> isSupportFunc;
24+
2325
/// Performs specific fusion for Linear operator.
2426
void FuseLinear(std::shared_ptr<torch::jit::Graph> &graph);
27+
void GlowCustomFuse(std::shared_ptr<torch::jit::Graph> graph, isSupportFunc fn,
28+
at::Symbol kind);
2529
} // namespace glow
26-
27-
#endif // GLOW_TORCH_GLOW_SRC_FUSINGOPTIMIZER_H

torch_glow/src/PyTorchCommon.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
*/
1616

1717
#include "PyTorchCommon.h"
18-
#include "FusingOptimizer.h"
18+
#include "GlowFuser.h"
1919
#include "PyTorchModelLoader.h"
2020
#include <torch/csrc/jit/passes/graph_fuser.h>
2121

@@ -40,8 +40,7 @@ void glowCustomFuse(std::shared_ptr<torch::jit::Graph> &g,
4040
// aten::linear before we fuse the whole graph.
4141
FuseLinear(g);
4242

43-
torch::jit::CustomFuseGraph(g, PyTorchModelLoader::isNodeSupported,
44-
fuseSymbol);
43+
GlowCustomFuse(g, PyTorchModelLoader::isNodeSupported, fuseSymbol);
4544
}
4645

4746
} // namespace glow

0 commit comments

Comments
 (0)