Skip to content

feat: add GitHub notifications tools for managing user notifications #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
300 changes: 300 additions & 0 deletions pkg/github/notifications.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
package github

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"time"

"github.com/github/github-mcp-server/pkg/translations"
"github.com/google/go-github/v69/github"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

// getNotifications creates a tool to list notifications for the current user.
func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("get_notifications",
mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")),
mcp.WithBoolean("all",
mcp.Description("If true, show notifications marked as read. Default: false"),
),
mcp.WithBoolean("participating",
mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"),
),
mcp.WithString("since",
mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"),
),
mcp.WithString("before",
mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"),
),
mcp.WithNumber("per_page",
mcp.Description("Results per page (max 100). Default: 30"),
),
mcp.WithNumber("page",
mcp.Description("Page number of the results to fetch. Default: 1"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

// Extract optional parameters with defaults
all, err := OptionalBoolParamWithDefault(request, "all", false)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

participating, err := OptionalBoolParamWithDefault(request, "participating", false)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

since, err := OptionalStringParamWithDefault(request, "since", "")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

before, err := OptionalStringParam(request, "before")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

page, err := OptionalIntParamWithDefault(request, "page", 1)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

// Build options
opts := &github.NotificationListOptions{
All: all,
Participating: participating,
ListOptions: github.ListOptions{
Page: page,
PerPage: perPage,
},
}

// Parse time parameters if provided
if since != "" {
sinceTime, err := time.Parse(time.RFC3339, since)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil
}
opts.Since = sinceTime
}

if before != "" {
beforeTime, err := time.Parse(time.RFC3339, before)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil
}
opts.Before = beforeTime
}

// Call GitHub API
notifications, resp, err := client.Activity.ListNotifications(ctx, opts)
if err != nil {
return nil, fmt.Errorf("failed to get notifications: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil
}

// Marshal response to JSON
r, err := json.Marshal(notifications)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
}
}

// markNotificationRead creates a tool to mark a notification as read.
func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("mark_notification_read",
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")),
mcp.WithString("threadID",
mcp.Required(),
mcp.Description("The ID of the notification thread"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getclient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

threadID, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

resp, err := client.Activity.MarkThreadRead(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
}

return mcp.NewToolResultText("Notification marked as read"), nil
}
}

// MarkAllNotificationsRead creates a tool to mark all notifications as read.
func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("mark_all_notifications_read",
mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")),
mcp.WithString("lastReadAt",
mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

lastReadAt, err := OptionalStringParam(request, "lastReadAt")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

var markReadOptions github.Timestamp
if lastReadAt != "" {
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
}
markReadOptions = github.Timestamp{
Time: lastReadTime,
}
}

resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
if err != nil {
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
}

return mcp.NewToolResultText("All notifications marked as read"), nil
}
}

// GetNotificationThread creates a tool to get a specific notification thread.
func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("get_notification_thread",
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")),
mcp.WithString("threadID",
mcp.Required(),
mcp.Description("The ID of the notification thread"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

threadID, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

thread, resp, err := client.Activity.GetThread(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to get notification thread: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil
}

r, err := json.Marshal(thread)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
}
}

// markNotificationDone creates a tool to mark a notification as done.
func MarkNotificationDone(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("mark_notification_done",
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_DONE_DESCRIPTION", "Mark a notification as done")),
mcp.WithString("threadID",
mcp.Required(),
mcp.Description("The ID of the notification thread"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getclient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

threadIDStr, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

threadID, err := strconv.ParseInt(threadIDStr, 10, 64)
if err != nil {
return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil
}

resp, err := client.Activity.MarkThreadDone(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to mark notification as done: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil
}

return mcp.NewToolResultText("Notification marked as done"), nil
}
}
41 changes: 41 additions & 0 deletions pkg/github/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,47 @@ func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e
return v, nil
}

// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
// similar to optionalParam, but it also takes a default value.
func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) {
v, err := OptionalParam[bool](r, p)
if err != nil {
return false, err
}
if !v {
return d, nil
}
return v, nil
}

// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request.
// It does the following checks:
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
// 2. If it is present, it checks if the parameter is of the expected type and returns it
func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) {
v, err := OptionalParam[string](r, p)
if err != nil {
return "", err
}
if v == "" {
return "", nil
}
return v, nil
}

// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
// similar to optionalParam, but it also takes a default value.
func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) {
v, err := OptionalParam[string](r, p)
if err != nil {
return "", err
}
if v == "" {
return d, nil
}
return v, nil
}

// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request.
// It does the following checks:
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
Expand Down
14 changes: 14 additions & 0 deletions pkg/github/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)),
toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)),
)

notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools").
AddReadTools(

toolsets.NewServerTool(MarkNotificationRead(getClient, t)),
toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)),
toolsets.NewServerTool(MarkNotificationDone(getClient, t)),
).
AddWriteTools(
toolsets.NewServerTool(GetNotifications(getClient, t)),
toolsets.NewServerTool(GetNotificationThread(getClient, t)),
)

// Keep experiments alive so the system doesn't error out when it's always enabled
experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet")

Expand All @@ -88,6 +101,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
tsg.AddToolset(pullRequests)
tsg.AddToolset(codeSecurity)
tsg.AddToolset(secretProtection)
tsg.AddToolset(notifications)
tsg.AddToolset(experiments)
// Enable the requested features

Expand Down