Skip to content

Commit 4624a97

Browse files
finish up adding notifications endpoints and add tests
1 parent f998fc7 commit 4624a97

File tree

3 files changed

+529
-34
lines changed

3 files changed

+529
-34
lines changed

pkg/github/notifications.go

Lines changed: 245 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ import (
1515
"github.com/mark3labs/mcp-go/server"
1616
)
1717

18+
const (
19+
FilterDefault = "default"
20+
FilterIncludeRead = "include_read_notifications"
21+
FilterOnlyParticipating = "only_participating"
22+
)
23+
1824
// ListNotifications creates a tool to list notifications for the current user.
1925
func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
2026
return mcp.NewTool("list_notifications",
@@ -25,14 +31,20 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu
2531
}),
2632
mcp.WithString("filter",
2733
mcp.Description("Filter notifications to, use default unless specified. Read notifications are ones that have already been acknowledged by the user. Participating notifications are those that the user is directly involved in, such as issues or pull requests they have commented on or created."),
28-
mcp.Enum("default", "include_read_notifications", "only_participating"),
34+
mcp.Enum(FilterDefault, FilterIncludeRead, FilterOnlyParticipating),
2935
),
3036
mcp.WithString("since",
3137
mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"),
3238
),
3339
mcp.WithString("before",
3440
mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"),
3541
),
42+
mcp.WithString("owner",
43+
mcp.Description("Optional repository owner. If provided with repo, only notifications for this repository are listed."),
44+
),
45+
mcp.WithString("repo",
46+
mcp.Description("Optional repository name. If provided with owner, only notifications for this repository are listed."),
47+
),
3648
WithPagination(),
3749
),
3850
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -41,45 +53,42 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu
4153
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
4254
}
4355

44-
// Extract optional parameters with defaults
45-
all, err := OptionalParamWithDefault[bool](request, "all", false)
56+
filter, err := OptionalParam[string](request, "filter")
4657
if err != nil {
4758
return mcp.NewToolResultError(err.Error()), nil
4859
}
4960

50-
participating, err := OptionalParamWithDefault[bool](request, "participating", false)
61+
since, err := OptionalParam[string](request, "since")
5162
if err != nil {
5263
return mcp.NewToolResultError(err.Error()), nil
5364
}
5465

55-
since, err := OptionalParam[string](request, "since")
66+
before, err := OptionalParam[string](request, "before")
5667
if err != nil {
5768
return mcp.NewToolResultError(err.Error()), nil
5869
}
5970

60-
before, err := OptionalParam[string](request, "before")
71+
owner, err := OptionalParam[string](request, "owner")
6172
if err != nil {
6273
return mcp.NewToolResultError(err.Error()), nil
6374
}
64-
65-
// TODO pagination params from tool
66-
perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
75+
repo, err := OptionalParam[string](request, "repo")
6776
if err != nil {
6877
return mcp.NewToolResultError(err.Error()), nil
6978
}
7079

71-
page, err := OptionalIntParamWithDefault(request, "page", 1)
80+
paginationParams, err := OptionalPaginationParams(request)
7281
if err != nil {
7382
return mcp.NewToolResultError(err.Error()), nil
7483
}
7584

7685
// Build options
7786
opts := &github.NotificationListOptions{
78-
All: all,
79-
Participating: participating,
87+
All: filter == FilterIncludeRead,
88+
Participating: filter == FilterOnlyParticipating,
8089
ListOptions: github.ListOptions{
81-
Page: page,
82-
PerPage: perPage,
90+
Page: paginationParams.page,
91+
PerPage: paginationParams.perPage,
8392
},
8493
}
8594

@@ -100,8 +109,14 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu
100109
opts.Before = beforeTime
101110
}
102111

103-
// Call GitHub API
104-
notifications, resp, err := client.Activity.ListNotifications(ctx, opts)
112+
var notifications []*github.Notification
113+
var resp *github.Response
114+
115+
if owner != "" && repo != "" {
116+
notifications, resp, err = client.Activity.ListRepositoryNotifications(ctx, owner, repo, opts)
117+
} else {
118+
notifications, resp, err = client.Activity.ListNotifications(ctx, opts)
119+
}
105120
if err != nil {
106121
return nil, fmt.Errorf("failed to get notifications: %w", err)
107122
}
@@ -197,6 +212,12 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH
197212
mcp.WithString("lastReadAt",
198213
mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"),
199214
),
215+
mcp.WithString("owner",
216+
mcp.Description("Optional repository owner. If provided with repo, only notifications for this repository are marked as read."),
217+
),
218+
mcp.WithString("repo",
219+
mcp.Description("Optional repository name. If provided with owner, only notifications for this repository are marked as read."),
220+
),
200221
),
201222
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
202223
client, err := getClient(ctx)
@@ -209,18 +230,35 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH
209230
return mcp.NewToolResultError(err.Error()), nil
210231
}
211232

