|
9 | 9 | """
|
10 | 10 |
|
11 | 11 | import math
|
12 |
| -from typing import Any, Dict |
| 12 | +from typing import Any, Dict, List, Optional, Tuple, Union |
13 | 13 |
|
14 | 14 | import torch.nn as nn
|
15 | 15 | from classy_vision.generic.util import is_pos_int
|
@@ -228,13 +228,13 @@ class ResNeXt(ClassyModel):
|
228 | 228 | def __init__(
|
229 | 229 | self,
|
230 | 230 | num_blocks,
|
231 |
| - init_planes, |
232 |
| - reduction, |
233 |
| - small_input, |
234 |
| - zero_init_bn_residuals, |
235 |
| - base_width_and_cardinality, |
236 |
| - basic_layer, |
237 |
| - final_bn_relu, |
| 231 | + init_planes: int = 64, |
| 232 | + reduction: int = 4, |
| 233 | + small_input: bool = False, |
| 234 | + zero_init_bn_residuals: bool = False, |
| 235 | + base_width_and_cardinality: Optional[Union[Tuple, List]] = None, |
| 236 | + basic_layer: bool = False, |
| 237 | + final_bn_relu: bool = True, |
238 | 238 | ):
|
239 | 239 | """
|
240 | 240 | Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
@@ -414,3 +414,108 @@ def output_shape(self):
|
414 | 414 | @property
|
415 | 415 | def model_depth(self):
|
416 | 416 | return sum(self.num_blocks)
|
| 417 | + |
| 418 | + |
| 419 | +@register_model("resnet18") |
| 420 | +class ResNet18(ResNeXt): |
| 421 | + def __init__(self): |
| 422 | + super().__init__( |
| 423 | + num_blocks=[2, 2, 2, 2], basic_layer=True, zero_init_bn_residuals=True |
| 424 | + ) |
| 425 | + |
| 426 | + @classmethod |
| 427 | + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": |
| 428 | + return cls() |
| 429 | + |
| 430 | + |
| 431 | +@register_model("resnet34") |
| 432 | +class ResNet34(ResNeXt): |
| 433 | + def __init__(self): |
| 434 | + super().__init__( |
| 435 | + num_blocks=[3, 4, 6, 3], basic_layer=True, zero_init_bn_residuals=True |
| 436 | + ) |
| 437 | + |
| 438 | + @classmethod |
| 439 | + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": |
| 440 | + return cls() |
| 441 | + |
| 442 | + |
| 443 | +@register_model("resnet50") |
| 444 | +class ResNet50(ResNeXt): |
| 445 | + def __init__(self): |
| 446 | + super().__init__( |
| 447 | + num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True |
| 448 | + ) |
| 449 | + |
| 450 | + @classmethod |
| 451 | + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": |
| 452 | + return cls() |
| 453 | + |
| 454 | + |
| 455 | +@register_model("resnet101") |
| 456 | +class ResNet101(ResNeXt): |
| 457 | + def __init__(self): |
| 458 | + super().__init__( |
| 459 | + num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True |
| 460 | + ) |
| 461 | + |
| 462 | + @classmethod |
| 463 | + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": |
| 464 | + return cls() |
| 465 | + |
| 466 | + |
| 467 | +@register_model("resnet152") |
| 468 | +class ResNet152(ResNeXt): |
| 469 | + def __init__(self): |
| 470 | + super().__init__( |
| 471 | + num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True |
| 472 | + ) |
| 473 | + |
| 474 | + @classmethod |
| 475 | + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": |
| 476 | + return cls() |
| 477 | + |
| 478 | + |
| 479 | +@register_model("resnext50_32x4d") |
| 480 | +class ResNet50(ResNeXt): |
| 481 | + def __init__(self): |
| 482 | + super().__init__( |
| 483 | + num_blocks=[3, 4, 6, 3], |
| 484 | + basic_layer=False, |
| 485 | + zero_init_bn_residuals=True, |
| 486 | + base_width_and_cardinality=(4, 32), |
| 487 | + ) |
| 488 | + |
| 489 | + @classmethod |
| 490 | + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": |
| 491 | + return cls() |
| 492 | + |
| 493 | + |
| 494 | +@register_model("resnext101_32x4d") |
| 495 | +class ResNet101(ResNeXt): |
| 496 | + def __init__(self): |
| 497 | + super().__init__( |
| 498 | + num_blocks=[3, 4, 23, 3], |
| 499 | + basic_layer=False, |
| 500 | + zero_init_bn_residuals=True, |
| 501 | + base_width_and_cardinality=(4, 32), |
| 502 | + ) |
| 503 | + |
| 504 | + @classmethod |
| 505 | + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": |
| 506 | + return cls() |
| 507 | + |
| 508 | + |
| 509 | +@register_model("resnext152_32x4d") |
| 510 | +class ResNet152(ResNeXt): |
| 511 | + def __init__(self): |
| 512 | + super().__init__( |
| 513 | + num_blocks=[3, 8, 36, 3], |
| 514 | + basic_layer=False, |
| 515 | + zero_init_bn_residuals=True, |
| 516 | + base_width_and_cardinality=(4, 32), |
| 517 | + ) |
| 518 | + |
| 519 | + @classmethod |
| 520 | + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": |
| 521 | + return cls() |
0 commit comments