From 94911e179350136c7957c2f07482a2821b711d97 Mon Sep 17 00:00:00 2001 From: ekka Date: Mon, 6 May 2019 18:08:37 +0530 Subject: [PATCH 1/7] Minor refactoring of ShuffleNetV2 Added progress flag following #875. Further the following refactoring was also done: 1) added `version` argument in shufflenetv2 method and removed the operations for converting the `width_mult` arg to float and string. 2) removed `num_classes` argument and **kwargs from functions except `ShuffleNetV2` --- torchvision/models/shufflenetv2.py | 33 +++++++++++++++--------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index eb66a1abeec..20e1c38f4d2 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +from .utils import load_state_dict_from_url __all__ = ['ShuffleNetV2', 'shufflenetv2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', @@ -89,7 +90,7 @@ def __init__(self, num_classes=1000, width_mult=1): super(ShuffleNetV2, self).__init__() try: - self.stage_out_channels = self._getStages(float(width_mult)) + self.stage_out_channels = self._getStages(width_mult) except KeyError: raise ValueError('width_mult {} is not supported'.format(width_mult)) @@ -145,36 +146,34 @@ def _getStages(mult): return stages[str(mult)] -def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs): - model = ShuffleNetV2(num_classes=num_classes, width_mult=width_mult) +def _shufflenetv2(version, pretrained=False, progress, width_mult=1.0): + model = ShuffleNetV2(width_mult=width_mult) if pretrained: - # change width_mult to float - if isinstance(width_mult, int): - width_mult = float(width_mult) - model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)])) + arch = 'shufflenetv2_x' + version try: - model_url = model_urls[model_type.lower()] + model_url = model_urls[arch] except KeyError: raise ValueError('model {} is not support'.format(model_type)) if model_url is None: raise NotImplementedError('pretrained {} is not supported'.format(model_type)) - model.load_state_dict(torch.utils.model_zoo.load_url(model_url)) + state_dict = load_state_dict_from_url(model_urls, progress=progress) + model.load_state_dict(state_dict) return model -def shufflenetv2_x0_5(pretrained=False, num_classes=1000, **kwargs): - return shufflenetv2(pretrained, num_classes, 0.5) +def shufflenetv2_x0_5(pretrained=False, progress=True): + return _shufflenetv2('0.5', pretrained, progress, 0.5) -def shufflenetv2_x1_0(pretrained=False, num_classes=1000, **kwargs): - return shufflenetv2(pretrained, num_classes, 1) +def shufflenetv2_x1_0(pretrained=False, progress=True): + return _shufflenetv2('1.0', pretrained, progress, 1.0) -def shufflenetv2_x1_5(pretrained=False, num_classes=1000, **kwargs): - return shufflenetv2(pretrained, num_classes, 1.5) +def shufflenetv2_x1_5(pretrained=False, progress=True): + return _shufflenetv2('1.5', pretrained, progress, 1.5) -def shufflenetv2_x2_0(pretrained=False, num_classes=1000, **kwargs): - return shufflenetv2(pretrained, num_classes, 2) +def shufflenetv2_x2_0(pretrained=False, progress=True): + return _shufflenetv2('2.0', pretrained, progress, 2.0) From 4b5d8ee5af1d5dc8a614b65116803882b0892c36 Mon Sep 17 00:00:00 2001 From: ekka Date: Mon, 6 May 2019 18:30:30 +0530 Subject: [PATCH 2/7] removed `version` arg --- torchvision/models/shufflenetv2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 20e1c38f4d2..0eebae0ab44 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -146,11 +146,11 @@ def _getStages(mult): return stages[str(mult)] -def _shufflenetv2(version, pretrained=False, progress, width_mult=1.0): +def _shufflenetv2(width_mult=1.0, pretrained, progress): model = ShuffleNetV2(width_mult=width_mult) if pretrained: - arch = 'shufflenetv2_x' + version + arch = 'shufflenetv2_x' + str(width_mult) try: model_url = model_urls[arch] except KeyError: @@ -164,16 +164,16 @@ def _shufflenetv2(version, pretrained=False, progress, width_mult=1.0): def shufflenetv2_x0_5(pretrained=False, progress=True): - return _shufflenetv2('0.5', pretrained, progress, 0.5) + return _shufflenetv2(0.5, pretrained, progress) def shufflenetv2_x1_0(pretrained=False, progress=True): - return _shufflenetv2('1.0', pretrained, progress, 1.0) + return _shufflenetv2(1.0, pretrained, progress) def shufflenetv2_x1_5(pretrained=False, progress=True): - return _shufflenetv2('1.5', pretrained, progress, 1.5) + return _shufflenetv2(1.5, pretrained, progress) def shufflenetv2_x2_0(pretrained=False, progress=True): - return _shufflenetv2('2.0', pretrained, progress, 2.0) + return _shufflenetv2(2.0, pretrained, progress) From e38fccaf72313391f192871db432749857b63785 Mon Sep 17 00:00:00 2001 From: ekka Date: Mon, 6 May 2019 18:56:32 +0530 Subject: [PATCH 3/7] Update shufflenetv2.py --- torchvision/models/shufflenetv2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 0eebae0ab44..a2b982d128c 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -146,7 +146,7 @@ def _getStages(mult): return stages[str(mult)] -def _shufflenetv2(width_mult=1.0, pretrained, progress): +def _shufflenetv2(pretrained, progress, width_mult=1.0): model = ShuffleNetV2(width_mult=width_mult) if pretrained: @@ -164,16 +164,16 @@ def _shufflenetv2(width_mult=1.0, pretrained, progress): def shufflenetv2_x0_5(pretrained=False, progress=True): - return _shufflenetv2(0.5, pretrained, progress) + return _shufflenetv2(pretrained, progress, 0.5) def shufflenetv2_x1_0(pretrained=False, progress=True): - return _shufflenetv2(1.0, pretrained, progress) + return _shufflenetv2(pretrained, progress, 1.0) def shufflenetv2_x1_5(pretrained=False, progress=True): - return _shufflenetv2(1.5, pretrained, progress) + return _shufflenetv2(pretrained, progress, 1.5) def shufflenetv2_x2_0(pretrained=False, progress=True): - return _shufflenetv2(2.0, pretrained, progress) + return _shufflenetv2(pretrained, progress, 2.0) From 6764c63d0e4e67e1f38563366b548dae8896994d Mon Sep 17 00:00:00 2001 From: ekka Date: Mon, 6 May 2019 19:34:31 +0530 Subject: [PATCH 4/7] Removed the try except block --- torchvision/models/shufflenetv2.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index a2b982d128c..c2e0adb6326 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -4,9 +4,7 @@ import torch.nn as nn from .utils import load_state_dict_from_url -__all__ = ['ShuffleNetV2', 'shufflenetv2', - 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', - 'shufflenetv2_x1_5', 'shufflenetv2_x2_0'] +__all__ = ['ShuffleNetV2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0'] model_urls = { 'shufflenetv2_x0.5': @@ -151,14 +149,12 @@ def _shufflenetv2(pretrained, progress, width_mult=1.0): if pretrained: arch = 'shufflenetv2_x' + str(width_mult) - try: - model_url = model_urls[arch] - except KeyError: - raise ValueError('model {} is not support'.format(model_type)) + model_url = model_urls[arch] if model_url is None: - raise NotImplementedError('pretrained {} is not supported'.format(model_type)) - state_dict = load_state_dict_from_url(model_urls, progress=progress) - model.load_state_dict(state_dict) + raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) + else: + state_dict = load_state_dict_from_url(model_urls, progress=progress) + model.load_state_dict(state_dict) return model From 7770ceea47e7ed7d31073a0ce7dd9ffc6b425e49 Mon Sep 17 00:00:00 2001 From: ekka Date: Mon, 6 May 2019 19:53:37 +0530 Subject: [PATCH 5/7] Update shufflenetv2.py --- torchvision/models/shufflenetv2.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index c2e0adb6326..985a0f0488b 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -144,8 +144,8 @@ def _getStages(mult): return stages[str(mult)] -def _shufflenetv2(pretrained, progress, width_mult=1.0): - model = ShuffleNetV2(width_mult=width_mult) +def _shufflenetv2(pretrained, progress, width_mult, **kwargs): + model = ShuffleNetV2(width_mult=width_mult, **kwargs) if pretrained: arch = 'shufflenetv2_x' + str(width_mult) @@ -159,17 +159,17 @@ def _shufflenetv2(pretrained, progress, width_mult=1.0): return model -def shufflenetv2_x0_5(pretrained=False, progress=True): - return _shufflenetv2(pretrained, progress, 0.5) +def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs): + return _shufflenetv2(pretrained, progress, 0.5, **kwargs) -def shufflenetv2_x1_0(pretrained=False, progress=True): - return _shufflenetv2(pretrained, progress, 1.0) +def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs): + return _shufflenetv2(pretrained, progress, 1.0, **kwargs) -def shufflenetv2_x1_5(pretrained=False, progress=True): - return _shufflenetv2(pretrained, progress, 1.5) +def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs): + return _shufflenetv2(pretrained, progress, 1.5, **kwargs) -def shufflenetv2_x2_0(pretrained=False, progress=True): - return _shufflenetv2(pretrained, progress, 2.0) +def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs): + return _shufflenetv2(pretrained, progress, 2.0, **kwargs) From 15b619fb86b2d1b87be79081ba5dc83e9a07f805 Mon Sep 17 00:00:00 2001 From: ekka Date: Tue, 7 May 2019 19:21:44 +0530 Subject: [PATCH 6/7] Changed version from float to str --- torchvision/models/shufflenetv2.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 985a0f0488b..122e96850fc 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -84,16 +84,13 @@ def forward(self, x): class ShuffleNetV2(nn.Module): - def __init__(self, num_classes=1000, width_mult=1): + def __init__(self, num_classes=1000, width_mult='1.0'): super(ShuffleNetV2, self).__init__() - try: - self.stage_out_channels = self._getStages(width_mult) - except KeyError: - raise ValueError('width_mult {} is not supported'.format(width_mult)) - + self.stage_out_channels = self._getStages(width_mult) input_channels = 3 output_channels = self.stage_out_channels[0] + self.conv1 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), nn.BatchNorm2d(output_channels), @@ -141,14 +138,14 @@ def _getStages(mult): '1.5': [24, 176, 352, 704, 1024], '2.0': [24, 244, 488, 976, 2048], } - return stages[str(mult)] + return stages[mult] def _shufflenetv2(pretrained, progress, width_mult, **kwargs): model = ShuffleNetV2(width_mult=width_mult, **kwargs) if pretrained: - arch = 'shufflenetv2_x' + str(width_mult) + arch = 'shufflenetv2_x' + width_mult model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) @@ -160,16 +157,16 @@ def _shufflenetv2(pretrained, progress, width_mult, **kwargs): def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs): - return _shufflenetv2(pretrained, progress, 0.5, **kwargs) + return _shufflenetv2(pretrained, progress, '0.5', **kwargs) def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs): - return _shufflenetv2(pretrained, progress, 1.0, **kwargs) + return _shufflenetv2(pretrained, progress, '1.0', **kwargs) def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs): - return _shufflenetv2(pretrained, progress, 1.5, **kwargs) + return _shufflenetv2(pretrained, progress, '1.5', **kwargs) def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs): - return _shufflenetv2(pretrained, progress, 2.0, **kwargs) + return _shufflenetv2(pretrained, progress, '2.0', **kwargs) From 3029b526e633e09899e73d5316d04f7b270775c7 Mon Sep 17 00:00:00 2001 From: ekka Date: Tue, 7 May 2019 20:35:52 +0530 Subject: [PATCH 7/7] Replace `width_mult` with `stages_out_channels` Removes the need of `_getStages` function. --- torchvision/models/shufflenetv2.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 122e96850fc..5726cea2a22 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -84,10 +84,10 @@ def forward(self, x): class ShuffleNetV2(nn.Module): - def __init__(self, num_classes=1000, width_mult='1.0'): + def __init__(self, stage_out_channels, num_classes=1000): super(ShuffleNetV2, self).__init__() - self.stage_out_channels = self._getStages(width_mult) + self.stage_out_channels = stage_out_channels input_channels = 3 output_channels = self.stage_out_channels[0] @@ -130,22 +130,11 @@ def forward(self, x): x = self.fc(x) return x - @staticmethod - def _getStages(mult): - stages = { - '0.5': [24, 48, 96, 192, 1024], - '1.0': [24, 116, 232, 464, 1024], - '1.5': [24, 176, 352, 704, 1024], - '2.0': [24, 244, 488, 976, 2048], - } - return stages[mult] - -def _shufflenetv2(pretrained, progress, width_mult, **kwargs): - model = ShuffleNetV2(width_mult=width_mult, **kwargs) +def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs): + model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs) if pretrained: - arch = 'shufflenetv2_x' + width_mult model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) @@ -157,16 +146,16 @@ def _shufflenetv2(pretrained, progress, width_mult, **kwargs): def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs): - return _shufflenetv2(pretrained, progress, '0.5', **kwargs) + return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs) def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs): - return _shufflenetv2(pretrained, progress, '1.0', **kwargs) + return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs) def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs): - return _shufflenetv2(pretrained, progress, '1.5', **kwargs) + return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs) def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs): - return _shufflenetv2(pretrained, progress, '2.0', **kwargs) + return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs)