Skip to content

Commit 4d05b24

Browse files
author
mikeiovine
committed
[SR] Make sigrid_transforms fusion work on graph outputs
Pull Request resolved: #73091 This is a re-work of D33669034; the change was backed out due to a data race causing crashes. The `output_types` vector was the culprit. It was previously lazily initialized on the first iteration. This was problematic because of static runtime's hidden assumption that ops are thread-safe. The re-work now only does the list unpack fusion if the output dtypes can be statically determined, e.g. if the sigrid transforms instance and `use_offsets` are both constant. Note that this is true for all the models we care about. Also, we were already partially making this assumption by dereferencing the `std::optional` sigrid transforms instance in most of the ops. Another advantage of this is that it makes the code simpler compared to D33669034. Once the output types are determined, they can be moved into the op lambda and shared as read-only data. ghstack-source-id: 149756644 Differential Revision: [D34290401](https://our.internmc.facebook.com/intern/diff/D34290401/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34290401/)!
1 parent 987f146 commit 4d05b24

File tree

1 file changed

+36
-27
lines changed

1 file changed

+36
-27
lines changed

torch/csrc/jit/runtime/static/passes.cpp

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -718,8 +718,40 @@ void EliminateTrivialEquallySplit(std::shared_ptr<torch::jit::Graph>& graph) {
718718
}
719719
}
720720

721-
// NB: The alias type of the fused op needs to be changed to
722-
// c10::AliasAnalysisKind::PURE_FUNCTION to make alias analysis work.
721+
namespace {
722+
723+
bool shouldNotFuseListUnpackSpecialCase(Node* node) {
724+
const static std::array<c10::Symbol, 3> sigrid_transforms_symbols{
725+
c10::Symbol::fromQualString("fb::variadic_sigrid_transforms_torch_bind"),
726+
c10::Symbol::fromQualString("fb::sigrid_transforms_torch_bind"),
727+
c10::Symbol::fromQualString("fb::sigrid_transforms")};
728+
729+
if (std::find(
730+
sigrid_transforms_symbols.begin(),
731+
sigrid_transforms_symbols.end(),
732+
node->kind()) == sigrid_transforms_symbols.end()) {
733+
return false;
734+
}
735+
736+
// To fuse with sigrid transforms, we must be able to statically determine
737+
// `instance` and `use_offsets` - these two together let us statically
738+
// determine the types of the outputs. Rationale: it is a huge pain to write
739+
// fused sigrid transforms without static type information, and these two
740+
// arguments are indeed statically known in every model we've seen.
741+
// The reason why trying to fuse the outputs is annoying without static type
742+
// information is that, if one of the outputs is not managed, you need to
743+
// reset to an empty tensor of the correct type each iteration. So, if we
744+
// can't collect types ahead of time, we would have to do it lazily on the
745+
// first iteration, which would could be wasteful in terms of time/memory
746+
// - either each thread would have its own set of output types, or we would
747+
// need a lock to prevent data races.
748+
const auto num_inputs = node->inputs().size();
749+
return !toIValue(node->input(0)).has_value() ||
750+
!toIValue(node->input(num_inputs - 1)).has_value();
751+
}
752+
753+
} // namespace
754+
723755
void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
724756
const FastMap<c10::Symbol, c10::Symbol> unfused_to_fused = {
725757
OP_PAIR("fb::equally_split", "static_runtime::fused_equally_split"),
@@ -744,12 +776,7 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
744776
"fb::gather_ranges_to_dense_v2",
745777
"static_runtime::fused_gather_ranges_to_dense_v2")};
746778

747-
AliasDb alias_db(
748-
graph,
749-
/*isFrozen=*/false);
750779
// replacement contains (old_node, new_node, list_unpack_node)
751-
const std::vector<Value*> graph_outputs(
752-
graph->outputs().begin(), graph->outputs().end());
753780
std::vector<std::tuple<Node*, Node*, Node*>> replacement;
754781
DepthFirstGraphNodeIterator graph_it(graph);
755782
for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
@@ -773,20 +800,8 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
773800
continue;
774801
}
775802

776-
const bool checks_all_outputs =
777-
node->kind() == fromQualString("fb::equally_split") ||
778-
node->kind() == fromQualString("fb::gather_ranges_to_dense") ||
779-
node->kind() == fromQualString("fb::gather_ranges_to_dense_v2");
780-
781-
if (!checks_all_outputs) {
782-
// If any output of the ListUnpack node is unmanaged, disable fusion
783-
// since the fused op assumes all outputs are either managed or not.
784-
// Ops excluded here check all outputs.
785-
const std::vector<Value*> list_unpack_outputs_vec(
786-
list_unpack_outputs.begin(), list_unpack_outputs.end());
787-
if (alias_db.mayContainAlias(list_unpack_outputs_vec, graph_outputs)) {
788-
continue;
789-
}
803+
if (shouldNotFuseListUnpackSpecialCase(node)) {
804+
continue;
790805
}
791806

792807
const auto& new_sym = unfused_to_fused_it->second;
@@ -813,12 +828,6 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
813828
list_unpack_node->destroy();
814829
old_node->destroy();
815830
}
816-
817-
#ifndef NDEBUG
818-
graph->lint();
819-
AliasDb db2(graph);
820-
torch::jit::Lint(&db2);
821-
#endif
822831
} // namespace jit
823832

824833
void EnableStaticRuntimeLayerNorm(std::shared_ptr<torch::jit::Graph>& graph) {

0 commit comments

Comments
 (0)