@@ -395,6 +395,7 @@ def wrapped_fn(*args, **kwargs):
395
395
skip ('nn.functional.max_unpool1d' ), # fails everywhere except on mac
396
396
skip ('nn.functional.max_unpool2d' ), # fails everywhere except on windows
397
397
skip ('nn.functional.max_unpool3d' ), # fails everywhere except on mac
398
+ xfail ("native_batch_norm" ),
398
399
399
400
xfail ('nn.functional.rrelu' ) # in-place test errors out with no formula implemented
400
401
}))
@@ -643,6 +644,7 @@ def fn(inp, *args, **kwargs):
643
644
xfail ("nn.functional.batch_norm" , 'without_cudnn' ),
644
645
# view doesn't work on sparse
645
646
xfail ("to_sparse" ),
647
+ xfail ("native_batch_norm" ),
646
648
}))
647
649
@ops (op_db + additional_op_db , allowed_dtypes = (torch .float ,))
648
650
@toleranceOverride ({torch .float32 : tol (atol = 1e-04 , rtol = 1e-04 )})
@@ -725,6 +727,7 @@ def vjp_of_vjp(*args_and_cotangents):
725
727
# ---------------------------- BUGS ------------------------------------
726
728
# All of the following are bugs and need to be fixed
727
729
skip ('linalg.svdvals' ), # # really annoying thing where it passes correctness check but not has_batch_rule
730
+ skip ("native_batch_norm" ),
728
731
xfail ('__getitem__' , '' ), # dynamic error
729
732
xfail ('linalg.eig' ), # Uses aten::allclose
730
733
xfail ('linalg.householder_product' ), # needs select_scatter
@@ -833,6 +836,7 @@ def test_vmapvjp(self, device, dtype, op):
833
836
# erroring because running_mean and running_var aren't differentiable
834
837
xfail ('nn.functional.batch_norm' ),
835
838
xfail ('nn.functional.batch_norm' , 'without_cudnn' ),
839
+ xfail ("native_batch_norm" ),
836
840
# ----------------------------------------------------------------------
837
841
}
838
842
@@ -1030,6 +1034,7 @@ def test():
1030
1034
xfail ('linalg.vecdot' , '' ),
1031
1035
xfail ('segment_reduce' , 'lengths' ),
1032
1036
xfail ('sparse.sampled_addmm' , '' ),
1037
+ xfail ("native_batch_norm" ),
1033
1038
}))
1034
1039
def test_vmapvjp_has_batch_rule (self , device , dtype , op ):
1035
1040
if not op .supports_autograd :
@@ -1095,6 +1100,7 @@ def test():
1095
1100
xfail ('nn.functional.dropout3d' , '' ),
1096
1101
xfail ('as_strided_scatter' , '' ),
1097
1102
xfail ('sparse.sampled_addmm' , '' ),
1103
+ xfail ("native_batch_norm" ),
1098
1104
}))
1099
1105
def test_vjpvmap (self , device , dtype , op ):
1100
1106
# NB: there is no vjpvmap_has_batch_rule test because that is almost
@@ -1338,6 +1344,10 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
1338
1344
xfail ('to' ), # RuntimeError: required rank 4 tensor to use channels_last format
1339
1345
xfail ('to_sparse' ), # Forward AD not implemented and no decomposition
1340
1346
xfail ('view_as_complex' ), # RuntimeError: Tensor must have a last dimension with stride 1
1347
+ # RuntimeError: Batch norm got a batched tensor as
1348
+ # input while the running_mean or running_var, which will be updated in
1349
+ # place, were not batched.
1350
+ xfail ("native_batch_norm" ),
1341
1351
}))
1342
1352
@ops (op_db + additional_op_db , allowed_dtypes = (torch .float ,))
1343
1353
@toleranceOverride ({torch .float32 : tol (atol = 1e-04 , rtol = 1e-04 )})
0 commit comments