|
15 | 15 |
|
16 | 16 | from keras_cv.api_export import keras_cv_export
|
17 | 17 | from keras_cv.backend import keras
|
18 |
| -from keras_cv.utils.preset_utils import check_preset_class |
| 18 | +from keras_cv.utils.preset_utils import check_config_class |
| 19 | +from keras_cv.utils.preset_utils import list_presets |
| 20 | +from keras_cv.utils.preset_utils import list_subclasses |
19 | 21 | from keras_cv.utils.preset_utils import load_from_preset
|
20 | 22 | from keras_cv.utils.python_utils import classproperty
|
21 |
| -from keras_cv.utils.python_utils import format_docstring |
22 | 23 |
|
23 | 24 |
|
24 | 25 | @keras_cv_export("keras_cv.models.Backbone")
|
@@ -64,12 +65,18 @@ def from_config(cls, config):
|
64 | 65 | @classproperty
|
65 | 66 | def presets(cls):
|
66 | 67 | """Dictionary of preset names and configs."""
|
67 |
| - return {} |
| 68 | + presets = list_presets(cls) |
| 69 | + for subclass in list_subclasses(cls): |
| 70 | + presets.update(subclass.presets) |
| 71 | + return presets |
68 | 72 |
|
69 | 73 | @classproperty
|
70 | 74 | def presets_with_weights(cls):
|
71 | 75 | """Dictionary of preset names and configs that include weights."""
|
72 |
| - return {} |
| 76 | + presets = list_presets(cls, with_weights=True) |
| 77 | + for subclass in list_subclasses(cls): |
| 78 | + presets.update(subclass.presets) |
| 79 | + return presets |
73 | 80 |
|
74 | 81 | @classproperty
|
75 | 82 | def presets_without_weights(cls):
|
@@ -109,47 +116,19 @@ def from_preset(
|
109 | 116 | load_weights=False,
|
110 | 117 | ```
|
111 | 118 | """
|
112 |
| - # We support short IDs for official presets, e.g. `"bert_base_en"`. |
113 |
| - # Map these to a Kaggle Models handle. |
114 |
| - if preset in cls.presets: |
115 |
| - preset = cls.presets[preset]["kaggle_handle"] |
116 |
| - |
117 |
| - check_preset_class(preset, cls) |
| 119 | + preset_cls = check_config_class(preset) |
| 120 | + if not issubclass(preset_cls, cls): |
| 121 | + raise ValueError( |
| 122 | + f"Preset has type `{preset_cls.__name__}` which is not a " |
| 123 | + f"a subclass of calling class `{cls.__name__}`. Call " |
| 124 | + f"`from_preset` directly on `{preset_cls.__name__}` instead." |
| 125 | + ) |
118 | 126 | return load_from_preset(
|
119 | 127 | preset,
|
120 | 128 | load_weights=load_weights,
|
121 | 129 | config_overrides=kwargs,
|
122 | 130 | )
|
123 | 131 |
|
124 |
| - def __init_subclass__(cls, **kwargs): |
125 |
| - # Use __init_subclass__ to set up a correct docstring for from_preset. |
126 |
| - super().__init_subclass__(**kwargs) |
127 |
| - |
128 |
| - # If the subclass does not define from_preset, assign a wrapper so that |
129 |
| - # each class can have a distinct docstring. |
130 |
| - if "from_preset" not in cls.__dict__: |
131 |
| - |
132 |
| - def from_preset(calling_cls, *args, **kwargs): |
133 |
| - return super(cls, calling_cls).from_preset(*args, **kwargs) |
134 |
| - |
135 |
| - cls.from_preset = classmethod(from_preset) |
136 |
| - |
137 |
| - if not cls.presets: |
138 |
| - cls.from_preset.__func__.__doc__ = """Not implemented. |
139 |
| -
|
140 |
| - No presets available for this class. |
141 |
| - """ |
142 |
| - |
143 |
| - # Format and assign the docstring unless the subclass has overridden it. |
144 |
| - if cls.from_preset.__doc__ is None: |
145 |
| - cls.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__ |
146 |
| - format_docstring( |
147 |
| - model_name=cls.__name__, |
148 |
| - example_preset_name=next(iter(cls.presets_with_weights), ""), |
149 |
| - preset_names='", "'.join(cls.presets), |
150 |
| - preset_with_weights_names='", "'.join(cls.presets_with_weights), |
151 |
| - )(cls.from_preset.__func__) |
152 |
| - |
153 | 132 | @property
|
154 | 133 | def pyramid_level_inputs(self):
|
155 | 134 | """Intermediate model outputs for feature extraction.
|
|
0 commit comments