diff --git a/merge_conflict_labeler/merge_conflicts/__init__.py b/merge_conflict_labeler/merge_conflicts/__init__.py new file mode 100644 index 0000000..bc775b7 --- /dev/null +++ b/merge_conflict_labeler/merge_conflicts/__init__.py @@ -0,0 +1 @@ +from . import merge_conflicts, utils diff --git a/merge_conflict_labeler/merge_conflicts/__main__.py b/merge_conflict_labeler/merge_conflicts/__main__.py new file mode 100644 index 0000000..ccab834 --- /dev/null +++ b/merge_conflict_labeler/merge_conflicts/__main__.py @@ -0,0 +1,55 @@ +import asyncio +import os +import sys +import traceback + +import aiohttp +import cachetools +from aiohttp import web +from gidgethub import aiohttp as gh_aiohttp +from gidgethub import routing +from gidgethub import sansio + +from . import merge_conflicts + +router = routing.Router(merge_conflicts.router) +cache = cachetools.LRUCache(maxsize=500) + + +async def main(request): + try: + body = await request.read() + secret = os.environ.get("GH_SECRET") + event = sansio.Event.from_http(request.headers, body, secret=secret) + print('GH delivery ID', event.delivery_id, file=sys.stderr) + if event.event == "ping": + return web.Response(status=200) + asyncio.ensure_future(identify_merge_conflicting_prs(event)) + + print("Answering") + return web.Response(status=200) + except Exception as exc: + traceback.print_exc(file=sys.stderr) + return web.Response(status=500) + + +async def identify_merge_conflicting_prs(event): + oauth_token = os.environ.get("GH_AUTH") + async with aiohttp.ClientSession() as session: + gh = gh_aiohttp.GitHubAPI(session, "pablogsal", + oauth_token=oauth_token, + cache=cache) + await router.dispatch(event, gh, session) + try: + print('GH requests remaining:', gh.rate_limit.remaining) + except AttributeError: + pass + + +if __name__ == "__main__": # pragma: no cover + app = web.Application() + app.router.add_post("/", main) + port = os.environ.get("PORT") + if port is not None: + port = int(port) + web.run_app(app, port=port) diff --git a/merge_conflict_labeler/merge_conflicts/merge_conflicts.py b/merge_conflict_labeler/merge_conflicts/merge_conflicts.py new file mode 100644 index 0000000..637ffa2 --- /dev/null +++ b/merge_conflict_labeler/merge_conflicts/merge_conflicts.py @@ -0,0 +1,32 @@ +import logging +import os + +import gidgethub.routing + +from .utils import get_open_prs + +router = gidgethub.routing.Router() +logger = logging.getLogger(__name__) + + +@router.register("pull_request", action="closed") +async def label_conflicting_prs(event, gh, session, *args, **kwargs): + pull_request = event.data["pull_request"] + if pull_request["merged"] is False: + return + + org = os.getenv("ORG") + repo = os.getenv("REPO") + pull_requests = get_open_prs(org, repo, session) + conflicting_pull_requests = [pr async for pr in pull_requests if pr["mergeable"] == 'CONFLICTING'] + needs_rebase_pull_request = [pr for pr in conflicting_pull_requests if "needs rebase" not in pr["labels"]] + + logger.info(f"Identified a total of {len(needs_rebase_pull_request)} pull requests with merge conflicts") + + for pr in needs_rebase_pull_request: + try: + logger.debug(f"Working on pr {pr['number']}") + await gh.post(f'https://api.github.com/repos/{org}/{repo}/issues/{pr["number"]}/labels', + data=["needs_rebase"]) + except Exception as e: + logger.exception(e) diff --git a/merge_conflict_labeler/merge_conflicts/utils.py b/merge_conflict_labeler/merge_conflicts/utils.py new file mode 100644 index 0000000..1b4b517 --- /dev/null +++ b/merge_conflict_labeler/merge_conflicts/utils.py @@ -0,0 +1,47 @@ +import json +import logging +import os +import textwrap +from string import Template + +logger = logging.getLogger(__name__) + +query_template = Template(textwrap.dedent( + r"""query { + repository(owner: "$org", name: "$repo") { + pullRequests($args states: OPEN) { + edges { + cursor + node { + number + url + labels(first: 20) { + nodes { + name + } + } + } + } + } + } + } + """ +)) + + +async def get_open_prs(organization, repo, session): + gh_token = os.getenv("GH_AUTH") + query_args = "first:100," + while True: + first_query = query_template.substitute(args=query_args, org=organization, repo=repo) + async with session.post(url='https://api.github.com/graphql', + data=json.dumps({"query": first_query}), + headers={"Authorization": f"bearer {gh_token}"}) as response: + pull_requests = json.loads(await response.text())["data"]["repository"]["pullRequests"]["edges"] + if not pull_requests: + return + for pr in pull_requests: + pr_info = pr["node"] + pr_info["labels"] = {node["name"] for node in pr_info["labels"]["nodes"]} + yield pr_info + query_args = f'after: "{pull_requests[-1]["cursor"]}", first:100,' diff --git a/merge_conflict_labeler/tests/__init__.py b/merge_conflict_labeler/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/merge_conflict_labeler/tests/test_label_conflicting_prs.py b/merge_conflict_labeler/tests/test_label_conflicting_prs.py new file mode 100644 index 0000000..bd7ac0f --- /dev/null +++ b/merge_conflict_labeler/tests/test_label_conflicting_prs.py @@ -0,0 +1,72 @@ +import unittest.mock as mock + +import pytest +from asynctest import patch +from merge_conflicts.merge_conflicts import label_conflicting_prs + +from .testing_utils import generate + + +@pytest.mark.asyncio +async def test_labeled_pr(): + # GIVEN + gh = mock.Mock() + session = mock.Mock() + pr = mock.Mock() + pr.data = {"pull_request": {"merged": True}} + + # WHEN + with patch("merge_conflicts.merge_conflicts.get_open_prs") as get_open_prs_mock: + get_open_prs_mock.return_value = generate( + [{"number": 2, "mergeable": "CONFLICTING", "labels": ["needs rebase"]}]) + await label_conflicting_prs(pr, gh, session) + + # THEN + assert gh.post.call_count == 0 + + +@pytest.mark.asyncio +async def test_unlabeled(): + # GIVEN + gh = mock.Mock() + session = mock.Mock() + pr = mock.Mock() + pr.data = {"pull_request": {"merged": True}} + + # WHEN + with patch("merge_conflicts.merge_conflicts.get_open_prs") as get_open_prs_mock: + get_open_prs_mock.return_value = generate([{"number": 2, "mergeable": "CONFLICTING", "labels": []}]) + await label_conflicting_prs(pr, gh, session) + + # THEN + assert gh.post.call_count == 1 + + +@pytest.mark.asyncio +async def test_label_creation(): + # GIVEN + gh = mock.Mock() + session = mock.Mock() + pr = mock.Mock() + pr.data = {"pull_request": {"merged": True}} + + prs = [ + {"number": 1, "mergeable": "CONFLICTING", "labels": []}, + {"number": 2, "mergeable": "CONFLICTING", "labels": []}, + {"number": 3, "mergeable": "CONFLICTING", "labels": ["needs rebase"]}, + {"number": 4, "mergeable": "CONFLICTING", "labels": []}, + ] + + # WHEN + with patch("merge_conflicts.merge_conflicts.get_open_prs") as get_open_prs_mock: + get_open_prs_mock.return_value = generate(prs) + await label_conflicting_prs(pr, gh, session) + + # THEN + expected_calls = [ + mock.call('https://api.github.com/repos/None/None/issues/1/labels', data=['needs_rebase']), + mock.call('https://api.github.com/repos/None/None/issues/2/labels', data=['needs_rebase']), + mock.call('https://api.github.com/repos/None/None/issues/4/labels', data=['needs_rebase']) + ] + + assert gh.post.mock_calls == expected_calls diff --git a/merge_conflict_labeler/tests/test_utils.py b/merge_conflict_labeler/tests/test_utils.py new file mode 100644 index 0000000..4518bda --- /dev/null +++ b/merge_conflict_labeler/tests/test_utils.py @@ -0,0 +1,83 @@ +import textwrap +import unittest.mock as mock + +import pytest +from merge_conflicts.utils import get_open_prs + +from .testing_utils import AsyncContextManagerMock, get_back + + +@pytest.mark.asyncio +async def test_get_open_prs(): + # GIVEN + session = mock.Mock() + response = mock.Mock() + post_mock = AsyncContextManagerMock(return_item=response) + session.post.return_value = post_mock + response.text.side_effect = [ + get_back(textwrap.dedent(""" + { + "data": { + "repository": { + "pullRequests": { + "edges": [ + { + "cursor": "Y3Vyc29yOnYyOpHOBk4feQ==", + "node": { + "number": 45, + "url": "https://github.com/python/cpython/pull/45", + "labels": { + "nodes": [ + { + "name": "CLA signed" + }, + { + "name": "type-documentation" + } + ] + } + } + }, + { + "cursor": "Y3Vyc29yOnYyOpHOBk6OZw==", + "node": { + "number": 57, + "url": "https://github.com/python/cpython/pull/57", + "labels": { + "nodes": [ + { + "name": "CLA signed" + } + ] + } + } + } + ] + } + } + } + }""")), + get_back(textwrap.dedent(""" + { + "data": { + "repository": { + "pullRequests": { + "edges": [] + } + } + } + } + """))] + + # WHEN + + response = [pr async for pr in get_open_prs("Python", "cpython", session)] + + # THEN + + expected_response = [{'number': 45, 'url': 'https://github.com/python/cpython/pull/45', + 'labels': {'type-documentation', 'CLA signed'}}, + {'number': 57, 'url': 'https://github.com/python/cpython/pull/57', + 'labels': {'CLA signed'}}] + + assert response == expected_response diff --git a/merge_conflict_labeler/tests/testing_utils.py b/merge_conflict_labeler/tests/testing_utils.py new file mode 100644 index 0000000..99e2a8e --- /dev/null +++ b/merge_conflict_labeler/tests/testing_utils.py @@ -0,0 +1,22 @@ +import unittest.mock as mock + + +class AsyncContextManagerMock(mock.MagicMock): + def __init__(self, return_item, *args, **kwargs): + self.__dict__["return_item"] = return_item + super().__init__(*args, **kwargs) + + async def __aenter__(self): + return self.__dict__["return_item"] + + async def __aexit__(self, *args): + pass + + +async def generate(iterable): + for item in iterable: + yield item + + +async def get_back(item): + return item