diff --git a/README.md b/README.md index 3c9eb670f..b40974e20 100644 --- a/README.md +++ b/README.md @@ -766,6 +766,7 @@ The following sets of tools are available (all are on by default): - `owner`: Repository owner (string, required) - `pullNumber`: Pull request number to update (number, required) - `repo`: Repository name (string, required) + - `reviewers`: GitHub usernames to request reviews from (string[], optional) - `state`: New state (string, optional) - `title`: New title (string, optional) diff --git a/pkg/github/__toolsnaps__/update_pull_request.snap b/pkg/github/__toolsnaps__/update_pull_request.snap index c44d8afa0..25170ed5f 100644 --- a/pkg/github/__toolsnaps__/update_pull_request.snap +++ b/pkg/github/__toolsnaps__/update_pull_request.snap @@ -34,6 +34,13 @@ "description": "Repository name", "type": "string" }, + "reviewers": { + "description": "GitHub usernames to request reviews from", + "items": { + "type": "string" + }, + "type": "array" + }, "state": { "description": "New state", "enum": [ diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 1fa90b403..945783ae1 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -87,7 +87,6 @@ var ( "url": "https://github.com/owner/.github/discussions/4", "category": map[string]any{"name": "General"}, }, - } // Ordered mock responses @@ -190,7 +189,7 @@ var ( "startCursor": "", "endCursor": "", }, - "totalCount": 4, + "totalCount": 4, }, }, }) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index f380843b0..f82117cad 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -241,6 +241,12 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra mcp.WithBoolean("maintainer_can_modify", mcp.Description("Allow maintainer edits"), ), + mcp.WithArray("reviewers", + mcp.Description("GitHub usernames to request reviews from"), + mcp.Items(map[string]interface{}{ + "type": "string", + }), + ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := RequiredParam[string](request, "owner") @@ -256,15 +262,17 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra return mcp.NewToolResultError(err.Error()), nil } + // Check if draft parameter is provided draftProvided := request.GetArguments()["draft"] != nil var draftValue bool if draftProvided { draftValue, err = OptionalParam[bool](request, "draft") if err != nil { - return nil, err + return mcp.NewToolResultError(err.Error()), nil } } + // Build the update struct only with provided fields update := &github.PullRequest{} restUpdateNeeded := false @@ -303,10 +311,18 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra restUpdateNeeded = true } - if !restUpdateNeeded && !draftProvided { + // Handle reviewers separately + reviewers, err := OptionalStringArrayParam(request, "reviewers") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // If no updates, no draft change, and no reviewers, return error early + if !restUpdateNeeded && !draftProvided && len(reviewers) == 0 { return mcp.NewToolResultError("No update parameters provided."), nil } + // Handle REST API updates (title, body, state, base, maintainer_can_modify) if restUpdateNeeded { client, err := getClient(ctx) if err != nil { @@ -332,6 +348,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra } } + // Handle draft status changes using GraphQL if draftProvided { gqlClient, err := getGQLClient(ctx) if err != nil { @@ -397,6 +414,41 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra } } + // Handle reviewer requests + if len(reviewers) > 0 { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + reviewersRequest := github.ReviewersRequest{ + Reviewers: reviewers, + } + + _, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to request reviewers", + resp, + err, + ), nil + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(body))), nil + } + } + + // Get the final state of the PR to return client, err := getClient(ctx) if err != nil { return nil, err diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 179458115..3a99d9f46 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -151,6 +151,7 @@ func Test_UpdatePullRequest(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "state") assert.Contains(t, tool.InputSchema.Properties, "base") assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify") + assert.Contains(t, tool.InputSchema.Properties, "reviewers") assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR for success case @@ -173,6 +174,17 @@ func Test_UpdatePullRequest(t *testing.T) { State: github.Ptr("closed"), // State updated } + // Mock PR for when there are no updates but we still need a response + mockPRWithReviewers := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + RequestedReviewers: []*github.User{ + {Login: github.Ptr("reviewer1")}, + {Login: github.Ptr("reviewer2")}, + }, + } + tests := []struct { name string mockedClient *http.Client @@ -238,6 +250,28 @@ func Test_UpdatePullRequest(t *testing.T) { expectError: false, expectedPR: mockClosedPR, }, + { + name: "successful PR update with reviewers", + mockedClient: mock.NewMockedHTTPClient( + // Mock for RequestReviewers call, returning the PR with reviewers + mock.WithRequestMatch( + mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, + mockPRWithReviewers, + ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockPRWithReviewers, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "reviewers": []interface{}{"reviewer1", "reviewer2"}, + }, + expectError: false, + expectedPR: mockPRWithReviewers, + }, { name: "successful PR update (title only)", mockedClient: mock.NewMockedHTTPClient( @@ -295,6 +329,27 @@ func Test_UpdatePullRequest(t *testing.T) { expectError: true, expectedErrMsg: "failed to update pull request", }, + { + name: "request reviewers fails", + mockedClient: mock.NewMockedHTTPClient( + // Then reviewer request fails + mock.WithRequestMatchHandler( + mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid reviewers"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "reviewers": []interface{}{"invalid-user"}, + }, + expectError: true, + expectedErrMsg: "failed to request reviewers", + }, } for _, tc := range tests { @@ -347,6 +402,26 @@ func Test_UpdatePullRequest(t *testing.T) { if tc.expectedPR.MaintainerCanModify != nil { assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify) } + + // Check reviewers if they exist in the expected PR + if len(tc.expectedPR.RequestedReviewers) > 0 { + assert.NotNil(t, returnedPR.RequestedReviewers) + assert.Equal(t, len(tc.expectedPR.RequestedReviewers), len(returnedPR.RequestedReviewers)) + + // Create maps of reviewer logins for easy comparison + expectedReviewers := make(map[string]bool) + for _, reviewer := range tc.expectedPR.RequestedReviewers { + expectedReviewers[*reviewer.Login] = true + } + + actualReviewers := make(map[string]bool) + for _, reviewer := range returnedPR.RequestedReviewers { + actualReviewers[*reviewer.Login] = true + } + + // Compare the maps + assert.Equal(t, expectedReviewers, actualReviewers) + } }) } } @@ -529,9 +604,6 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { err = json.Unmarshal([]byte(textContent.Text), &returnedPR) require.NoError(t, err) assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) - if tc.expectedPR.Draft != nil { - assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft) - } }) } }