Skip to content

Commit 19d8c02

Browse files
authored
Merge pull request #5 from kafkasl/master
Add rest of data pipeline
2 parents 3284c30 + fa733ac commit 19d8c02

File tree

9 files changed

+261
-29
lines changed

9 files changed

+261
-29
lines changed

src/__main__.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import argparse
2+
import sys
3+
4+
from sourced.ml.cmd.args import add_repo2_args
5+
from sourced.ml.cmd import ArgumentDefaultsHelpFormatterNoNone
6+
from cmd.code2vec_extract_features import code2vec_extract_features
7+
8+
9+
def get_parser() -> argparse.ArgumentParser:
10+
"""
11+
Creates the cmdline argument parser.
12+
"""
13+
parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatterNoNone)
14+
15+
# sourced.engine args
16+
17+
subparsers = parser.add_subparsers(help="Commands", dest="command")
18+
19+
extract_parser = subparsers.add_parser("extract",
20+
help="Extract features from input repositories",
21+
formatter_class=ArgumentDefaultsHelpFormatterNoNone)
22+
23+
extract_parser.set_defaults(handler=code2vec_extract_features)
24+
25+
add_repo2_args(extract_parser)
26+
27+
# code2vec specific args
28+
extract_parser.add_argument('--max-length', type=int, default=5, help="Max path length.",
29+
required=False)
30+
extract_parser.add_argument('--max-width', type=int, default=2, help="Max path width.",
31+
required=False)
32+
extract_parser.add_argument('-o', '--output', type=str,
33+
help="Output path for the Code2VecFeatures model", required=True)
34+
return parser
35+
36+
37+
def main():
38+
parser = get_parser()
39+
40+
args = parser.parse_args()
41+
42+
try:
43+
handler = args.handler
44+
except AttributeError:
45+
def print_usage(_):
46+
parser.print_usage()
47+
48+
handler = print_usage
49+
50+
return handler(args)
51+
52+
53+
if __name__ == "__main__":
54+
sys.exit(main())

src/algorithms/uast_to_bag_paths.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ def __call__(self, uast):
3030
dict_of_paths = {str(path): val for path, val in Counter(path_contexts).items()}
3131
self._log.info("Extracted paths successfully")
3232

33-
from pprint import pprint
34-
pprint(dict_of_paths)
35-
3633
return dict_of_paths
3734

3835
def _get_log_name(self):

src/cmd/__init__.py

Whitespace-only changes.
Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,29 @@
11
import logging
2-
import argparse
32
from uuid import uuid4
43

54
from extractors.paths import UastPathsBagExtractor
5+
from transformers.vocabulary2id import Vocabulary2Id
66
from sourced.ml.transformers import UastDeserializer, Uast2BagFeatures, create_uast_source, \
7-
UastRow2Document, Collector
7+
UastRow2Document, Moder
88
from sourced.ml.utils.engine import pipeline_graph, pause
9-
from sourced.ml.cmd.args import add_repo2_args
9+
1010

1111
@pause
12-
def code2vec(args):
12+
def code2vec_extract_features(args):
1313
log = logging.getLogger("code2vec")
1414
session_name = "code2vec-%s" % uuid4()
1515
root, start_point = create_uast_source(args, session_name)
1616

1717
res = start_point \
18+
.link(Moder("func")) \
1819
.link(UastRow2Document()) \
1920
.link(UastDeserializer()) \
2021
.link(Uast2BagFeatures([UastPathsBagExtractor(args.max_length, args.max_width)])) \
21-
.link(Collector()) \
22+
.link(Vocabulary2Id(root.session.sparkContext, args.output)) \
2223
.execute()
2324

2425
# TODO: Add rest of data pipeline: extract distinct paths and terminal nodes for embedding mapping
2526
# TODO: Add transformer to write bags and vocabs to a model
2627
# TODO: Add ML pipeline
2728

2829
pipeline_graph(args, log, root)
29-
30-
31-
def main():
32-
parser = argparse.ArgumentParser()
33-
34-
# sourced.engine args
35-
add_repo2_args(parser)
36-
37-
# code2vec specific args
38-
parser.add_argument('-g', '--max_length', type=int, default=5, help="Max path length.",
39-
required=False)
40-
parser.add_argument('-w', '--max_width', type=int, default=2, help="Max path width.",
41-
required=False)
42-
43-
args = parser.parse_args()
44-
45-
code2vec(args)
46-
47-
48-
if __name__ == '__main__':
49-
main()

src/models/__init__.py

Whitespace-only changes.

