1
1
import os
2
+ import sys
2
3
import inspect
3
4
4
5
import torch
12
13
from torchdrug import core , data
13
14
from torchdrug .core import Registry as R
14
15
15
-
16
- def make_configurable (cls , ignore_args = ()):
17
- ignore_args = set (ignore_args )
18
- return type (cls .__name__ , (cls , core .Configurable ), {"_ignore_args" : ignore_args })
16
+ module = sys .modules [__name__ ]
19
17
20
18
21
19
class PatchedModule (nn .Module ):
22
20
23
21
def __init__ (self ):
24
22
super (PatchedModule , self ).__init__ ()
23
+ # TODO: these hooks are bugged.
25
24
# self._register_state_dict_hook(PatchedModule.graph_state_dict)
26
25
# self._register_load_state_dict_pre_hook(PatchedModule.load_graph_state_dict)
27
26
@@ -54,7 +53,6 @@ def load_graph_state_dict(cls, state_dict, prefix, local_metadata, strict, missi
54
53
print ("successfully assigned %s" % (prefix + name ))
55
54
except :
56
55
error_msgs .append ("Can't construct Graph `%s` from tensors in the state dict" % key )
57
- raise ValueError
58
56
return state_dict
59
57
60
58
@property
@@ -137,22 +135,22 @@ def _get_build_directory(name, verbose):
137
135
for name , cls in inspect .getmembers (optim ):
138
136
if inspect .isclass (cls ) and issubclass (cls , Optimizer ):
139
137
setattr (optim , "_%s" % name , cls )
140
- cls = make_configurable (cls , ignore_args = ("params" ,))
138
+ cls = core . make_configurable (cls , ignore_args = ("params" ,))
141
139
cls = R .register ("optim.%s" % name )(cls )
142
140
setattr (optim , name , cls )
143
141
144
142
Scheduler = scheduler ._LRScheduler
145
143
for name , cls in inspect .getmembers (scheduler ):
146
144
if inspect .isclass (cls ) and issubclass (cls , Scheduler ):
147
145
setattr (scheduler , "_%s" % name , cls )
148
- cls = make_configurable (cls , ignore_args = ("optimizer" ,))
146
+ cls = core . make_configurable (cls , ignore_args = ("optimizer" ,))
149
147
cls = R .register ("scheduler.%s" % name )(cls )
150
148
setattr (optim , name , cls )
151
149
152
150
Dataset = dataset .Dataset
153
151
for name , cls in inspect .getmembers (dataset ):
154
152
if inspect .isclass (cls ) and issubclass (cls , Dataset ):
155
153
setattr (dataset , "_%s" % name , cls )
156
- cls = make_configurable (cls )
154
+ cls = core . make_configurable (cls )
157
155
cls = R .register ("dataset.%s" % name )(cls )
158
156
setattr (dataset , name , cls )
0 commit comments