|
| 1 | +/** |
| 2 | + * Copyright (c) 2017-present, Facebook, Inc. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "GlowFuser.h" |
| 18 | + |
| 19 | +#include <llvm/Support/raw_ostream.h> |
| 20 | +#include <torch/csrc/jit/jit_log.h> |
| 21 | +#include <torch/csrc/jit/passes/alias_analysis.h> |
| 22 | +#include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
| 23 | +#include <torch/csrc/jit/passes/dead_code_elimination.h> |
| 24 | +#include <torch/csrc/jit/passes/subgraph_rewrite.h> |
| 25 | + |
| 26 | +namespace glow { |
| 27 | + |
| 28 | +// This is mainly copied from pytorch/tvm |
| 29 | +// This pass fuse the addmm or matmul + add generated by JIT back to linear |
| 30 | +// to allow direct support with Glow integration with Glow IR |
| 31 | +// This pass can be deleted once the JIT can emit the aten::linear in the future |
| 32 | +void FuseLinear(std::shared_ptr<torch::jit::Graph> &graph) { |
| 33 | + std::string addmm_pattern = R"IR( |
| 34 | +graph(%input, %weight, %bias, %4): |
| 35 | + %weight_t = aten::t(%weight) |
| 36 | + %res = aten::addmm(%bias, %input, %weight_t, %4, %4) |
| 37 | + return (%res))IR"; |
| 38 | + std::string matmul_add_pattern = R"IR( |
| 39 | +graph(%input, %weight, %bias, %4): |
| 40 | + %weight_t = aten::t(%weight) |
| 41 | + %output = aten::matmul(%input, %weight_t) |
| 42 | + %res = aten::add_(%output, %bias, %4) |
| 43 | + return (%res))IR"; |
| 44 | + std::string mm_add_pattern = R"IR( |
| 45 | +graph(%input, %weight, %bias, %4): |
| 46 | + %weight_t = aten::t(%weight) |
| 47 | + %output = aten::mm(%input, %weight_t) |
| 48 | + %res = aten::add_(%output, %bias, %4) |
| 49 | + return (%res))IR"; |
| 50 | + std::string fused_linear = R"IR( |
| 51 | +graph(%input, %weight, %bias, %4): |
| 52 | + %res = aten::linear(%input, %weight, %bias) |
| 53 | + return (%res))IR"; |
| 54 | + |
| 55 | + std::string matmul_pattern = R"IR( |
| 56 | +graph(%input, %weight): |
| 57 | + %weight_t = aten::t(%weight) |
| 58 | + %output = aten::matmul(%input, %weight_t) |
| 59 | + return (%output))IR"; |
| 60 | + std::string mm_pattern = R"IR( |
| 61 | +graph(%input, %weight): |
| 62 | + %weight_t = aten::t(%weight) |
| 63 | + %output = aten::mm(%input, %weight_t) |
| 64 | + return (%output))IR"; |
| 65 | + std::string fused_linear_bias_none = R"IR( |
| 66 | +graph(%input, %weight): |
| 67 | + %bias: Tensor? = prim::Constant() |
| 68 | + %res = aten::linear(%input, %weight, %bias) |
| 69 | + return (%res))IR"; |
| 70 | + |
| 71 | + // replace addmm pattern to linear |
| 72 | + torch::jit::SubgraphRewriter addmm_to_linear; |
| 73 | + addmm_to_linear.RegisterRewritePattern(addmm_pattern, fused_linear); |
| 74 | + addmm_to_linear.runOnGraph(graph); |
| 75 | + |
| 76 | + // replace matmul + add pattern to linear |
| 77 | + torch::jit::SubgraphRewriter matmuladd_to_linear; |
| 78 | + matmuladd_to_linear.RegisterRewritePattern(matmul_add_pattern, fused_linear); |
| 79 | + matmuladd_to_linear.runOnGraph(graph); |
| 80 | + |
| 81 | + // replace mm + add pattern to linear |
| 82 | + torch::jit::SubgraphRewriter mmadd_to_linear; |
| 83 | + mmadd_to_linear.RegisterRewritePattern(mm_add_pattern, fused_linear); |
| 84 | + mmadd_to_linear.runOnGraph(graph); |
| 85 | + |
| 86 | + // replace matmul with bias=None pattern to linear |
| 87 | + torch::jit::SubgraphRewriter matmul_to_linear; |
| 88 | + matmul_to_linear.RegisterRewritePattern(matmul_pattern, |
| 89 | + fused_linear_bias_none); |
| 90 | + matmul_to_linear.runOnGraph(graph); |
| 91 | + |
| 92 | + // replace mm with bias=None pattern to linear |
| 93 | + torch::jit::SubgraphRewriter mm_to_linear; |
| 94 | + mm_to_linear.RegisterRewritePattern(mm_pattern, fused_linear_bias_none); |
| 95 | + mm_to_linear.runOnGraph(graph); |
| 96 | +} |
| 97 | + |
| 98 | +torch::jit::value_list |
| 99 | +sortReverseTopological(at::ArrayRef<torch::jit::Value *> inputs, |
| 100 | + torch::jit::Block *block) { |
| 101 | + torch::jit::value_list result; |
| 102 | + for (auto i : inputs) { |
| 103 | + if (i->node()->owningBlock() == block) { |
| 104 | + result.push_back(i); |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + std::sort(result.begin(), result.end(), |
| 109 | + [&](torch::jit::Value *a, torch::jit::Value *b) { |
| 110 | + return a->node()->isAfter(b->node()); |
| 111 | + }); |
| 112 | + return result; |
| 113 | +} |
| 114 | + |
| 115 | +bool canMerge(torch::jit::Node *node, isSupportFunc fn) { |
| 116 | + return node->kind() == torch::jit::prim::Constant || fn(node); |
| 117 | +} |
| 118 | + |
| 119 | +bool canMerge(torch::jit::Block *block, isSupportFunc fn) { |
| 120 | + for (torch::jit::Node *node : block->nodes()) { |
| 121 | + if (!canMerge(node, fn)) { |
| 122 | + return false; |
| 123 | + } |
| 124 | + } |
| 125 | + return true; |
| 126 | +} |
| 127 | + |
| 128 | +#define REQ(cond, log_info) \ |
| 129 | + if (!(cond)) { \ |
| 130 | + llvm::errs() << log_info; \ |
| 131 | + return c10::nullopt; \ |
| 132 | + } |
| 133 | + |
| 134 | +c10::optional<torch::jit::Node *> tryMerge(torch::jit::Node *consumer, |
| 135 | + torch::jit::Node *producer, |
| 136 | + torch::jit::AliasDb &aliasDb, |
| 137 | + isSupportFunc fn, at::Symbol kind) { |
| 138 | + |
| 139 | + std::string symbol_name_producer = producer->kind().toQualString(); |
| 140 | + std::string symbol_name_consumer = consumer->kind().toQualString(); |
| 141 | + REQ(canMerge(producer, fn), |
| 142 | + "Detected unknown node: " + symbol_name_producer + ".\n") |
| 143 | + REQ(consumer->kind() == kind || canMerge(consumer, fn), |
| 144 | + "Detected unknown node: " + symbol_name_consumer + ".\n") |
| 145 | + |
| 146 | + // Alias checks |
| 147 | + // Requirement: |
| 148 | + // - moveAfterTopologicallyValid(consumer, producer) |
| 149 | + // - One of: |
| 150 | + // 1) Both are in-place ops |
| 151 | + // 2) Consumer is in-place, producer !hasInputWriters |
| 152 | + // 3) Producer is in-place, consumer !hasOutputWriters |
| 153 | + REQ(aliasDb.moveAfterTopologicallyValid(consumer, producer), |
| 154 | + "Unable to move after topologically valid."); |
| 155 | + |
| 156 | + // 1) |
| 157 | + if (!(aliasDb.isMutable(consumer) && aliasDb.isMutable(producer))) { |
| 158 | + // 2) |
| 159 | + if (aliasDb.isMutable(consumer)) { |
| 160 | + REQ(!aliasDb.hasInputWriters(producer), |
| 161 | + "Producer does not have input writer when merging."); |
| 162 | + // 3) |
| 163 | + } else if (aliasDb.isMutable(producer)) { |
| 164 | + REQ(!aliasDb.hasOutputWriters(consumer), |
| 165 | + "Consumer does not have output writer when merging."); |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + if (!consumer->hasAttribute(torch::jit::attr::Subgraph) && |
| 170 | + consumer->kind() != kind) { |
| 171 | + consumer = |
| 172 | + torch::jit::SubgraphUtils::createSingletonSubgraph(consumer, kind); |
| 173 | + } |
| 174 | + if (producer->kind() == torch::jit::prim::Constant) { |
| 175 | + auto &subgraph = consumer->g(torch::jit::attr::Subgraph); |
| 176 | + torch::jit::Node *in_const = subgraph->createClone( |
| 177 | + producer, [](torch::jit::Value *) -> torch::jit::Value * { |
| 178 | + throw std::runtime_error("unexpected input"); |
| 179 | + }); |
| 180 | + subgraph->insertNode(in_const); |
| 181 | + } else { |
| 182 | + torch::jit::SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer); |
| 183 | + } |
| 184 | + return consumer; |
| 185 | +} |
| 186 | +#undef REQ |
| 187 | + |
| 188 | +torch::jit::graph_node_list::iterator |
| 189 | +getNewNode(torch::jit::Node *node, torch::jit::AliasDb &aliasDb, |
| 190 | + torch::jit::Block *block, isSupportFunc fn, at::Symbol kind) { |
| 191 | + auto node_inputs = sortReverseTopological(node->inputs(), block); |
| 192 | + for (auto input : node_inputs) { |
| 193 | + if (auto group = tryMerge(node, input->node(), aliasDb, fn, kind)) { |
| 194 | + return group.value()->reverseIterator(); |
| 195 | + } |
| 196 | + } |
| 197 | + return ++node->reverseIterator(); |
| 198 | +} |
| 199 | + |
| 200 | +void GlowCustomFuse(std::shared_ptr<torch::jit::Graph> graph, isSupportFunc fn, |
| 201 | + at::Symbol kind) { |
| 202 | + torch::jit::AliasDb aliasDb(graph); |
| 203 | + auto block = graph->block(); |
| 204 | + |
| 205 | + for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) { |
| 206 | + it = getNewNode(*it, aliasDb, block, fn, kind); |
| 207 | + } |
| 208 | + EliminateCommonSubexpression(graph); |
| 209 | + EliminateDeadCode(graph); |
| 210 | +} |
| 211 | + |
| 212 | +} // namespace glow |
0 commit comments