Skip to content

Commit fa733ac

Browse files
committed
Add main and rework data pipeline
Signed-off-by: Pol Alvarez Vecino <[email protected]>
1 parent 45ab7e8 commit fa733ac

File tree

5 files changed

+60
-29
lines changed

5 files changed

+60
-29
lines changed

src/__main__.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import argparse
2+
import sys
3+
4+
from sourced.ml.cmd.args import add_repo2_args
5+
from sourced.ml.cmd import ArgumentDefaultsHelpFormatterNoNone
6+
from cmd.code2vec_extract_features import code2vec_extract_features
7+
8+
9+
def get_parser() -> argparse.ArgumentParser:
10+
"""
11+
Creates the cmdline argument parser.
12+
"""
13+
parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatterNoNone)
14+
15+
# sourced.engine args
16+
17+
subparsers = parser.add_subparsers(help="Commands", dest="command")
18+
19+
extract_parser = subparsers.add_parser("extract",
20+
help="Extract features from input repositories",
21+
formatter_class=ArgumentDefaultsHelpFormatterNoNone)
22+
23+
extract_parser.set_defaults(handler=code2vec_extract_features)
24+
25+
add_repo2_args(extract_parser)
26+
27+
# code2vec specific args
28+
extract_parser.add_argument('--max-length', type=int, default=5, help="Max path length.",
29+
required=False)
30+
extract_parser.add_argument('--max-width', type=int, default=2, help="Max path width.",
31+
required=False)
32+
extract_parser.add_argument('-o', '--output', type=str,
33+
help="Output path for the Code2VecFeatures model", required=True)
34+
return parser
35+
36+
37+
def main():
38+
parser = get_parser()
39+
40+
args = parser.parse_args()
41+
42+
try:
43+
handler = args.handler
44+
except AttributeError:
45+
def print_usage(_):
46+
parser.print_usage()
47+
48+
handler = print_usage
49+
50+
return handler(args)
51+
52+
53+
if __name__ == "__main__":
54+
sys.exit(main())

src/algorithms/uast_to_bag_paths.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ def __call__(self, uast):
3030
dict_of_paths = Counter(path_contexts)
3131
self._log.info("Extracted paths successfully")
3232

33-
from pprint import pprint
34-
pprint(dict_of_paths)
35-
3633
return dict_of_paths
3734

3835
def _get_log_name(self):

src/cmd/__init__.py

Whitespace-only changes.
Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,29 @@
11
import logging
2-
import argparse
32
from uuid import uuid4
43

54
from extractors.paths import UastPathsBagExtractor
5+
from transformers.vocabulary2id import Vocabulary2Id
66
from sourced.ml.transformers import UastDeserializer, Uast2BagFeatures, create_uast_source, \
7-
UastRow2Document, Collector
7+
UastRow2Document, Moder
88
from sourced.ml.utils.engine import pipeline_graph, pause
9-
from sourced.ml.cmd.args import add_repo2_args
9+
1010

1111
@pause
12-
def code2vec(args):
12+
def code2vec_extract_features(args):
1313
log = logging.getLogger("code2vec")
1414
session_name = "code2vec-%s" % uuid4()
1515
root, start_point = create_uast_source(args, session_name)
1616

1717
res = start_point \
18+
.link(Moder("func")) \
1819
.link(UastRow2Document()) \
1920
.link(UastDeserializer()) \
2021
.link(Uast2BagFeatures([UastPathsBagExtractor(args.max_length, args.max_width)])) \
21-
.link(Collector()) \
22+
.link(Vocabulary2Id(root.session.sparkContext, args.output)) \
2223
.execute()
2324

2425
# TODO: Add rest of data pipeline: extract distinct paths and terminal nodes for embedding mapping
2526
# TODO: Add transformer to write bags and vocabs to a model
2627
# TODO: Add ML pipeline
2728

2829
pipeline_graph(args, log, root)
29-
30-
31-
def main():
32-
parser = argparse.ArgumentParser()
33-
34-
# sourced.engine args
35-
add_repo2_args(parser)
36-
37-
# code2vec specific args
38-
parser.add_argument('-g', '--max_length', type=int, default=5, help="Max path length.",
39-
required=False)
40-
parser.add_argument('-w', '--max_width', type=int, default=2, help="Max path width.",
41-
required=False)
42-
43-
args = parser.parse_args()
44-
45-
code2vec(args)
46-
47-
48-
if __name__ == '__main__':
49-
main()

0 commit comments

Comments
 (0)