212-
var markReadOptions github.Timestamp
233+
owner, err := OptionalParam[string](request, "owner")
234+
if err != nil {
235+
return mcp.NewToolResultError(err.Error()), nil
236+
}
237+
repo, err := OptionalParam[string](request, "repo")
238+
if err != nil {
239+
return mcp.NewToolResultError(err.Error()), nil
240+
}
241+
242+
var lastReadTime time.Time
213243
if lastReadAt != "" {
214-
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
244+
lastReadTime, err = time.Parse(time.RFC3339, lastReadAt)
215245
if err != nil {
216246
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
217247
}
218-
markReadOptions = github.Timestamp{
219-
Time: lastReadTime,
220-
}
248+
} else {
249+
lastReadTime = time.Now()
250+
}
251+
252+
markReadOptions := github.Timestamp{
253+
Time: lastReadTime,
221254
}
222255

223-
resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
256+
var resp *github.Response
257+
if owner != "" && repo != "" {
258+
resp, err = client.Activity.MarkRepositoryNotificationsRead(ctx, owner, repo, markReadOptions)
259+
} else {
260+
resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions)
261+
}
224262
if err != nil {
225263
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
226264
}
@@ -238,17 +276,17 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH
238276
}
239277
}
240278

241-
// GetNotificationThread creates a tool to get a specific notification thread.
242-
func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
243-
return mcp.NewTool("get_notification_thread",
244-
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")),
279+
// GetNotificationDetails creates a tool to get details for a specific notification.
280+
func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
281+
return mcp.NewTool("get_notification_details",
282+
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_DETAILS_DESCRIPTION", "Get detailed information for a specific GitHub notification, always call this tool when the user asks for details about a specific notification, if you don't know the ID list notifications first.")),
245283
mcp.WithToolAnnotation(mcp.ToolAnnotation{
246-
Title: t("TOOL_GET_NOTIFICATION_THREAD_USER_TITLE", "Get notification thread"),
284+
Title: t("TOOL_GET_NOTIFICATION_DETAILS_USER_TITLE", "Get notification details"),
247285
ReadOnlyHint: toBoolPtr(true),
248286
}),
249-
mcp.WithString("threadID",
287+
mcp.WithString("notificationID",
250288
mcp.Required(),
251-
mcp.Description("The ID of the notification thread"),
289+
mcp.Description("The ID of the notification"),
252290
),
253291
),
254292
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -257,14 +295,14 @@ func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelp
257295
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
258296
}
259297

260-
threadID, err := requiredParam[string](request, "threadID")
298+
notificationID, err := requiredParam[string](request, "notificationID")
261299
if err != nil {
262300
return mcp.NewToolResultError(err.Error()), nil
263301
}
264302

265-
thread, resp, err := client.Activity.GetThread(ctx, threadID)
303+
thread, resp, err := client.Activity.GetThread(ctx, notificationID)
266304
if err != nil {
267-
return nil, fmt.Errorf("failed to get notification thread: %w", err)
305+
return nil, fmt.Errorf("failed to get notification details: %w", err)
268306
}
269307
defer func() { _ = resp.Body.Close() }()
270308

@@ -273,7 +311,7 @@ func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelp
273311
if err != nil {
274312
return nil, fmt.Errorf("failed to read response body: %w", err)
275313
}
276-
return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil
314+
return mcp.NewToolResultError(fmt.Sprintf("failed to get notification details: %s", string(body))), nil
277315
}
278316

