Skip to content

Prototype for bot that adds label to PRs with merge conficts #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions merge_conflict_labeler/merge_conflicts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import merge_conflicts, utils
55 changes: 55 additions & 0 deletions merge_conflict_labeler/merge_conflicts/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio
import os
import sys
import traceback

import aiohttp
import cachetools
from aiohttp import web
from gidgethub import aiohttp as gh_aiohttp
from gidgethub import routing
from gidgethub import sansio

from . import merge_conflicts

router = routing.Router(merge_conflicts.router)
cache = cachetools.LRUCache(maxsize=500)


async def main(request):
try:
body = await request.read()
secret = os.environ.get("GH_SECRET")
event = sansio.Event.from_http(request.headers, body, secret=secret)
print('GH delivery ID', event.delivery_id, file=sys.stderr)
if event.event == "ping":
return web.Response(status=200)
asyncio.ensure_future(identify_merge_conflicting_prs(event))

print("Answering")
return web.Response(status=200)
except Exception as exc:
traceback.print_exc(file=sys.stderr)
return web.Response(status=500)


async def identify_merge_conflicting_prs(event):
oauth_token = os.environ.get("GH_AUTH")
async with aiohttp.ClientSession() as session:
gh = gh_aiohttp.GitHubAPI(session, "pablogsal",
oauth_token=oauth_token,
cache=cache)
await router.dispatch(event, gh, session)
try:
print('GH requests remaining:', gh.rate_limit.remaining)
except AttributeError:
pass


if __name__ == "__main__": # pragma: no cover
app = web.Application()
app.router.add_post("/", main)
port = os.environ.get("PORT")
if port is not None:
port = int(port)
web.run_app(app, port=port)
32 changes: 32 additions & 0 deletions merge_conflict_labeler/merge_conflicts/merge_conflicts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
import os

import gidgethub.routing

from .utils import get_open_prs

router = gidgethub.routing.Router()
logger = logging.getLogger(__name__)


@router.register("pull_request", action="closed")
async def label_conflicting_prs(event, gh, session, *args, **kwargs):
pull_request = event.data["pull_request"]
if pull_request["merged"] is False:
return

org = os.getenv("ORG")
repo = os.getenv("REPO")
pull_requests = get_open_prs(org, repo, session)
conflicting_pull_requests = [pr async for pr in pull_requests if pr["mergeable"] == 'CONFLICTING']
needs_rebase_pull_request = [pr for pr in conflicting_pull_requests if "needs rebase" not in pr["labels"]]

logger.info(f"Identified a total of {len(needs_rebase_pull_request)} pull requests with merge conflicts")

for pr in needs_rebase_pull_request:
try:
logger.debug(f"Working on pr {pr['number']}")
await gh.post(f'https://github.com/api/repos/{org}/{repo}/issues/{pr["number"]}/labels',
data=["needs_rebase"])
except Exception as e:
logger.exception(e)
47 changes: 47 additions & 0 deletions merge_conflict_labeler/merge_conflicts/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json
import logging
import os
import textwrap
from string import Template

logger = logging.getLogger(__name__)

query_template = Template(textwrap.dedent(
r"""query {
repository(owner: "$org", name: "$repo") {
pullRequests($args states: OPEN) {
edges {
cursor
node {
number
url
labels(first: 20) {
nodes {
name
}
}
}
}
}
}
}
"""
))


async def get_open_prs(organization, repo, session):
gh_token = os.getenv("GH_AUTH")
query_args = "first:100,"
while True:
first_query = query_template.substitute(args=query_args, org=organization, repo=repo)
async with session.post(url='https://github.com/api/graphql',
data=json.dumps({"query": first_query}),
headers={"Authorization": f"bearer {gh_token}"}) as response:
pull_requests = json.loads(await response.text())["data"]["repository"]["pullRequests"]["edges"]
if not pull_requests:
return
for pr in pull_requests:
pr_info = pr["node"]
pr_info["labels"] = {node["name"] for node in pr_info["labels"]["nodes"]}
yield pr_info
query_args = f'after: "{pull_requests[-1]["cursor"]}", first:100,'
Empty file.
72 changes: 72 additions & 0 deletions merge_conflict_labeler/tests/test_label_conflicting_prs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import unittest.mock as mock

import pytest
from asynctest import patch
from merge_conflicts.merge_conflicts import label_conflicting_prs

from .testing_utils import generate