src/models/code2vec_features.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from modelforge import register_model, Model
2+
from itertools import islice
3+
4+
5+
@register_model
6+
class Code2VecFeatures(Model):
7+
"""
8+
Code2VecFeatures model - path contexts from source code.
9+
"""
10+
NAME = "code2vec_features"
11+
12+
def construct(self, value2index, path2index, value2freq, path2freq, path_contexts):
13+
self._value2index = value2index
14+
self._path2index = path2index
15+
self._value2freq = value2freq
16+
self._path2freq = path2freq
17+
18+
self._path_contexts = path_contexts
19+
return self
20+
21+
def _load_tree(self, tree):
22+
self.construct(value2index=tree["value2index"],
23+
path2index=tree["path2index"],
24+
value2freq=tree["value2freq"],
25+
path2freq=tree["path2freq"],
26+
path_contexts=tree["path_contexts"])
27+
28+
@property
29+
def value2index(self):
30+
"""
31+
Dict mapping value -> ID.
32+
"""
33+
return self._value2index
34+
35+
@property
36+
def path2index(self):
37+
"""
38+
Dict mapping path -> ID.
39+
"""
40+
return self._path2index
41+
42+
@property
43+
def value2freq(self):
44+
"""
45+
Dict mapping value -> frequency.
46+
"""
47+
return self._value2freq
48+
49+
@property
50+
def path2freq(self):
51+
"""
52+
Dict mapping path -> frequency.
53+
"""
54+
return self._path2freq
55+
56+
@property
57+
def path_contexts(self):
58+
"""
59+
List with the processed source code identifiers.
60+
"""
61+
return self._path_contexts
62+
63+
def value2index_items(self):
64+
"""
65+
Returns the tuples belonging to value -> index mapping.
66+
"""
67+
return self._value2index.items()
68+
69+
def path2index_items(self):
70+
"""
71+
Returns the tuples belonging to path -> index mapping.
72+
"""
73+
return self._path2index.items()
74+
75+
def value2freq_items(self):
76+
"""
77+
Returns the tuples belonging to value -> freq mapping.
78+
"""
79+
return self._value2freq.items()
80+
81+
def path2freq_items(self):
82+
"""
83+
Returns the tuples belonging to path -> freq mapping.
84+
"""
85+
return self._path2freq.items()
86+
87+
def _generate_tree(self):
88+
return {"value2index": self._value2index,
89+
"path2index": self._path2index,
90+
"value2freq": self._value2freq,
91+
"path2freq": self._path2freq,
92+
"path_contexts": self._path_contexts}
93+
94+
def dump(self):
95+
return "Number of values: %s\n" \
96+
"Number of paths: %s\n" \
97+
"First 10 value -> ID: %s\n" \
98+
"First 10 path -> ID: %s\n" \
99+
"First 10 value -> frequency: %s\n" \
100+
"First 10 path -> frequency: %s" % \
101+
(len(self._value2index_freq),
102+
len(self.path2index_freq),
103+
list(islice(self._value2index, 10)),
104+
list(islice(self._path2index, 10)),
105+
list(islice(self._value2freq, 10)),
106+
list(islice(self._path2freq, 10)))

src/transformers/__init__.py

Whitespace-only changes.

src/transformers/vocabulary2id.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import operator
2+
3+
from pyspark import RDD, Row
4+
from models.code2vec_features import Code2VecFeatures
5+
6+
from ast import literal_eval as make_tuple
7+
from sourced.ml.transformers import Transformer
8+
9+
10+
class Vocabulary2Id(Transformer):
11+
def __init__(self, sc, output: str, **kwargs):
12+
super().__init__(**kwargs)
13+
self.output = output
14+
self.sc = sc
15+
16+
def __call__(self, rows: RDD):
17+
value2index, path2index, value2freq, path2freq = self.build_vocabularies(rows)
18+
19+
doc2path_contexts = self.build_doc2pc(value2index, path2index, rows)
20+
21+
doc2path_contexts = doc2path_contexts.collect()
22+
23+
Code2VecFeatures().construct(value2index=value2index,
24+
path2index=path2index,
25+
value2freq=value2freq,
26+
path2freq=path2freq,
27+
path_contexts=doc2path_contexts).save(
28+
self.output)
29+
30+
@staticmethod
31+
def _unstringify_path_context(row):
32+
"""
33+
Takes a row containing ((pc, doc), freq) and returns a tuple (u, path, v)
34+
(removes namespace prefix v.)
35+
"""
36+
return make_tuple(row[0][0][2:])
37+
38+
def build_vocabularies(self, rows: RDD):
39+
"""
40+
Process rows to gather values and paths with their frequencies.
41+
:param rows: row structure is ((key, doc), val) where:
42+
* key: str with the path context
43+
* doc: file name
44+
* val: number of occurrences of key in doc
45+
"""
46+
47+
def _flatten_row(row: Row):
48+
# 2: removes the namespace v. from the string to parse it as tuple
49+
k = Vocabulary2Id._unstringify_path_context(row)
50+
return [(k[0], 1), (k[1], 1), (k[2], 1)]
51+
52+
rows = rows \
53+
.flatMap(_flatten_row) \
54+
.reduceByKey(operator.add) \
55+
.persist()
56+
57+
values = rows.filter(lambda x: type(x[0]) == str).collect()
58+
paths = rows.filter(lambda x: type(x[0]) == tuple).collect()
59+
60+
value2index = {w: id for id, (w, _) in enumerate(values)}
61+
path2index = {w: id for id, (w, _) in enumerate(paths)}
62+
value2freq = {w: freq for _, (w, freq) in enumerate(values)}
63+
path2freq = {w: freq for _, (w, freq) in enumerate(paths)}
64+
65+
rows.unpersist()
66+
67+
return value2index, path2index, value2freq, path2freq
68+
69+
def build_doc2pc(self, value2index: dict, path2index: dict, rows: RDD):
70+
"""
71+
Process rows and build elements (doc, [path_context_1, path_context_2, ...])
72+
:param value2index_freq: value -> id
73+
:param path2index_freq: path -> id
74+
"""
75+
76+
bc_value2index = self.sc.broadcast(value2index)
77+
bc_path2index = self.sc.broadcast(path2index)
78+
79+
def _doc2pc(row: Row):
80+
(u, path, v), doc = Vocabulary2Id._unstringify_path_context(row), row[0][1]
81+
82+
return doc, (bc_value2index.value[u], bc_path2index.value[path],
83+
bc_value2index.value[v])
84+
85+
rows = rows \
86+
.map(_doc2pc) \
87+
.distinct() \
88+
.combineByKey(lambda value: [value],
89+
lambda x, value: x + [value],
90+
lambda x, y: x + y)
91+
92+
bc_value2index.unpersist(blocking=True)
93+
bc_path2index.unpersist(blocking=True)
94+
95+
return rows

0 commit comments

Comments
 (0)