Skip to content

Commit 7f51958

Browse files
authored
Merge pull request #10 from r0mainK/fix/model
Fix the model saving
2 parents 19d8c02 + 5d1114f commit 7f51958

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/models/code2vec_features.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ def construct(self, value2index, path2index, value2freq, path2freq, path_context
1414
self._path2index = path2index
1515
self._value2freq = value2freq
1616
self._path2freq = path2freq
17-
1817
self._path_contexts = path_contexts
1918
return self
2019

2120
def _load_tree(self, tree):
22-
self.construct(value2index=tree["value2index"],
23-
path2index=tree["path2index"],
24-
value2freq=tree["value2freq"],
25-
path2freq=tree["path2freq"],
26-
path_contexts=tree["path_contexts"])
21+
self.construct(
22+
value2index=tree["value2index"],
23+
path2index={tuple(val[0]): key for (key, val) in tree["index2path_freq"].items()},
24+
value2freq=tree["value2freq"],
25+
path2freq={tuple(val[0]): val[1] for (_, val) in tree["index2path_freq"].items()},
26+
path_contexts=tree["path_contexts"])
2727

2828
@property
2929
def value2index(self):
@@ -86,9 +86,9 @@ def path2freq_items(self):
8686

8787
def _generate_tree(self):
8888
return {"value2index": self._value2index,
89-
"path2index": self._path2index,
89+
"index2path_freq": {val: (key, self._path2freq[key])
90+
for (key, val) in self._path2index.items()},
9091
"value2freq": self._value2freq,
91-
"path2freq": self._path2freq,
9292
"path_contexts": self._path_contexts}
9393

9494
def dump(self):
@@ -98,9 +98,9 @@ def dump(self):
9898
"First 10 path -> ID: %s\n" \
9999
"First 10 value -> frequency: %s\n" \
100100
"First 10 path -> frequency: %s" % \
101-
(len(self._value2index_freq),
102-
len(self.path2index_freq),
103-
list(islice(self._value2index, 10)),
104-
list(islice(self._path2index, 10)),
105-
list(islice(self._value2freq, 10)),
106-
list(islice(self._path2freq, 10)))
101+
(len(self._value2index),
102+
len(self._path2index),
103+
list(islice(self.value2index_items(), 10)),
104+
list(islice(self.path2index_items(), 10)),
105+
list(islice(self.value2freq_items(), 10)),
106+
list(islice(self.path2freq_items(), 10)))

0 commit comments

Comments
 (0)