Skip to content

Commit 2aabbd0

Browse files
author
mikeiovine
committed
[SR] Make sigrid_transforms fusion work on graph outputs
Pull Request resolved: #73091 This is a re-work of D33669034; the change was backed out due to a data race causing crashes. The `output_types` vector was the culprit. It was previously lazily initialized on the first iteration. This was problematic because of static runtime's hidden assumption that ops are thread-safe. The re-work now only does the list unpack fusion if the output dtypes can be statically determined, e.g. if the sigrid transforms instance and `use_offsets` are both constant. Note that this is true for all the models we care about. Also, we were already partially making this assumption by dereferencing the `std::optional` sigrid transforms instance in most of the ops. Another advantage of this is that it makes the code simpler compared to D33669034. Once the output types are determined, they can be moved into the op lambda and shared as read-only data. ghstack-source-id: 150704445 Differential Revision: [D34290401](https://our.internmc.facebook.com/intern/diff/D34290401/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34290401/)!
1 parent 1d49711 commit 2aabbd0

File tree

1 file changed

+36
-27
lines changed

1 file changed

+36
-27
lines changed

torch/csrc/jit/runtime/static/passes.cpp

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -718,8 +718,40 @@ void EliminateTrivialEquallySplit(std::shared_ptr<torch::jit::Graph>& graph) {
718718
}
719719
}
720720

721-
// NB: The alias type of the fused op needs to be changed to
722-
// c10::AliasAnalysisKind::PURE_FUNCTION to make alias analysis work.
721+
namespace {
722+
723+
bool shouldNotFuseListUnpackSpecialCase(const Node* node) {
724+
const static std::array<c10::Symbol, 3> sigrid_transforms_symbols{
725+
c10::Symbol::fromQualString("fb::variadic_sigrid_transforms_torch_bind"),
726+
c10::Symbol::fromQualString("fb::sigrid_transforms_torch_bind"),
727+
c10::Symbol::fromQualString("fb::sigrid_transforms")};
728+
729+
if (std::find(
730+
sigrid_transforms_symbols.begin(),
731+
sigrid_transforms_symbols.end(),
732+
node->kind()) == sigrid_transforms_symbols.end()) {
733+
return false;
734+
}
735+
736+
// To fuse with sigrid transforms, we must be able to statically determine
737+
// `instance` and `use_offsets` - these two together let us statically
738+
// determine the types of the outputs. Rationale: it is a huge pain to write
739+
// fused sigrid transforms without static type information, and these two
740+
// arguments are indeed statically known in every model we've seen.
741+
// The reason why trying to fuse the outputs is annoying without static type
742+
// information is that, if one of the outputs is not managed, you need to
743+
// reset to an empty tensor of the correct type each iteration. So, if we
744+
// can't collect types ahead of time, we would have to do it lazily on the
745+
// first iteration, which would could be wasteful in terms of time/memory
746+
// - either each thread would have its own set of output types, or we would
747+
// need a lock to prevent data races.
748+
const auto num_inputs = node->inputs().size();
749+
return !toIValue(node->input(0)).has_value() ||
750+
!toIValue(node->input(num_inputs - 1)).has_value();
751+
}
752+
753+
} // namespace
754+
723755
void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
724756
const FastMap<c10::Symbol, c10::Symbol> unfused_to_fused = {
725757
OP_PAIR("fb::equally_split", "static_runtime::fused_equally_split"),
@@ -746,12 +778,7 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
746778
OP_PAIR(
747779
"fb::split_and_squeeze", "static_runtime::fused_split_and_squeeze")};
748780

749-
AliasDb alias_db(
750-
graph,
751-
/*isFrozen=*/false);
752781
// replacement contains (old_node, new_node, list_unpack_node)
753-
const std::vector<Value*> graph_outputs(
754-
graph->outputs().begin(), graph->outputs().end());
755782
std::vector<std::tuple<Node*, Node*, Node*>> replacement;
756783
DepthFirstGraphNodeIterator graph_it(graph);
757784
for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
@@ -775,20 +802,8 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
775802
continue;
776803
}
777804

778-
const bool checks_all_outputs =
779-
node->kind() == fromQualString("fb::equally_split") ||
780-
node->kind() == fromQualString("fb::gather_ranges_to_dense") ||
781-
node->kind() == fromQualString("fb::gather_ranges_to_dense_v2");
782-
783-
if (!checks_all_outputs) {
784-
// If any output of the ListUnpack node is unmanaged, disable fusion
785-
// since the fused op assumes all outputs are either managed or not.
786-
// Ops excluded here check all outputs.
787-
const std::vector<Value*> list_unpack_outputs_vec(
788-
list_unpack_outputs.begin(), list_unpack_outputs.end());
789-
if (alias_db.mayContainAlias(list_unpack_outputs_vec, graph_outputs)) {
790-
continue;
791-
}
805+
if (shouldNotFuseListUnpackSpecialCase(node)) {
806+
continue;
792807
}
793808

794809
const auto& new_sym = unfused_to_fused_it->second;
@@ -815,12 +830,6 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
815830
list_unpack_node->destroy();
816831
old_node->destroy();
817832
}
818-
819-
#ifndef NDEBUG
820-
graph->lint();
821-
AliasDb db2(graph);
822-
torch::jit::Lint(&db2);
823-
#endif
824833
} // namespace jit
825834

826835
void EnableStaticRuntimeLayerNorm(std::shared_ptr<torch::jit::Graph>& graph) {

0 commit comments

Comments
 (0)