Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 59dc6d7

Browse files
vreisfacebook-github-bot
authored andcommitted
Add standard resnet models
Summary: It is annoying to configure the ResNet blocks all the time. Add the standard models to the code so we can refer to them by name Differential Revision: D20050757 fbshipit-source-id: 42ca5da802ea57736589495f1fa34125612e0839
1 parent 4715b0a commit 59dc6d7

File tree

1 file changed

+113
-8
lines changed

1 file changed

+113
-8
lines changed

classy_vision/models/resnext.py

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
import math
12-
from typing import Any, Dict
12+
from typing import Any, Dict, List, Optional, Tuple, Union
1313

1414
import torch.nn as nn
1515
from classy_vision.generic.util import is_pos_int
@@ -228,13 +228,13 @@ class ResNeXt(ClassyModel):
228228
def __init__(
229229
self,
230230
num_blocks,
231-
init_planes,
232-
reduction,
233-
small_input,
234-
zero_init_bn_residuals,
235-
base_width_and_cardinality,
236-
basic_layer,
237-
final_bn_relu,
231+
init_planes: int = 64,
232+
reduction: int = 4,
233+
small_input: bool = False,
234+
zero_init_bn_residuals: bool = False,
235+
base_width_and_cardinality: Optional[Union[Tuple, List]] = None,
236+
basic_layer: bool = False,
237+
final_bn_relu: bool = True,
238238
):
239239
"""
240240
Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
@@ -414,3 +414,108 @@ def output_shape(self):
414414
@property
415415
def model_depth(self):
416416
return sum(self.num_blocks)
417+
418+
419+
@register_model("resnet18")
420+
class ResNet18(ResNeXt):
421+
def __init__(self):
422+
super().__init__(
423+
num_blocks=[2, 2, 2, 2], basic_layer=True, zero_init_bn_residuals=True
424+
)
425+
426+
@classmethod
427+
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
428+
return cls()
429+
430+
431+
@register_model("resnet34")
432+
class ResNet34(ResNeXt):
433+
def __init__(self):
434+
super().__init__(
435+
num_blocks=[3, 4, 6, 3], basic_layer=True, zero_init_bn_residuals=True
436+
)
437+
438+
@classmethod
439+
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
440+
return cls()
441+
442+
443+
@register_model("resnet50")
444+
class ResNet50(ResNeXt):
445+
def __init__(self):
446+
super().__init__(
447+
num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True
448+
)
449+
450+
@classmethod
451+
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
452+
return cls()
453+
454+
455+
@register_model("resnet101")
456+
class ResNet101(ResNeXt):
457+
def __init__(self):
458+
super().__init__(
459+
num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True
460+
)
461+
462+
@classmethod
463+
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
464+
return cls()
465+
466+
467+
@register_model("resnet152")
468+
class ResNet152(ResNeXt):
469+
def __init__(self):
470+
super().__init__(
471+
num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True
472+
)
473+
474+
@classmethod
475+
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
476+
return cls()
477+
478+
479+
@register_model("resnext50_32x4d")
480+
class ResNet50(ResNeXt):
481+
def __init__(self):
482+
super().__init__(
483+
num_blocks=[3, 4, 6, 3],
484+
basic_layer=False,
485+
zero_init_bn_residuals=True,
486+
base_width_and_cardinality=(4, 32),
487+
)
488+
489+
@classmethod
490+
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
491+
return cls()
492+
493+
494+
@register_model("resnext101_32x4d")
495+
class ResNet101(ResNeXt):
496+
def __init__(self):
497+
super().__init__(
498+
num_blocks=[3, 4, 23, 3],
499+
basic_layer=False,
500+
zero_init_bn_residuals=True,
501+
base_width_and_cardinality=(4, 32),
502+
)
503+
504+
@classmethod
505+
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
506+
return cls()
507+
508+
509+
@register_model("resnext152_32x4d")
510+
class ResNet152(ResNeXt):
511+
def __init__(self):
512+
super().__init__(
513+
num_blocks=[3, 8, 36, 3],
514+
basic_layer=False,
515+
zero_init_bn_residuals=True,
516+
base_width_and_cardinality=(4, 32),
517+
)
518+
519+
@classmethod
520+
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
521+
return cls()

0 commit comments

Comments
 (0)