Skip to content

Commit c76c491

Browse files
Mike Iovinepytorchmergebot
Mike Iovine
authored andcommitted
[SR] Make fused_sigrid_transforms work on graph outputs (#71507)
Summary: Pull Request resolved: #71507 We previously disabled `FuseListUnpack` if the fused outputs of the op would alias the graph outputs. The concern was that some ops were assuming that `p_node->Output(0).isTensor()` implies `p_node->Output(i).isTensor()` for all `i > 0`. This condition can be violated if there exists both managed and unmanaged tensors in the output list. Instead of adding this special case and missing out on some fusions, we should implement fused ops correctly. Reviewed By: d1jang Differential Revision: D33669034 fbshipit-source-id: 8b291b5fe610ffbe47b88a5a018daa63cb5665b0 (cherry picked from commit c6cba23)
1 parent 2539b6a commit c76c491

File tree

1 file changed

+0
-18
lines changed

1 file changed

+0
-18
lines changed

torch/csrc/jit/runtime/static/passes.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,6 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
746746
AliasDb alias_db(
747747
graph,
748748
/*isFrozen=*/false);
749-
const std::vector<Value*> graph_outputs(
750-
graph->outputs().begin(), graph->outputs().end());
751749
auto nodes = graph->nodes();
752750
std::vector<Node*> to_remove;
753751
for (auto* node : nodes) {
@@ -771,22 +769,6 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
771769
continue;
772770
}
773771

774-
const bool checks_all_outputs =
775-
node->kind() == fromQualString("fb::equally_split") ||
776-
node->kind() == fromQualString("fb::gather_ranges_to_dense") ||
777-
node->kind() == fromQualString("fb::gather_ranges_to_dense_v2");
778-
779-
if (!checks_all_outputs) {
780-
// If any output of the ListUnpack node is unmanaged, disable fusion
781-
// since the fused op assumes all outputs are either managed or not.
782-
// Ops excluded here check all outputs.
783-
const std::vector<Value*> list_unpack_outputs_vec(
784-
list_unpack_outputs.begin(), list_unpack_outputs.end());
785-
if (alias_db.mayContainAlias(list_unpack_outputs_vec, graph_outputs)) {
786-
continue;
787-
}
788-
}
789-
790772
const auto& new_sym = unfused_to_fused_it->second;
791773
auto* new_node = graph->create(new_sym, 0);
792774

0 commit comments

Comments
 (0)