Skip to content

Commit 7975732

Browse files
committed
Enabling few more torchbench models with AOT Autograd
1 parent 47b9857 commit 7975732

File tree

3 files changed

+31
-13
lines changed

3 files changed

+31
-13
lines changed

torchbench.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,16 @@
7171
"maml",
7272
# Known issues with training
7373
"demucs", # https://github.com/pytorch/benchmark/pull/639
74-
"densenet121", # https://github.com/pytorch/benchmark/issues/652
75-
"hf_Albert", # https://github.com/pytorch/benchmark/issues/652
7674
"hf_Reformer", # Can only be used in the training phase
77-
# AOT Autograd known issues
78-
"dlrm", # No sparse support
79-
"resnet50_quantized_qat", # Con2DBnRelu
8075
# Known TorchDynamo bug
8176
"hf_GPT2", # Hard to debug stashed tensor issue
8277
"tacotron2", # Model uses Variable
8378
}
8479

8580
# Some models have bad train dataset. We read eval dataset.
86-
ONLY_EVAL_DATASET = {"yolov3"}
81+
# yolov3 - seems to have different number of inputs between eval and train
82+
# densenet121 - OOM for train, using eval for now.
83+
ONLY_EVAL_DATASET = {"yolov3", "densenet121"}
8784

8885
# These models support only train mode. So accuracy checking can't be done in
8986
# eval mode.
@@ -93,6 +90,8 @@
9390
REQUIRE_HIGHER_TOLERANCE = {
9491
"alexnet",
9592
"attention_is_all_you_need_pytorch",
93+
"densenet121",
94+
"hf_Albert",
9695
"vgg16",
9796
"mobilenet_v3_large",
9897
}
@@ -574,6 +573,11 @@ def main():
574573
action="store_true",
575574
help="Generates AOT Autograd stats like how mnay graphs are sent to AOT",
576575
)
576+
parser.add_argument(
577+
"--disable-functionalization",
578+
action="store_true",
579+
help="Disables functionalization",
580+
)
577581
group = parser.add_mutually_exclusive_group()
578582
group.add_argument(
579583
"--coverage", action="store_true", help="(default) " + help(coverage_experiment)
@@ -856,6 +860,9 @@ def main():
856860
if output_filename:
857861
output_filename = os.path.join(torchdynamo.config.base_dir, output_filename)
858862

863+
if args.disable_functionalization:
864+
torchdynamo.config.normalize_ir = False
865+
859866
if args.minimum_call_count:
860867
torchdynamo.config.minimum_call_count = args.minimum_call_count
861868
if args.only:

torchdynamo/optimizations/training.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
import copy
12
import logging
3+
from distutils.debug import DEBUG
24

35
import torch
6+
from cv2 import norm
47

8+
import torchdynamo
9+
from torchdynamo import config
10+
from torchdynamo.testing import same
511
from torchdynamo.utils import clone_inputs
612
from torchdynamo.utils import count_calls
713
from torchdynamo.utils import counters
@@ -27,13 +33,14 @@ def __init__(self, gm: torch.fx.GraphModule, example_inputs):
2733
counters["aot_autograd"]["total"] += 1
2834
self.use_fallback = False
2935
self.original_example_inputs = example_inputs
30-
try:
31-
self.gm = normalize_ir(gm, self.example_inputs)
32-
except Exception:
33-
log.debug("TorchDynamo unable to remove mutation")
34-
self.gm = gm
35-
self.use_fallback = True
36-
pass
36+
self.gm = gm
37+
if config.normalize_ir:
38+
try:
39+
self.gm = normalize_ir(gm, self.example_inputs)
40+
except Exception:
41+
log.debug("TorchDynamo unable to remove mutation")
42+
self.use_fallback = True
43+
pass
3744

3845
gm_inputs = list(filter(lambda x: x.op == "placeholder", gm.graph.nodes))
3946

torchdynamo/testing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def same(a, b, cos_similarity=False, tol=1e-4):
8484
return False
8585
return True
8686
elif isinstance(a, torch.Tensor):
87+
if a.is_sparse:
88+
assert b.is_sparse
89+
a = a.to_dense()
90+
b = b.to_dense()
8791
assert isinstance(b, torch.Tensor)
8892
if cos_similarity:
8993
# TRT will bring error loss larger than current threshold. Use cosine similarity as replacement

0 commit comments

Comments
 (0)