Skip to content

Commit d341b66

Browse files
kit1980pytorchmergebot
authored andcommitted
Revert [dynamo] support group=None when rewriting collectives (#12018) (#120677)
This reverts commit 298c686. Pull Request resolved: #120677 Approved by: https://github.com/yifuwang, https://github.com/huydhn
1 parent fdae936 commit d341b66

14 files changed

+20
-96
lines changed

benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
234234

235235

236236

237-
moco,pass,5
237+
moco,pass,11
238238

239239

240240

benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ mobilenet_v3_large,pass,7
182182

183183

184184

185-
moco,pass,11
185+
moco,pass,17
186186

187187

188188

benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ mobilenet_v3_large,pass,0
230230

231231

232232

233-
moco,pass,5
233+
moco,pass,11
234234

235235

236236

benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ mobilenet_v3_large,pass,7
178178

179179

180180

181-
moco,pass,11
181+
moco,pass,17
182182

183183

184184

benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ mobilenet_v3_large,pass,0
230230

231231

232232

233-
moco,pass,5
233+
moco,pass,11
234234

235235

236236

benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ mobilenet_v3_large,pass,7
178178

179179

180180

181-
moco,pass,11
181+
moco,pass,17
182182

183183

184184

benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
234234

235235

236236

237-
moco,pass,5
237+
moco,pass,11
238238

239239

240240

benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ mobilenet_v3_large,pass,7
182182

183183

184184

185-
moco,pass,11
185+
moco,pass,17
186186

187187

188188

benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
234234

235235

236236

237-
moco,pass,5
237+
moco,pass,11
238238

239239

240240

benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ mobilenet_v3_large,pass,7
182182

183183

184184

185-
moco,pass,11
185+
moco,pass,17
186186

187187

188188

test/distributed/test_inductor_collectives.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@
2222
run_with_both_funcol_impls_with_arg,
2323
skip_if_lt_x_gpu,
2424
)
25-
from torch.testing._internal.common_utils import (
26-
instantiate_parametrized_tests,
27-
parametrize,
28-
requires_cuda,
29-
)
25+
from torch.testing._internal.common_utils import instantiate_parametrized_tests, requires_cuda
3026
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
3127
from torch.utils._triton import has_triton
3228
from torch._inductor.utils import run_and_get_triton_code
@@ -829,43 +825,22 @@ def func(inp, out, *, pg):
829825
assert same(outputs, correct_outputs)
830826

831827
@run_with_both_funcol_impls
832-
@parametrize(
833-
"pg_mode",
834-
[
835-
"kwargs",
836-
"kwargs_none",
837-
"unspecified",
838-
]
839-
)
840-
def test_dynamo_rewrite_dist_allreduce(self, pg_mode):
841-
842-
def func(tensor, *args, **kwargs):
828+
def test_dynamo_rewrite_dist_allreduce(self):
829+
830+
def func(tensor, pg):
843831
torch.distributed.all_reduce(
844832
tensor,
845-
*args,
846-
**kwargs,
833+
group=pg
847834
)
848835

849836
counter = CompileCounter()
850837
compiled = torch.compile(func, backend=counter, fullgraph=True)
851838

852-
args = []
853-
kwargs = {}
854-
855-
# TODO(yifu): test positional and positional_none
856-
# once explicit reduce op is supported
857-
if pg_mode == "kwargs":
858-
kwargs["group"] = GroupMember.WORLD
859-
elif pg_mode == "kwargs_none":
860-
kwargs["group"] = None
861-
else:
862-
assert pg_mode == "unspecified"
863-
864839
inputs_compiled = torch.ones(2, device=self.device)
865840
inputs_eager = torch.ones(2, device=self.device)
866841

867-
compiled(inputs_compiled, *args, **kwargs)
868-
func(inputs_eager, *args, **kwargs)
842+
compiled(inputs_compiled, GroupMember.WORLD)
843+
func(inputs_eager, GroupMember.WORLD)
869844

870845
assert counter.frame_count == 1
871846
# should test more precisely, but the 3 is supposed to be (all_reduce, wait, copy_)

torch/_dynamo/variables/distributed.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import torch
77
from .. import variables
88
from ..exc import unimplemented
9-
from ..guards import GuardBuilder, install_guard
10-
from ..source import AttrSource, GlobalSource
119
from ..utils import istype
1210
from .base import VariableTracker
1311
from .constant import ConstantVariable
@@ -257,30 +255,3 @@ def is_process_group(value):
257255
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
258256

259257
return istype(value, (ProcessGroup, FakeProcessGroup))
260-
261-
@staticmethod
262-
def get_global_pg_variable():
263-
"""
264-
Make a ProcessGroupVariable from torch.distributed.group.WORLD and
265-
intall guards.
266-
"""
267-
import torch.distributed as dist
268-
269-
source = AttrSource(
270-
AttrSource(
271-
base=AttrSource(
272-
base=GlobalSource(global_name="torch"),
273-
member="distributed",
274-
get_static=False,
275-
),
276-
member="group",
277-
get_static=False,
278-
),
279-
member="WORLD",
280-
get_static=False,
281-
)
282-
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
283-
return ProcessGroupVariable(
284-
dist.group.WORLD,
285-
source=source,
286-
)

torch/_dynamo/variables/functions.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from ..utils import check_constant_args, get_first_attr, identity, istype, make_cell
1818
from .base import MutableLocal, typestr, VariableTracker
1919
from .constant import ConstantVariable
20-
from .distributed import ProcessGroupVariable
2120

2221
if TYPE_CHECKING:
2322
from torch._guards import Source
@@ -687,21 +686,10 @@ def call_function(
687686
# call_function must check any unsupported arguments and graph-break.
688687
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
689688
# since that's the contract for putting a mapping in `traceable_collective_remaps`
690-
691-
# Merge args into kwargs so positional and keyword args
692-
# can be processed the same way.
693-
signature = inspect.signature(self.fn)
694-
kwargs = dict(signature.bind(*args, **kwargs).arguments)
695-
args = ()
696-
697689
if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
698690
unimplemented(
699691
f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}"
700692
)
701-
702-
if kwargs.get("group") is None or kwargs["group"].value is None:
703-
kwargs["group"] = ProcessGroupVariable.get_global_pg_variable()
704-
705693
return self.replacement_var.call_function(tx, args, kwargs)
706694

707695

torch/distributed/_functional_collectives.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,20 +1035,10 @@ def all_gather_inplace(
10351035
assert (
10361036
not async_op
10371037
), "Can't remap async version of inplace op to functional collective"
1038-
assert all(
1039-
t.size(0) == tensor.size(0) for t in tensor_list
1040-
), "Remapping variable size all_gather is not yet supported"
1041-
10421038
output = all_gather_tensor(tensor, 0, group, tag)
1043-
1044-
# Use aten.slice as instead of aten.split because the latter causes
1045-
# tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
1046-
output_splits = []
1047-
offset = 0
1048-
for t in tensor_list:
1049-
output_splits.append(output[offset : offset + t.size(0)])
1050-
offset += t.size(0)
1051-
for dst, src in zip(tensor_list, output_splits):
1039+
for dst, src in zip(
1040+
tensor_list, output.split([t.size(0) for t in tensor_list], dim=0)
1041+
):
10521042
dst.copy_(src)
10531043
return tensor_list
10541044

0 commit comments

Comments
 (0)