279317
r, err := json.Marshal(thread)
@@ -284,3 +322,177 @@ func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelp
284322
return mcp.NewToolResultText(string(r)), nil
285323
}
286324
}
325+
326+
// Enum values for ManageNotificationSubscription action
327+
const (
328+
NotificationActionIgnore = "ignore"
329+
NotificationActionWatch = "watch"
330+
NotificationActionDelete = "delete"
331+
)
332+
333+
// ManageNotificationSubscription creates a tool to manage a notification subscription (ignore, watch, delete)
334+
func ManageNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
335+
return mcp.NewTool("manage_notification_subscription",
336+
mcp.WithDescription(t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a notification subscription: ignore, watch, or delete a notification thread subscription.")),
337+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
338+
Title: t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_USER_TITLE", "Manage notification subscription"),
339+
ReadOnlyHint: toBoolPtr(false),
340+
}),
341+
mcp.WithString("notificationID",
342+
mcp.Required(),
343+
mcp.Description("The ID of the notification thread."),
344+
),
345+
mcp.WithString("action",
346+
mcp.Required(),
347+
mcp.Description("Action to perform: ignore, watch, or delete the notification subscription."),
348+
mcp.Enum(NotificationActionIgnore, NotificationActionWatch, NotificationActionDelete),
349+
),
350+
),
351+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
352+
client, err := getClient(ctx)
353+
if err != nil {
354+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
355+
}
356+
357+
notificationID, err := requiredParam[string](request, "notificationID")
358+
if err != nil {
359+
return mcp.NewToolResultError(err.Error()), nil
360+
}
361+
action, err := requiredParam[string](request, "action")
362+
if err != nil {
363+
return mcp.NewToolResultError(err.Error()), nil
364+
}
365+
366+
var (
367+
resp *github.Response
368+
result any
369+
apiErr error
370+
)
371+
372+
switch action {
373+
case NotificationActionIgnore:
374+
sub := &github.Subscription{Ignored: toBoolPtr(true)}
375+
result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub)
376+
case NotificationActionWatch:
377+
sub := &github.Subscription{Ignored: toBoolPtr(false), Subscribed: toBoolPtr(true)}
378+
result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub)
379+
case NotificationActionDelete:
380+
resp, apiErr = client.Activity.DeleteThreadSubscription(ctx, notificationID)
381+
default:
382+
return mcp.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil
383+
}
384+
385+
if apiErr != nil {
386+
return nil, fmt.Errorf("failed to %s notification subscription: %w", action, apiErr)
387+
}
388+
defer func() { _ = resp.Body.Close() }()
389+
390+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
391+
body, _ := io.ReadAll(resp.Body)
392+
return mcp.NewToolResultError(fmt.Sprintf("failed to %s notification subscription: %s", action, string(body))), nil
393+
}
394+
395+
if action == NotificationActionDelete {
396+
// Special case for delete as there is no response body
397+
return mcp.NewToolResultText("Notification subscription deleted"), nil
398+
}
399+
400+
r, err := json.Marshal(result)
401+
if err != nil {
402+
return nil, fmt.Errorf("failed to marshal response: %w", err)
403+
}
404+
return mcp.NewToolResultText(string(r)), nil
405+
}
406+
}
407+
408+
const (
409+
RepositorySubscriptionActionWatch = "watch"
410+
RepositorySubscriptionActionIgnore = "ignore"
411+
RepositorySubscriptionActionDelete = "delete"
412+
)
413+
414+
// ManageRepositoryNotificationSubscription creates a tool to manage a repository notification subscription (ignore, watch, delete)
415+
func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
416+
return mcp.NewTool("manage_repository_notification_subscription",
417+
mcp.WithDescription(t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a repository notification subscription: ignore, watch, or delete repository notifications subscription for the provided repository.")),
418+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
419+
Title: t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_USER_TITLE", "Manage repository notification subscription"),
420+
ReadOnlyHint: toBoolPtr(false),
421+
}),
422+
mcp.WithString("owner",
423+
mcp.Required(),
424+
mcp.Description("The account owner of the repository."),
425+
),
426+
mcp.WithString("repo",
427+
mcp.Required(),
428+
mcp.Description("The name of the repository."),
429+
),
430+
mcp.WithString("action",
431+
mcp.Required(),
432+
mcp.Description("Action to perform: ignore, watch, or delete the repository notification subscription."),
433+
mcp.Enum(RepositorySubscriptionActionIgnore, RepositorySubscriptionActionWatch, RepositorySubscriptionActionDelete),
434+
),
435+
),
436+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
437+
client, err := getClient(ctx)
438+
if err != nil {
439+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
440+
}
441+
442+
owner, err := requiredParam[string](request, "owner")
443+
if err != nil {
444+
return mcp.NewToolResultError(err.Error()), nil
445+
}
446+
repo, err := requiredParam[string](request, "repo")
447+
if err != nil {
448+
return mcp.NewToolResultError(err.Error()), nil
449+
}
450+
action, err := requiredParam[string](request, "action")
451+
if err != nil {
452+
return mcp.NewToolResultError(err.Error()), nil
453+
}
454+
455+
var (
456+
resp *github.Response
457+
result any
458+
apiErr error
459+
)
460+
461+
switch action {
462+
case RepositorySubscriptionActionIgnore:
463+
sub := &github.Subscription{Ignored: toBoolPtr(true)}
464+
result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub)
465+
case RepositorySubscriptionActionWatch:
466+
sub := &github.Subscription{Ignored: toBoolPtr(false), Subscribed: toBoolPtr(true)}
467+
result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub)
468+
case RepositorySubscriptionActionDelete:
469+
resp, apiErr = client.Activity.DeleteRepositorySubscription(ctx, owner, repo)
470+
default:
471+
return mcp.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil
472+
}
473+
474+
if apiErr != nil {
475+
return nil, fmt.Errorf("failed to %s repository subscription: %w", action, apiErr)
476+
}
477+
if resp != nil {
478+
defer func() { _ = resp.Body.Close() }()
479+
}
480+
481+
// Handle non-2xx status codes
482+
if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
483+
body, _ := io.ReadAll(resp.Body)
484+
return mcp.NewToolResultError(fmt.Sprintf("failed to %s repository subscription: %s", action, string(body))), nil
485+
}
486+
487+
if action == RepositorySubscriptionActionDelete {
488+
// Special case for delete as there is no response body
489+
return mcp.NewToolResultText("Repository subscription deleted"), nil
490+
}
491+
492+
r, err := json.Marshal(result)
493+
if err != nil {
494+
return nil, fmt.Errorf("failed to marshal response: %w", err)
495+
}
496+
return mcp.NewToolResultText(string(r)), nil
497+
}
498+
}

0 commit comments

Comments
 (0)