Skip to content

Commit 7784ee4

Browse files
authored
Merge pull request pytorch#167 from NVIDIA/syncbn_api_update
update SyncBatchNorm API
2 parents 8e8dd35 + bc76b01 commit 7784ee4

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

apex/parallel/optimized_sync_batchnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class SyncBatchNorm(_BatchNorm):
5555
>>> inp = torch.randn(10, 14, 14, 100).cuda()
5656
"""
5757

58-
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last = False):
58+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):
5959
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
6060
self.process_group = process_group
6161
self.channel_last = channel_last

apex/parallel/sync_batchnorm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ class SyncBatchNorm(_BatchNorm):
4848

4949
warned = False
5050

51-
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None):
51+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):
52+
if channel_last == True:
53+
raise AttributeError("channel_last is not supported by primitive SyncBatchNorm implementation. Try install apex with `--cuda_ext` if channel_last is desired.")
5254

5355
if not SyncBatchNorm.warned:
5456
print("Warning: using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext. The exception raised when attempting to import the cuda backend was: ", self.syncbn_import_error)

0 commit comments

Comments
 (0)