@pytest.mark.asyncio
async def test_labeled_pr():
# GIVEN
gh = mock.Mock()
session = mock.Mock()
pr = mock.Mock()
pr.data = {"pull_request": {"merged": True}}

# WHEN
with patch("merge_conflicts.merge_conflicts.get_open_prs") as get_open_prs_mock:
get_open_prs_mock.return_value = generate(
[{"number": 2, "mergeable": "CONFLICTING", "labels": ["needs rebase"]}])
await label_conflicting_prs(pr, gh, session)

# THEN
assert gh.post.call_count == 0


@pytest.mark.asyncio
async def test_unlabeled():
# GIVEN
gh = mock.Mock()
session = mock.Mock()
pr = mock.Mock()
pr.data = {"pull_request": {"merged": True}}

# WHEN
with patch("merge_conflicts.merge_conflicts.get_open_prs") as get_open_prs_mock:
get_open_prs_mock.return_value = generate([{"number": 2, "mergeable": "CONFLICTING", "labels": []}])
await label_conflicting_prs(pr, gh, session)

# THEN
assert gh.post.call_count == 1


@pytest.mark.asyncio
async def test_label_creation():
# GIVEN
gh = mock.Mock()
session = mock.Mock()
pr = mock.Mock()
pr.data = {"pull_request": {"merged": True}}

prs = [
{"number": 1, "mergeable": "CONFLICTING", "labels": []},
{"number": 2, "mergeable": "CONFLICTING", "labels": []},
{"number": 3, "mergeable": "CONFLICTING", "labels": ["needs rebase"]},
{"number": 4, "mergeable": "CONFLICTING", "labels": []},
]

# WHEN
with patch("merge_conflicts.merge_conflicts.get_open_prs") as get_open_prs_mock:
get_open_prs_mock.return_value = generate(prs)
await label_conflicting_prs(pr, gh, session)

# THEN
expected_calls = [
mock.call('https://github.com/api/repos/None/None/issues/1/labels', data=['needs_rebase']),
mock.call('https://github.com/api/repos/None/None/issues/2/labels', data=['needs_rebase']),
mock.call('https://github.com/api/repos/None/None/issues/4/labels', data=['needs_rebase'])
]

assert gh.post.mock_calls == expected_calls
83 changes: 83 additions & 0 deletions merge_conflict_labeler/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import textwrap
import unittest.mock as mock

import pytest
from merge_conflicts.utils import get_open_prs

from .testing_utils import AsyncContextManagerMock, get_back


@pytest.mark.asyncio
async def test_get_open_prs():
# GIVEN
session = mock.Mock()
response = mock.Mock()
post_mock = AsyncContextManagerMock(return_item=response)
session.post.return_value = post_mock
response.text.side_effect = [
get_back(textwrap.dedent("""
{
"data": {
"repository": {
"pullRequests": {
"edges": [
{
"cursor": "Y3Vyc29yOnYyOpHOBk4feQ==",
"node": {
"number": 45,
"url": "https://github.com/python/cpython/pull/45",
"labels": {
"nodes": [
{
"name": "CLA signed"
},
{
"name": "type-documentation"
}
]
}
}
},
{
"cursor": "Y3Vyc29yOnYyOpHOBk6OZw==",
"node": {
"number": 57,
"url": "https://github.com/python/cpython/pull/57",
"labels": {
"nodes": [
{
"name": "CLA signed"
}
]
}
}
}
]
}
}
}
}""")),
get_back(textwrap.dedent("""
{
"data": {
"repository": {
"pullRequests": {
"edges": []
}
}
}
}
"""))]

# WHEN

response = [pr async for pr in get_open_prs("Python", "cpython", session)]

# THEN

expected_response = [{'number': 45, 'url': 'https://github.com/python/cpython/pull/45',
'labels': {'type-documentation', 'CLA signed'}},
{'number': 57, 'url': 'https://github.com/python/cpython/pull/57',
'labels': {'CLA signed'}}]

assert response == expected_response
22 changes: 22 additions & 0 deletions merge_conflict_labeler/tests/testing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import unittest.mock as mock


class AsyncContextManagerMock(mock.MagicMock):
def __init__(self, return_item, *args, **kwargs):
self.__dict__["return_item"] = return_item
super().__init__(*args, **kwargs)

async def __aenter__(self):
return self.__dict__["return_item"]

async def __aexit__(self, *args):
pass


async def generate(iterable):
for item in iterable:
yield item


async def get_back(item):
return item