Skip to content

Commit 9fca616

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Multi-pretrained weight support - initial API + ResNet50 (#4610)
Summary: * Adding lightweight API for models. * Adding resnet50. * Fix preset * Add fake categories. * Fixing mypy. * Add string=>weight conversion support on Enums. * Temporarily hardcoding imagenet categories. * Minor refactoring. Reviewed By: fmassa Differential Revision: D31649970 fbshipit-source-id: b4908da7be972c0a19949e75d61f2051e785494c
1 parent da0aa01 commit 9fca616

File tree

7 files changed

+1199
-0
lines changed

7 files changed

+1199
-0
lines changed

torchvision/prototype/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from . import datasets
2+
from . import models
3+
from . import transforms
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .resnet import *

torchvision/prototype/models/_api.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from collections import OrderedDict
2+
from dataclasses import dataclass, fields
3+
from enum import Enum
4+
from typing import Any, Callable, Dict
5+
6+
from ..._internally_replaced_utils import load_state_dict_from_url
7+
8+
9+
__all__ = ["Weights", "WeightEntry"]
10+
11+
12+
@dataclass
13+
class WeightEntry:
14+
"""
15+
This class is used to group important attributes associated with the pre-trained weights.
16+
17+
Args:
18+
url (str): The location where we find the weights.
19+
transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms)
20+
needed to use the model. The reason we attach a constructor method rather than an already constructed
21+
object is because the specific object might have memory and thus we want to delay initialization until
22+
needed.
23+
meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be
24+
informative attributes (for example the number of parameters/flops, recipe link/methods used in training
25+
etc), configuration parameters (for example the `num_classes`) needed to construct the model or important
26+
meta-data (for example the `classes` of a classification model) needed to use the model.
27+
"""
28+
29+
url: str
30+
transforms: Callable
31+
meta: Dict[str, Any]
32+
33+
34+
class Weights(Enum):
35+
"""
36+
This class is the parent class of all model weights. Each model building method receives an optional `weights`
37+
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
38+
`WeightEntry`.
39+
40+
Args:
41+
value (WeightEntry): The data class entry with the weight information.
42+
"""
43+
44+
def __init__(self, value: WeightEntry):
45+
self._value_ = value
46+
47+
@classmethod
48+
def verify(cls, obj: Any) -> Any:
49+
if obj is not None:
50+
if type(obj) is str:
51+
obj = cls.from_str(obj)
52+
elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry):
53+
raise TypeError(
54+
f"Invalid Weight class provided; expected {cls.__name__} " f"but received {obj.__class__.__name__}."
55+
)
56+
return obj
57+
58+
@classmethod
59+
def from_str(cls, value: str) -> "Weights":
60+
for v in cls:
61+
if v._name_ == value:
62+
return v
63+
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
64+
65+
def state_dict(self, progress: bool) -> OrderedDict:
66+
return load_state_dict_from_url(self.url, progress=progress)
67+
68+
def __repr__(self):
69+
return f"{self.__class__.__name__}.{self._name_}"
70+
71+
def __getattr__(self, name):
72+
# Be able to fetch WeightEntry attributes directly
73+
for f in fields(WeightEntry):
74+
if f.name == name:
75+
return object.__getattribute__(self.value, name)
76+
return super().__getattr__(name)

0 commit comments

Comments
 (0)