Skip to content

Commit d187dd8

Browse files
committed
fix metaclass issue in patch.py & now compatible with PyTorch 1.9.0
1 parent 00cec6a commit d187dd8

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

torchdrug/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .core import _MetaContainer, Registry, Configurable
1+
from .core import _MetaContainer, Registry, Configurable, make_configurable
22
from .engine import Engine
33
from .meter import Meter
44

torchdrug/core/core.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def wrapper(obj):
187187
for key in keys[:-1]:
188188
entry = entry[key]
189189
if keys[-1] in entry:
190-
return obj
191190
raise KeyError("`%s` has already been registered by %s" % (name, entry[keys[-1]]))
192191

193192
entry[keys[-1]] = obj
@@ -304,7 +303,7 @@ def get_function(method):
304303
class Configurable(metaclass=_Configurable):
305304
"""
306305
Class for load/save configuration.
307-
It will automatically record every argument passed to ``__init__`` function.
306+
It will automatically record every argument passed to the ``__init__`` function.
308307
309308
This class is inspired by :meth:`state_dict()` in PyTorch, but designed for hyperparameters.
310309
@@ -334,3 +333,26 @@ class Configurable(metaclass=_Configurable):
334333
>>> gcn = Configurable.load_config_dict(config)
335334
"""
336335
pass
336+
337+
338+
def make_configurable(cls, module=None, ignore_args=()):
339+
"""
340+
Make a configurable class out of an existing class.
341+
The configurable class will automatically record every argument passed to its ``__init__`` function.
342+
343+
Parameters:
344+
cls (type): input class
345+
module (str, optional): bind the output class to this module.
346+
By default, bind to the original module of the input class.
347+
ignore_args (set of str, optional): arguments to ignore in the ``__init__`` function
348+
"""
349+
ignore_args = set(ignore_args)
350+
module = module or cls.__module__
351+
Metaclass = type(cls)
352+
if issubclass(Metaclass, _Configurable): # already a configurable class
353+
return cls
354+
if Metaclass != type: # already have a meta class
355+
MetaClass = type(_Configurable.__name__, (Metaclass, _Configurable), {})
356+
else:
357+
MetaClass = _Configurable
358+
return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module})

torchdrug/patch.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23
import inspect
34

45
import torch
@@ -12,16 +13,14 @@
1213
from torchdrug import core, data
1314
from torchdrug.core import Registry as R
1415

15-
16-
def make_configurable(cls, ignore_args=()):
17-
ignore_args = set(ignore_args)
18-
return type(cls.__name__, (cls, core.Configurable), {"_ignore_args": ignore_args})
16+
module = sys.modules[__name__]
1917

2018

2119
class PatchedModule(nn.Module):
2220

2321
def __init__(self):
2422
super(PatchedModule, self).__init__()
23+
# TODO: these hooks are bugged.
2524
# self._register_state_dict_hook(PatchedModule.graph_state_dict)
2625
# self._register_load_state_dict_pre_hook(PatchedModule.load_graph_state_dict)
2726

@@ -54,7 +53,6 @@ def load_graph_state_dict(cls, state_dict, prefix, local_metadata, strict, missi
5453
print("successfully assigned %s" % (prefix + name))
5554
except:
5655
error_msgs.append("Can't construct Graph `%s` from tensors in the state dict" % key)
57-
raise ValueError
5856
return state_dict
5957

6058
@property
@@ -137,22 +135,22 @@ def _get_build_directory(name, verbose):
137135
for name, cls in inspect.getmembers(optim):
138136
if inspect.isclass(cls) and issubclass(cls, Optimizer):
139137
setattr(optim, "_%s" % name, cls)
140-
cls = make_configurable(cls, ignore_args=("params",))
138+
cls = core.make_configurable(cls, ignore_args=("params",))
141139
cls = R.register("optim.%s" % name)(cls)
142140
setattr(optim, name, cls)
143141

144142
Scheduler = scheduler._LRScheduler
145143
for name, cls in inspect.getmembers(scheduler):
146144
if inspect.isclass(cls) and issubclass(cls, Scheduler):
147145
setattr(scheduler, "_%s" % name, cls)
148-
cls = make_configurable(cls, ignore_args=("optimizer",))
146+
cls = core.make_configurable(cls, ignore_args=("optimizer",))
149147
cls = R.register("scheduler.%s" % name)(cls)
150148
setattr(optim, name, cls)
151149

152150
Dataset = dataset.Dataset
153151
for name, cls in inspect.getmembers(dataset):
154152
if inspect.isclass(cls) and issubclass(cls, Dataset):
155153
setattr(dataset, "_%s" % name, cls)
156-
cls = make_configurable(cls)
154+
cls = core.make_configurable(cls)
157155
cls = R.register("dataset.%s" % name)(cls)
158156
setattr(dataset, name, cls)

0 commit comments

Comments
 (0)