Skip to content

Commit a9d9919

Browse files
committed
[wip] Basic fuser pass to select texpr subgraphs
1 parent d1c7556 commit a9d9919

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
417417
${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
418418
${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp
419419
${TORCH_SRC_DIR}/csrc/jit/passes/subgraph_rewrite.cpp
420+
${TORCH_SRC_DIR}/csrc/jit/passes/tensorexpr_fuser.cpp
420421
${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
421422
${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
422423
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#include <torch/csrc/autograd/record_function.h>
2+
#include <torch/csrc/jit/custom_operator.h>
3+
#include <torch/csrc/jit/operator_options.h>
4+
#include <torch/csrc/jit/jit_log.h>
5+
#include <torch/csrc/jit/pass_manager.h>
6+
#include <torch/csrc/jit/passes/alias_analysis.h>
7+
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
8+
#include <torch/csrc/jit/passes/dead_code_elimination.h>
9+
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
10+
11+
using namespace torch::jit;
12+
13+
namespace {
14+
15+
const Symbol& getTensorExprSymbol() {
16+
static Symbol s = Symbol::fromQualString("tensorexpr::Group");
17+
return s;
18+
}
19+
20+
value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* block) {
21+
value_list result;
22+
for (auto i : inputs) {
23+
if (i->node()->owningBlock() == block) {
24+
result.push_back(i);
25+
}
26+
}
27+
// Sort in reverse topological order
28+
std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
29+
return a->node()->isAfter(b->node());
30+
});
31+
return result;
32+
}
33+
34+
bool isSupported(Node* node) {
35+
// TODO:
36+
return node->kind() == Symbol::fromQualString("aten::add");
37+
}
38+
39+
bool canHandle(Node* node, AliasDb& aliasDb) {
40+
if (node->kind() == prim::Constant) {
41+
return true;
42+
}
43+
if (node->kind() == prim::Loop) {
44+
return false; // TODO
45+
}
46+
return isSupported(node);
47+
}
48+
49+
#define REQ(cond) \
50+
if (!(cond)) { \
51+
GRAPH_DEBUG("Failed cond " #cond "\n"); \
52+
return c10::nullopt; \
53+
}
54+
55+
c10::optional<Node*> tryMerge(
56+
Node* consumer,
57+
Node* producer,
58+
AliasDb& aliasDb) {
59+
GRAPH_DEBUG(
60+
"Trying producer ",
61+
producer->kind().toQualString(),
62+
" and consumer ",
63+
consumer->kind().toQualString(),
64+
":\n");
65+
66+
// Symbolic checks
67+
REQ(canHandle(producer, aliasDb));
68+
REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTensorExprSymbol()));
69+
70+
// Alias checks
71+
// Requirement:
72+
// - moveAfterTopologicallyValid(consumer, producer)
73+
// - One of:
74+
// 1) Both are in-place ops
75+
// 2) Consumer is in-place, producer !hasInputWriters
76+
// 3) Producer is in-place, consumer !hasOutputWriters
77+
REQ(aliasDb.moveAfterTopologicallyValid(consumer, producer));
78+
79+
// 1)
80+
if (!(aliasDb.isMutable(consumer) && aliasDb.isMutable(producer))) {
81+
// 2)
82+
if (aliasDb.isMutable(consumer)) {
83+
REQ(!aliasDb.hasInputWriters(producer));
84+
// 3)
85+
} else if (aliasDb.isMutable(producer)) {
86+
REQ(!aliasDb.hasOutputWriters(consumer));
87+
}
88+
}
89+
90+
if (!consumer->hasAttribute(attr::Subgraph) &&
91+
consumer->kind() != getTensorExprSymbol()) {
92+
consumer = SubgraphUtils::createSingletonSubgraph(consumer, getTensorExprSymbol());
93+
}
94+
if (producer->kind() == prim::Constant) {
95+
auto& subgraph = consumer->g(attr::Subgraph);
96+
Node* in_const = subgraph->createClone(producer, [](Value*) -> Value* {
97+
throw std::runtime_error("unexpected input");
98+
});
99+
subgraph->insertNode(in_const);
100+
} else {
101+
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
102+
}
103+
return consumer;
104+
}
105+
#undef REQ
106+
107+
std::pair<graph_node_list::iterator, bool> scanNode(
108+
Node* consumer,
109+
AliasDb& aliasDb,
110+
Block* block) {
111+
auto inputs = sortReverseTopological(consumer->inputs(), block);
112+
for (auto input : inputs) {
113+
if (auto group = tryMerge(consumer, input->node(), aliasDb)) {
114+
// we successfully merged, so the new group's `inputs` may have
115+
// changed. So rescan the new group for more merging opportunities.
116+
return {group.value()->reverseIterator(), true};
117+
}
118+
}
119+
return {++consumer->reverseIterator(), false};
120+
}
121+
122+
void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
123+
std::cout << "Entering TExprFuser\n";
124+
std::cout << *graph;
125+
126+
AliasDb aliasDb(graph);
127+
auto block = graph->block();
128+
129+
bool any_changed = true;
130+
while (any_changed) {
131+
any_changed = false;
132+
for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
133+
bool changed;
134+
std::tie(it, changed) = scanNode(*it, aliasDb, block);
135+
any_changed |= changed;
136+
}
137+
}
138+
139+
EliminateCommonSubexpression(graph);
140+
EliminateDeadCode(graph);
141+
142+
std::cout << "Finishing TExprFuser\n";
143+
std::cout << *graph;
144+
}
145+
146+
Operation createTensorExprOp(const Node* node) {
147+
return [](Stack& stack) {
148+
RECORD_FUNCTION("TensorExprGroup", std::vector<c10::IValue>());
149+
// Do something?
150+
return 0;
151+
};
152+
}
153+
154+
c10::OperatorOptions getAliasAnalysisOption(AliasAnalysisKind k) {
155+
auto options = c10::OperatorOptions();
156+
options.setAliasAnalysis(k);
157+
return options;
158+
}
159+
160+
RegisterOperators TensorExprOps({
161+
torch::jit::Operator(
162+
getTensorExprSymbol(),
163+
createTensorExprOp,
164+
getAliasAnalysisOption(AliasAnalysisKind::PURE_FUNCTION)
165+
),
166+
});
167+
168+
RegisterPass pass(fuseTensorExprs);
169+
170+
} // namespace

0 commit comments

Comments
 (0)