Skip to content

Commit 7cf63d8

Browse files
Mike Iovinecyyever
Mike Iovine
authored andcommitted
[SR] Make sigrid_transforms fusion work on graph outputs (#73091)
Summary: Pull Request resolved: pytorch/pytorch#73091 This is a re-work of D33669034 (pytorch/pytorch@c76c491); the change was backed out due to a data race causing crashes. The `output_types` vector was the culprit. It was previously lazily initialized on the first iteration. This was problematic because of static runtime's hidden assumption that ops are thread-safe. The re-work now only does the list unpack fusion if the output dtypes can be statically determined, e.g. if the sigrid transforms instance and `use_offsets` are both constant. Note that this is true for all the models we care about. Also, we were already partially making this assumption by dereferencing the `std::optional` sigrid transforms instance in most of the ops. Another advantage of this is that it makes the code simpler compared to D33669034 (pytorch/pytorch@c76c491). Once the output types are determined, they can be moved into the op lambda and shared as read-only data. ghstack-source-id: 150704445 Reviewed By: d1jang Differential Revision: D34290401 fbshipit-source-id: 9573e6f08ee9e8282de961bf5f5cc8d32b81e601 (cherry picked from commit 715b0077bd18cb144b9653f5f51057b9440252ad)
1 parent d1c0fbb commit 7cf63d8

File tree

1 file changed

+36
-27
lines changed

1 file changed

+36
-27
lines changed

torch/csrc/jit/runtime/static/passes.cpp

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -718,8 +718,40 @@ void EliminateTrivialEquallySplit(std::shared_ptr<torch::jit::Graph>& graph) {
718718
}
719719
}
720720

721-
// NB: The alias type of the fused op needs to be changed to
722-
// c10::AliasAnalysisKind::PURE_FUNCTION to make alias analysis work.
721+
namespace {
722+
723+
bool shouldNotFuseListUnpackSpecialCase(const Node* node) {
724+
const static std::array<c10::Symbol, 3> sigrid_transforms_symbols{
725+
c10::Symbol::fromQualString("fb::variadic_sigrid_transforms_torch_bind"),
726+
c10::Symbol::fromQualString("fb::sigrid_transforms_torch_bind"),
727+
c10::Symbol::fromQualString("fb::sigrid_transforms")};
728+
729+
if (std::find(
730+
sigrid_transforms_symbols.begin(),
731+
sigrid_transforms_symbols.end(),
732+
node->kind()) == sigrid_transforms_symbols.end()) {
733+
return false;
734+
}
735+
736+
// To fuse with sigrid transforms, we must be able to statically determine
737+
// `instance` and `use_offsets` - these two together let us statically
738+
// determine the types of the outputs. Rationale: it is a huge pain to write
739+
// fused sigrid transforms without static type information, and these two
740+
// arguments are indeed statically known in every model we've seen.
741+
// The reason why trying to fuse the outputs is annoying without static type
742+
// information is that, if one of the outputs is not managed, you need to
743+
// reset to an empty tensor of the correct type each iteration. So, if we
744+
// can't collect types ahead of time, we would have to do it lazily on the
745+
// first iteration, which would could be wasteful in terms of time/memory
746+
// - either each thread would have its own set of output types, or we would
747+
// need a lock to prevent data races.
748+
const auto num_inputs = node->inputs().size();
749+
return !toIValue(node->input(0)).has_value() ||
750+
!toIValue(node->input(num_inputs - 1)).has_value();
751+
}
752+
753+
} // namespace
754+
723755
void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
724756
const FastMap<c10::Symbol, c10::Symbol> unfused_to_fused = {
725757
OP_PAIR("fb::equally_split", "static_runtime::fused_equally_split"),
@@ -746,12 +778,7 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
746778
OP_PAIR(
747779
"fb::split_and_squeeze", "static_runtime::fused_split_and_squeeze")};
748780

749-
AliasDb alias_db(
750-
graph,
751-
/*isFrozen=*/false);
752781
// replacement contains (old_node, new_node, list_unpack_node)
753-
const std::vector<Value*> graph_outputs(
754-
graph->outputs().begin(), graph->outputs().end());
755782
std::vector<std::tuple<Node*, Node*, Node*>> replacement;
756783
DepthFirstGraphNodeIterator graph_it(graph);
757784
for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
@@ -775,20 +802,8 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
775802
continue;
776803
}
777804

778-
const bool checks_all_outputs =
779-
node->kind() == fromQualString("fb::equally_split") ||
780-
node->kind() == fromQualString("fb::gather_ranges_to_dense") ||
781-
node->kind() == fromQualString("fb::gather_ranges_to_dense_v2");
782-
783-
if (!checks_all_outputs) {
784-
// If any output of the ListUnpack node is unmanaged, disable fusion
785-
// since the fused op assumes all outputs are either managed or not.
786-
// Ops excluded here check all outputs.
787-
const std::vector<Value*> list_unpack_outputs_vec(
788-
list_unpack_outputs.begin(), list_unpack_outputs.end());
789-
if (alias_db.mayContainAlias(list_unpack_outputs_vec, graph_outputs)) {
790-
continue;
791-
}
805+
if (shouldNotFuseListUnpackSpecialCase(node)) {
806+
continue;
792807
}
793808

794809
const auto& new_sym = unfused_to_fused_it->second;
@@ -815,12 +830,6 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
815830
list_unpack_node->destroy();
816831
old_node->destroy();
817832
}
818-
819-
#ifndef NDEBUG
820-
graph->lint();
821-
AliasDb db2(graph);
822-
torch::jit::Lint(&db2);
823-
#endif
824833
} // namespace jit
825834

826835
void EnableStaticRuntimeLayerNorm(std::shared_ptr<torch::jit::Graph>& graph) {

0 commit comments

Comments
 (0)