Skip to content

Commit 0495b05

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Add test to check that classification models are FX-compatible (#3662)
Summary: * Add test to check that classification models are FX-compatible * Replace torch.equal with torch.allclose * remove skipling Reviewed By: fmassa Differential Revision: D29264313 fbshipit-source-id: 4e57e255c6ce680fc6deee6a9980a7d189e23597 Co-authored-by: Nicolas Hug <[email protected]>
1 parent eaea921 commit 0495b05

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

test/test_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import functools
88
import operator
99
import torch
10+
import torch.fx
1011
import torch.nn as nn
1112
import torchvision
1213
from torchvision import models
@@ -140,6 +141,13 @@ def get_export_import_copy(m):
140141
assert_export_import_module(sm, args)
141142

142143

144+
def _check_fx_compatible(model, inputs):
145+
model_fx = torch.fx.symbolic_trace(model)
146+
out = model(inputs)
147+
out_fx = model_fx(inputs)
148+
torch.testing.assert_close(out, out_fx)
149+
150+
143151
# If 'unwrapper' is provided it will be called with the script model outputs
144152
# before they are compared to the eager model outputs. This is useful if the
145153
# model outputs are different between TorchScript / Eager mode
@@ -408,6 +416,7 @@ def test_classification_model(model_name, dev):
408416
_assert_expected(out.cpu(), model_name, prec=0.1)
409417
assert out.shape[-1] == 50
410418
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
419+
_check_fx_compatible(model, x)
411420

412421
if dev == torch.device("cuda"):
413422
with torch.cuda.amp.autocast():

0 commit comments

Comments
 (0)