Skip to content

implement a clustering poc #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/semantic_code_search/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from semantic_code_search.embed import do_embed
from semantic_code_search.query import do_query
from semantic_code_search.cluster import do_cluster


def git_root(path=None):
Expand Down Expand Up @@ -37,6 +38,11 @@ def query_func(args):
do_query(args, model)


def cluster_func(args):
model = SentenceTransformer(args.model_name_or_path)
do_cluster(args, model)


def main():
parser = argparse.ArgumentParser(
prog='sem', description='Search your codebase using natural language')
Expand All @@ -55,13 +61,25 @@ def main():
required=False, default=5, help='Number of results to return')
parser.add_argument('-e', '--editor', choices=[
'vscode', 'vim'], default='vscode', required=False, help='Editor to open selected result in')
parser.add_argument('-c', '--cluster', action='store_true', default=False,
required=False, help='Generate clusters of related functions and methods')
parser.add_argument('--cluster-max-distance', metavar='THRESHOLD', type=float, default=0.2, required=False,
help='How close functions need to be to one another to be clustered. Distance 0 means that the code is identical, smaller values (e.g. 0.2, 0.3) are stricter and result in fewer matches ')
parser.add_argument('--cluster-min-lines', metavar='SIZE', type=int, default=0, required=False,
help='Ignore clusters with code snippets smaller than this size (lines of code). Use this if you are not interested in smaller duplications (eg. one liners)')
parser.add_argument('--cluster-min-cluster-size', metavar='SIZE', type=int, default=2, required=False,
help='Ignore clusters smaller than this size. Use this if you want to find code that is similar and repeated many times (e.g. >5)')
parser.add_argument('--cluster-ignore-identincal', action='store_true', default=True,
required=False, help='Ignore identical code / exact duplicates (where distance is 0)')
parser.set_defaults(func=query_func)
parser.add_argument('query_text', nargs=argparse.REMAINDER)

args = parser.parse_args()

if args.embed:
embed_func(args)
elif args.cluster:
cluster_func(args)
else:
query_func(args)

Expand Down
84 changes: 84 additions & 0 deletions src/semantic_code_search/cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import gzip
import os
import pickle
from semantic_code_search.embed import do_embed
from sklearn.cluster import AgglomerativeClustering
import numpy as np
from textwrap import indent


def _get_clusters(dataset, distance_threshold):
embeddings = dataset.get('embeddings')
# Normalize the embeddings to unit length
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
dataset['embeddings'] = embeddings

clustering_model = AgglomerativeClustering(
n_clusters=None,
distance_threshold=distance_threshold,
compute_distances=True,
)
clustering_model.fit(embeddings)
cluster_assignment = clustering_model.labels_
cluster_distances = clustering_model.distances_
cluster_children = clustering_model.children_

clustered_functions = {}
for idx, cluster_id in enumerate(cluster_assignment):
if cluster_id not in clustered_functions:
clustered_functions[cluster_id] = []

ds_entry = dataset.get('functions')[idx]
ds_entry['idx'] = idx

clustered_functions[cluster_id].append(ds_entry)

# filter out clusters with only one function
clusters = []
for cluster_id, functions in clustered_functions.items():
if len(functions) > 1:
fx_idx = functions[0].get('idx')
distances = []
for f in functions[1:]:
f_idx = f.get('idx')
for i, cc in enumerate(cluster_children):
if cc.tolist() == [fx_idx, f_idx]:
distances.append(cluster_distances[i])
avg_distance = sum(distances) / \
len(distances) if len(distances) > 0 else 0
clusters.append(
{'avg_distance': avg_distance, 'functions': functions})

return clusters


def do_cluster(args, model):
if not os.path.isfile(args.path_to_repo + '/' + '.embeddings'):
print('Embeddings not found in {}. Generating embeddings now.'.format(
args.path_to_repo))
do_embed(args, model)

with gzip.open(args.path_to_repo + '/' + '.embeddings', 'r') as f:
dataset = pickle.loads(f.read())
if dataset.get('model_name') != args.model_name_or_path:
print('Model name mismatch. Regenerating embeddings.')
dataset = do_embed(args, model)
clusters = _get_clusters(dataset, args.cluster_max_distance)

filtered_clusters = []
for c in (clusters):
if args.cluster_ignore_identincal and c.get('avg_distance') == 0:
continue
if any([len(f.get('text').split('\n')) <= args.cluster_min_lines for f in c.get('functions')]):
continue
if len(c.get('functions')) < args.cluster_min_cluster_size:
continue
filtered_clusters.append(c)

for i, c in enumerate(filtered_clusters):
print('Cluster #{}: avg_distance: {:.3} ================================================\n'.format(
i, c.get('avg_distance')))
# print('avg_distance:', c.get('avg_distance'))
for f in c.get('functions'):
print(indent(f.get('file'), ' ') + ':' + str(f.get('line')))
print(indent(f.get('text'), ' ') + '\n')
2 changes: 1 addition & 1 deletion src/semantic_code_search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def _search(query_embedding, corpus_embeddings, functions, k=5, file_extension=None):
# TODO: filtering by file extension
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=min(k, len(cos_scores) -1), sorted=True)
top_results = torch.topk(cos_scores, k=min(k, len(cos_scores)), sorted=True)
out = []
for score, idx in zip(top_results[0], top_results[1]):
out.append((score, functions[idx]))
Expand Down