Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Add standard resnet models #405

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 113 additions & 8 deletions classy_vision/models/resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

import math
from typing import Any, Dict
from typing import Any, Dict, List, Optional, Tuple, Union

import torch.nn as nn
from classy_vision.generic.util import is_pos_int
Expand Down Expand Up @@ -228,13 +228,13 @@ class ResNeXt(ClassyModel):
def __init__(
self,
num_blocks,
init_planes,
reduction,
small_input,
zero_init_bn_residuals,
base_width_and_cardinality,
basic_layer,
final_bn_relu,
init_planes: int = 64,
reduction: int = 4,
small_input: bool = False,
zero_init_bn_residuals: bool = False,
base_width_and_cardinality: Optional[Union[Tuple, List]] = None,
basic_layer: bool = False,
final_bn_relu: bool = True,
):
"""
Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
Expand Down Expand Up @@ -414,3 +414,108 @@ def output_shape(self):
@property
def model_depth(self):
return sum(self.num_blocks)


@register_model("resnet18")
class ResNet18(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[2, 2, 2, 2], basic_layer=True, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet34")
class ResNet34(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 6, 3], basic_layer=True, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet50")
class ResNet50(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet101")
class ResNet101(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet152")
class ResNet152(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext50_32x4d")
class ResNeXt50(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 6, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext101_32x4d")
class ResNeXt101(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 23, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext152_32x4d")
class ResNeXt152(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 8, 36, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()