|
| 1 | +from collections import OrderedDict |
| 2 | +from dataclasses import dataclass, fields |
| 3 | +from enum import Enum |
| 4 | +from typing import Any, Callable, Dict |
| 5 | + |
| 6 | +from ..._internally_replaced_utils import load_state_dict_from_url |
| 7 | + |
| 8 | + |
| 9 | +__all__ = ["Weights", "WeightEntry"] |
| 10 | + |
| 11 | + |
| 12 | +@dataclass |
| 13 | +class WeightEntry: |
| 14 | + """ |
| 15 | + This class is used to group important attributes associated with the pre-trained weights. |
| 16 | +
|
| 17 | + Args: |
| 18 | + url (str): The location where we find the weights. |
| 19 | + transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms) |
| 20 | + needed to use the model. The reason we attach a constructor method rather than an already constructed |
| 21 | + object is because the specific object might have memory and thus we want to delay initialization until |
| 22 | + needed. |
| 23 | + meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be |
| 24 | + informative attributes (for example the number of parameters/flops, recipe link/methods used in training |
| 25 | + etc), configuration parameters (for example the `num_classes`) needed to construct the model or important |
| 26 | + meta-data (for example the `classes` of a classification model) needed to use the model. |
| 27 | + """ |
| 28 | + |
| 29 | + url: str |
| 30 | + transforms: Callable |
| 31 | + meta: Dict[str, Any] |
| 32 | + |
| 33 | + |
| 34 | +class Weights(Enum): |
| 35 | + """ |
| 36 | + This class is the parent class of all model weights. Each model building method receives an optional `weights` |
| 37 | + parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type |
| 38 | + `WeightEntry`. |
| 39 | +
|
| 40 | + Args: |
| 41 | + value (WeightEntry): The data class entry with the weight information. |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, value: WeightEntry): |
| 45 | + self._value_ = value |
| 46 | + |
| 47 | + @classmethod |
| 48 | + def verify(cls, obj: Any) -> Any: |
| 49 | + if obj is not None: |
| 50 | + if type(obj) is str: |
| 51 | + obj = cls.from_str(obj) |
| 52 | + elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry): |
| 53 | + raise TypeError( |
| 54 | + f"Invalid Weight class provided; expected {cls.__name__} " f"but received {obj.__class__.__name__}." |
| 55 | + ) |
| 56 | + return obj |
| 57 | + |
| 58 | + @classmethod |
| 59 | + def from_str(cls, value: str) -> "Weights": |
| 60 | + for v in cls: |
| 61 | + if v._name_ == value: |
| 62 | + return v |
| 63 | + raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") |
| 64 | + |
| 65 | + def state_dict(self, progress: bool) -> OrderedDict: |
| 66 | + return load_state_dict_from_url(self.url, progress=progress) |
| 67 | + |
| 68 | + def __repr__(self): |
| 69 | + return f"{self.__class__.__name__}.{self._name_}" |
| 70 | + |
| 71 | + def __getattr__(self, name): |
| 72 | + # Be able to fetch WeightEntry attributes directly |
| 73 | + for f in fields(WeightEntry): |
| 74 | + if f.name == name: |
| 75 | + return object.__getattribute__(self.value, name) |
| 76 | + return super().__getattr__(name) |
0 commit comments