|
| 1 | +#include <torch/csrc/autograd/record_function.h> |
| 2 | +#include <torch/csrc/jit/custom_operator.h> |
| 3 | +#include <torch/csrc/jit/operator_options.h> |
| 4 | +#include <torch/csrc/jit/jit_log.h> |
| 5 | +#include <torch/csrc/jit/pass_manager.h> |
| 6 | +#include <torch/csrc/jit/passes/alias_analysis.h> |
| 7 | +#include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
| 8 | +#include <torch/csrc/jit/passes/dead_code_elimination.h> |
| 9 | +#include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
| 10 | + |
| 11 | +using namespace torch::jit; |
| 12 | + |
| 13 | +namespace { |
| 14 | + |
| 15 | +const Symbol& getTensorExprSymbol() { |
| 16 | + static Symbol s = Symbol::fromQualString("tensorexpr::Group"); |
| 17 | + return s; |
| 18 | +} |
| 19 | + |
| 20 | +value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* block) { |
| 21 | + value_list result; |
| 22 | + for (auto i : inputs) { |
| 23 | + if (i->node()->owningBlock() == block) { |
| 24 | + result.push_back(i); |
| 25 | + } |
| 26 | + } |
| 27 | + // Sort in reverse topological order |
| 28 | + std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { |
| 29 | + return a->node()->isAfter(b->node()); |
| 30 | + }); |
| 31 | + return result; |
| 32 | +} |
| 33 | + |
| 34 | +bool isSupported(Node* node) { |
| 35 | + // TODO: |
| 36 | + return node->kind() == Symbol::fromQualString("aten::add"); |
| 37 | +} |
| 38 | + |
| 39 | +bool canHandle(Node* node, AliasDb& aliasDb) { |
| 40 | + if (node->kind() == prim::Constant) { |
| 41 | + return true; |
| 42 | + } |
| 43 | + if (node->kind() == prim::Loop) { |
| 44 | + return false; // TODO |
| 45 | + } |
| 46 | + return isSupported(node); |
| 47 | +} |
| 48 | + |
| 49 | +#define REQ(cond) \ |
| 50 | + if (!(cond)) { \ |
| 51 | + GRAPH_DEBUG("Failed cond " #cond "\n"); \ |
| 52 | + return c10::nullopt; \ |
| 53 | + } |
| 54 | + |
| 55 | +c10::optional<Node*> tryMerge( |
| 56 | + Node* consumer, |
| 57 | + Node* producer, |
| 58 | + AliasDb& aliasDb) { |
| 59 | + GRAPH_DEBUG( |
| 60 | + "Trying producer ", |
| 61 | + producer->kind().toQualString(), |
| 62 | + " and consumer ", |
| 63 | + consumer->kind().toQualString(), |
| 64 | + ":\n"); |
| 65 | + |
| 66 | + // Symbolic checks |
| 67 | + REQ(canHandle(producer, aliasDb)); |
| 68 | + REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTensorExprSymbol())); |
| 69 | + |
| 70 | + // Alias checks |
| 71 | + // Requirement: |
| 72 | + // - moveAfterTopologicallyValid(consumer, producer) |
| 73 | + // - One of: |
| 74 | + // 1) Both are in-place ops |
| 75 | + // 2) Consumer is in-place, producer !hasInputWriters |
| 76 | + // 3) Producer is in-place, consumer !hasOutputWriters |
| 77 | + REQ(aliasDb.moveAfterTopologicallyValid(consumer, producer)); |
| 78 | + |
| 79 | + // 1) |
| 80 | + if (!(aliasDb.isMutable(consumer) && aliasDb.isMutable(producer))) { |
| 81 | + // 2) |
| 82 | + if (aliasDb.isMutable(consumer)) { |
| 83 | + REQ(!aliasDb.hasInputWriters(producer)); |
| 84 | + // 3) |
| 85 | + } else if (aliasDb.isMutable(producer)) { |
| 86 | + REQ(!aliasDb.hasOutputWriters(consumer)); |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + if (!consumer->hasAttribute(attr::Subgraph) && |
| 91 | + consumer->kind() != getTensorExprSymbol()) { |
| 92 | + consumer = SubgraphUtils::createSingletonSubgraph(consumer, getTensorExprSymbol()); |
| 93 | + } |
| 94 | + if (producer->kind() == prim::Constant) { |
| 95 | + auto& subgraph = consumer->g(attr::Subgraph); |
| 96 | + Node* in_const = subgraph->createClone(producer, [](Value*) -> Value* { |
| 97 | + throw std::runtime_error("unexpected input"); |
| 98 | + }); |
| 99 | + subgraph->insertNode(in_const); |
| 100 | + } else { |
| 101 | + SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer); |
| 102 | + } |
| 103 | + return consumer; |
| 104 | +} |
| 105 | +#undef REQ |
| 106 | + |
| 107 | +std::pair<graph_node_list::iterator, bool> scanNode( |
| 108 | + Node* consumer, |
| 109 | + AliasDb& aliasDb, |
| 110 | + Block* block) { |
| 111 | + auto inputs = sortReverseTopological(consumer->inputs(), block); |
| 112 | + for (auto input : inputs) { |
| 113 | + if (auto group = tryMerge(consumer, input->node(), aliasDb)) { |
| 114 | + // we successfully merged, so the new group's `inputs` may have |
| 115 | + // changed. So rescan the new group for more merging opportunities. |
| 116 | + return {group.value()->reverseIterator(), true}; |
| 117 | + } |
| 118 | + } |
| 119 | + return {++consumer->reverseIterator(), false}; |
| 120 | +} |
| 121 | + |
| 122 | +void fuseTensorExprs(std::shared_ptr<Graph>& graph) { |
| 123 | + std::cout << "Entering TExprFuser\n"; |
| 124 | + std::cout << *graph; |
| 125 | + |
| 126 | + AliasDb aliasDb(graph); |
| 127 | + auto block = graph->block(); |
| 128 | + |
| 129 | + bool any_changed = true; |
| 130 | + while (any_changed) { |
| 131 | + any_changed = false; |
| 132 | + for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) { |
| 133 | + bool changed; |
| 134 | + std::tie(it, changed) = scanNode(*it, aliasDb, block); |
| 135 | + any_changed |= changed; |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + EliminateCommonSubexpression(graph); |
| 140 | + EliminateDeadCode(graph); |
| 141 | + |
| 142 | + std::cout << "Finishing TExprFuser\n"; |
| 143 | + std::cout << *graph; |
| 144 | +} |
| 145 | + |
| 146 | +Operation createTensorExprOp(const Node* node) { |
| 147 | + return [](Stack& stack) { |
| 148 | + RECORD_FUNCTION("TensorExprGroup", std::vector<c10::IValue>()); |
| 149 | + // Do something? |
| 150 | + return 0; |
| 151 | + }; |
| 152 | +} |
| 153 | + |
| 154 | +c10::OperatorOptions getAliasAnalysisOption(AliasAnalysisKind k) { |
| 155 | + auto options = c10::OperatorOptions(); |
| 156 | + options.setAliasAnalysis(k); |
| 157 | + return options; |
| 158 | +} |
| 159 | + |
| 160 | +RegisterOperators TensorExprOps({ |
| 161 | + torch::jit::Operator( |
| 162 | + getTensorExprSymbol(), |
| 163 | + createTensorExprOp, |
| 164 | + getAliasAnalysisOption(AliasAnalysisKind::PURE_FUNCTION) |
| 165 | + ), |
| 166 | + }); |
| 167 | + |
| 168 | +RegisterPass pass(fuseTensorExprs); |
| 169 | + |
| 170 | +} // namespace |
0 commit comments