From bdb59e0e34d1a08f2dc6ea1afafb6b0d287f7cca Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Fri, 11 Apr 2025 11:48:42 -0400 Subject: [PATCH 1/5] feat: add GitHub notifications tools for managing user notifications --- pkg/github/notifications.go | 236 ++++++++++++++++++++++++++++++++++++ pkg/github/server.go | 22 ++++ 2 files changed, 258 insertions(+) create mode 100644 pkg/github/notifications.go diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go new file mode 100644 index 00000000..9e32c143 --- /dev/null +++ b/pkg/github/notifications.go @@ -0,0 +1,236 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// getNotifications creates a tool to list notifications for the current user. +func getNotifications(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_notifications", + mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")), + mcp.WithBoolean("all", + mcp.Description("If true, show notifications marked as read. Default: false"), + ), + mcp.WithBoolean("participating", + mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"), + ), + mcp.WithString("since", + mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"), + ), + mcp.WithString("before", + mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"), + ), + mcp.WithNumber("per_page", + mcp.Description("Results per page (max 100). Default: 30"), + ), + mcp.WithNumber("page", + mcp.Description("Page number of the results to fetch. Default: 1"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract optional parameters with defaults + all, err := optionalParamWithDefault[bool](request, "all", false) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + participating, err := optionalParamWithDefault[bool](request, "participating", false) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + since, err := optionalParam[string](request, "since") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + before, err := optionalParam[string](request, "before") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + perPage, err := optionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + page, err := optionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Build options + opts := &github.NotificationListOptions{ + All: all, + Participating: participating, + ListOptions: github.ListOptions{ + Page: page, + PerPage: perPage, + }, + } + + // Parse time parameters if provided + if since != "" { + sinceTime, err := time.Parse(time.RFC3339, since) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil + } + opts.Since = sinceTime + } + + if before != "" { + beforeTime, err := time.Parse(time.RFC3339, before) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil + } + opts.Before = beforeTime + } + + // Call GitHub API + notifications, resp, err := client.Activity.ListNotifications(ctx, opts) + if err != nil { + return nil, fmt.Errorf("failed to get notifications: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil + } + + // Marshal response to JSON + r, err := json.Marshal(notifications) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +// markNotificationRead creates a tool to mark a notification as read. +func markNotificationRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("mark_notification_read", + mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")), + mcp.WithString("threadID", + mcp.Required(), + mcp.Description("The ID of the notification thread"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + threadID, err := requiredParam[string](request, "threadID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + resp, err := client.Activity.MarkThreadRead(ctx, threadID) + if err != nil { + return nil, fmt.Errorf("failed to mark notification as read: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil + } + + return mcp.NewToolResultText("Notification marked as read"), nil + } +} + +// markAllNotificationsRead creates a tool to mark all notifications as read. +func markAllNotificationsRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("mark_all_notifications_read", + mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")), + mcp.WithString("lastReadAt", + mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + lastReadAt, err := optionalParam[string](request, "lastReadAt") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + var markReadOptions github.Timestamp + if lastReadAt != "" { + lastReadTime, err := time.Parse(time.RFC3339, lastReadAt) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil + } + markReadOptions = github.Timestamp{ + Time: lastReadTime, + } + } + + resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions) + if err != nil { + return nil, fmt.Errorf("failed to mark all notifications as read: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil + } + + return mcp.NewToolResultText("All notifications marked as read"), nil + } +} + +// getNotificationThread creates a tool to get a specific notification thread. +func getNotificationThread(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_notification_thread", + mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")), + mcp.WithString("threadID", + mcp.Required(), + mcp.Description("The ID of the notification thread"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + threadID, err := requiredParam[string](request, "threadID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + thread, resp, err := client.Activity.GetThread(ctx, threadID) + if err != nil { + return nil, fmt.Errorf("failed to get notification thread: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil + } + + r, err := json.Marshal(thread) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index 66dbfd1c..35a17f01 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -77,6 +77,14 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH // Add GitHub tools - Code Scanning s.AddTool(getCodeScanningAlert(client, t)) s.AddTool(listCodeScanningAlerts(client, t)) + + // Add GitHub tools - Notifications + s.AddTool(getNotifications(client, t)) + s.AddTool(getNotificationThread(client, t)) + if !readOnly { + s.AddTool(markNotificationRead(client, t)) + s.AddTool(markAllNotificationsRead(client, t)) + } return s } @@ -189,6 +197,20 @@ func optionalIntParam(r mcp.CallToolRequest, p string) (int, error) { return int(v), nil } +// optionalParamWithDefault is a generic helper function that can be used to fetch a requested parameter from the request +// with a default value if the parameter is not provided or is zero value. +func optionalParamWithDefault[T comparable](r mcp.CallToolRequest, p string, d T) (T, error) { + var zero T + v, err := optionalParam[T](r, p) + if err != nil { + return zero, err + } + if v == zero { + return d, nil + } + return v, nil +} + // optionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request // similar to optionalIntParam, but it also takes a default value. func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { From dbdef790eb833edb954c3ea8890c2f172368f07d Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Sat, 12 Apr 2025 10:36:01 -0400 Subject: [PATCH 2/5] refactor: update notification functions to use GetClientFn . Fix conflicts --- pkg/github/notifications.go | 46 +++++++++++++++++++-------- pkg/github/server.go | 63 +++++++++++++++++++++++++++---------- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 9e32c143..e040f6ef 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -15,7 +15,7 @@ import ( ) // getNotifications creates a tool to list notifications for the current user. -func getNotifications(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_notifications", mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")), mcp.WithBoolean("all", @@ -38,33 +38,38 @@ func getNotifications(client *github.Client, t translations.TranslationHelperFun ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + // Extract optional parameters with defaults - all, err := optionalParamWithDefault[bool](request, "all", false) + all, err := OptionalBoolParamWithDefault(request, "all", false) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - participating, err := optionalParamWithDefault[bool](request, "participating", false) + participating, err := OptionalBoolParamWithDefault(request, "participating", false) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - since, err := optionalParam[string](request, "since") + since, err := OptionalStringParamWithDefault(request, "since", "") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - before, err := optionalParam[string](request, "before") + before, err := OptionalStringParam(request, "before") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - perPage, err := optionalIntParamWithDefault(request, "per_page", 30) + perPage, err := OptionalIntParamWithDefault(request, "per_page", 30) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - page, err := optionalIntParamWithDefault(request, "page", 1) + page, err := OptionalIntParamWithDefault(request, "page", 1) if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -122,7 +127,7 @@ func getNotifications(client *github.Client, t translations.TranslationHelperFun } // markNotificationRead creates a tool to mark a notification as read. -func markNotificationRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("mark_notification_read", mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")), mcp.WithString("threadID", @@ -131,6 +136,11 @@ func markNotificationRead(client *github.Client, t translations.TranslationHelpe ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getclient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + threadID, err := requiredParam[string](request, "threadID") if err != nil { return mcp.NewToolResultError(err.Error()), nil @@ -154,8 +164,8 @@ func markNotificationRead(client *github.Client, t translations.TranslationHelpe } } -// markAllNotificationsRead creates a tool to mark all notifications as read. -func markAllNotificationsRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// MarkAllNotificationsRead creates a tool to mark all notifications as read. +func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("mark_all_notifications_read", mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")), mcp.WithString("lastReadAt", @@ -163,7 +173,12 @@ func markAllNotificationsRead(client *github.Client, t translations.TranslationH ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - lastReadAt, err := optionalParam[string](request, "lastReadAt") + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + lastReadAt, err := OptionalStringParam(request, "lastReadAt") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -197,8 +212,8 @@ func markAllNotificationsRead(client *github.Client, t translations.TranslationH } } -// getNotificationThread creates a tool to get a specific notification thread. -func getNotificationThread(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +// GetNotificationThread creates a tool to get a specific notification thread. +func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_notification_thread", mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")), mcp.WithString("threadID", @@ -207,6 +222,11 @@ func getNotificationThread(client *github.Client, t translations.TranslationHelp ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + threadID, err := requiredParam[string](request, "threadID") if err != nil { return mcp.NewToolResultError(err.Error()), nil diff --git a/pkg/github/server.go b/pkg/github/server.go index 63772ee1..2a1b1fd0 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -91,12 +91,14 @@ func NewServer(getClient GetClientFn, version string, readOnly bool, t translati s.AddTool(GetCodeScanningAlert(getClient, t)) s.AddTool(ListCodeScanningAlerts(getClient, t)) - // Add GitHub tools - Notifications + // Add GitHub tools - Notifications + s.AddTool(GetNotifications(getClient, t)) + s.AddTool(GetNotificationThread(getClient, t)) if !readOnly { - s.AddTool(markNotificationRead(client, t)) - s.AddTool(markAllNotificationsRead(client, t)) + s.AddTool(MarkNotificationRead(getClient, t)) + s.AddTool(MarkAllNotificationsRead(getClient, t)) } - + return s } @@ -237,28 +239,55 @@ func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) { return int(v), nil } -// optionalParamWithDefault is a generic helper function that can be used to fetch a requested parameter from the request -// with a default value if the parameter is not provided or is zero value. -func optionalParamWithDefault[T comparable](r mcp.CallToolRequest, p string, d T) (T, error) { - var zero T - v, err := optionalParam[T](r, p) +// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalIntParam, but it also takes a default value. +func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { + v, err := OptionalIntParam(r, p) if err != nil { - return zero, err + return 0, err } - if v == zero { + if v == 0 { return d, nil } return v, nil } -// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request -// similar to optionalIntParam, but it also takes a default value. -func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { - v, err := OptionalIntParam(r, p) +// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalParam, but it also takes a default value. +func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) { + v, err := OptionalParam[bool](r, p) if err != nil { - return 0, err + return false, err } - if v == 0 { + if v == false { + return d, nil + } + return v, nil +} + +// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) { + v, err := OptionalParam[string](r, p) + if err != nil { + return "", err + } + if v == "" { + return "", nil + } + return v, nil +} + +// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalParam, but it also takes a default value. +func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) { + v, err := OptionalParam[string](r, p) + if err != nil { + return "", err + } + if v == "" { return d, nil } return v, nil From f75c234d83df2ea4f0f53f3527e8a9312f2e5ed6 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Sat, 12 Apr 2025 11:14:20 -0400 Subject: [PATCH 3/5] lint: simplify boolean check in OptionalBoolParamWithDefault function --- pkg/github/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/github/server.go b/pkg/github/server.go index 2a1b1fd0..114ae066 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -259,7 +259,7 @@ func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool if err != nil { return false, err } - if v == false { + if !v { return d, nil } return v, nil From 6eaa8d3fca0dd085665cac93741c6f8aa1ba9002 Mon Sep 17 00:00:00 2001 From: Ricardo Fearing <9965014+rfearing@users.noreply.github.com> Date: Wed, 16 Apr 2025 09:13:43 -0400 Subject: [PATCH 4/5] Notifications Mark as done with number implementation (#270) --- pkg/github/notifications.go | 44 +++++++++++++++++++++++++++++++++++++ pkg/github/server.go | 1 + 2 files changed, 45 insertions(+) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index e040f6ef..d7252e39 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "strconv" "time" "github.com/github/github-mcp-server/pkg/translations" @@ -254,3 +255,46 @@ func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelp return mcp.NewToolResultText(string(r)), nil } } + +// markNotificationDone creates a tool to mark a notification as done. +func MarkNotificationDone(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("mark_notification_done", + mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_DONE_DESCRIPTION", "Mark a notification as done")), + mcp.WithString("threadID", + mcp.Required(), + mcp.Description("The ID of the notification thread"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getclient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + threadIDStr, err := requiredParam[string](request, "threadID") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + threadID, err := strconv.ParseInt(threadIDStr, 10, 64) + if err != nil { + return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil + } + + resp, err := client.Activity.MarkThreadDone(ctx, threadID) + if err != nil { + return nil, fmt.Errorf("failed to mark notification as done: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil + } + + return mcp.NewToolResultText("Notification marked as done"), nil + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index 114ae066..c17e4a33 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -97,6 +97,7 @@ func NewServer(getClient GetClientFn, version string, readOnly bool, t translati if !readOnly { s.AddTool(MarkNotificationRead(getClient, t)) s.AddTool(MarkAllNotificationsRead(getClient, t)) + s.AddTool(MarkNotificationDone(getClient, t)) } return s From bc897a1518962d0637435b4e6823362fbdb18046 Mon Sep 17 00:00:00 2001 From: Avinash Sridhar Date: Fri, 18 Apr 2025 17:16:56 -0400 Subject: [PATCH 5/5] Fix merge conflicts --- pkg/github/server.go | 257 +++++++++++++++++++++++++++++++++++++++++++ pkg/github/tools.go | 14 +++ 2 files changed, 271 insertions(+) diff --git a/pkg/github/server.go b/pkg/github/server.go index e69de29b..c51e4732 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -0,0 +1,257 @@ +package github + +import ( + "errors" + "fmt" + + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// NewServer creates a new GitHub MCP server with the specified GH client and logger. + +func NewServer(version string, opts ...server.ServerOption) *server.MCPServer { + // Add default options + defaultOpts := []server.ServerOption{ + server.WithToolCapabilities(true), + server.WithResourceCapabilities(true, true), + server.WithLogging(), + } + opts = append(defaultOpts, opts...) + + // Create a new MCP server + s := server.NewMCPServer( + "github-mcp-server", + version, + opts..., + ) + return s +} + +// OptionalParamOK is a helper function that can be used to fetch a requested parameter from the request. +// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong. +func OptionalParamOK[T any](r mcp.CallToolRequest, p string) (value T, ok bool, err error) { + // Check if the parameter is present in the request + val, exists := r.Params.Arguments[p] + if !exists { + // Not present, return zero value, false, no error + return + } + + // Check if the parameter is of the expected type + value, ok = val.(T) + if !ok { + // Present but wrong type + err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val) + ok = true // Set ok to true because the parameter *was* present, even if wrong type + return + } + + // Present and correct type + ok = true + return +} + +// isAcceptedError checks if the error is an accepted error. +func isAcceptedError(err error) bool { + var acceptedError *github.AcceptedError + return errors.As(err, &acceptedError) +} + +// requiredParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return zero, fmt.Errorf("missing required parameter: %s", p) + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) + } + + if r.Params.Arguments[p].(T) == zero { + return zero, fmt.Errorf("missing required parameter: %s", p) + + } + + return r.Params.Arguments[p].(T), nil +} + +// RequiredInt is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +func RequiredInt(r mcp.CallToolRequest, p string) (int, error) { + v, err := requiredParam[float64](r, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// OptionalParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { + var zero T + + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return zero, nil + } + + // Check if the parameter is of the expected type + if _, ok := r.Params.Arguments[p].(T); !ok { + return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, r.Params.Arguments[p]) + } + + return r.Params.Arguments[p].(T), nil +} + +// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) { + v, err := OptionalParam[float64](r, p) + if err != nil { + return 0, err + } + return int(v), nil +} + +// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalIntParam, but it also takes a default value. +func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { + v, err := OptionalIntParam(r, p) + if err != nil { + return 0, err + } + if v == 0 { + return d, nil + } + return v, nil +} + +// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalParam, but it also takes a default value. +func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) { + v, err := OptionalParam[bool](r, p) + if err != nil { + return false, err + } + if !v { + return d, nil + } + return v, nil +} + +// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type and returns it +func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) { + v, err := OptionalParam[string](r, p) + if err != nil { + return "", err + } + if v == "" { + return "", nil + } + return v, nil +} + +// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request +// similar to optionalParam, but it also takes a default value. +func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) { + v, err := OptionalParam[string](r, p) + if err != nil { + return "", err + } + if v == "" { + return d, nil + } + return v, nil +} + +// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, iterates the elements and checks each is a string +func OptionalStringArrayParam(r mcp.CallToolRequest, p string) ([]string, error) { + // Check if the parameter is present in the request + if _, ok := r.Params.Arguments[p]; !ok { + return []string{}, nil + } + + switch v := r.Params.Arguments[p].(type) { + case nil: + return []string{}, nil + case []string: + return v, nil + case []any: + strSlice := make([]string, len(v)) + for i, v := range v { + s, ok := v.(string) + if !ok { + return []string{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v) + } + strSlice[i] = s + } + return strSlice, nil + default: + return []string{}, fmt.Errorf("parameter %s could not be coerced to []string, is %T", p, r.Params.Arguments[p]) + } +} + +// WithPagination returns a ToolOption that adds "page" and "perPage" parameters to the tool. +// The "page" parameter is optional, min 1. The "perPage" parameter is optional, min 1, max 100. +func WithPagination() mcp.ToolOption { + return func(tool *mcp.Tool) { + mcp.WithNumber("page", + mcp.Description("Page number for pagination (min 1)"), + mcp.Min(1), + )(tool) + + mcp.WithNumber("perPage", + mcp.Description("Results per page for pagination (min 1, max 100)"), + mcp.Min(1), + mcp.Max(100), + )(tool) + } +} + +type PaginationParams struct { + page int + perPage int +} + +// OptionalPaginationParams returns the "page" and "perPage" parameters from the request, +// or their default values if not present, "page" default is 1, "perPage" default is 30. +// In future, we may want to make the default values configurable, or even have this +// function returned from `withPagination`, where the defaults are provided alongside +// the min/max values. +func OptionalPaginationParams(r mcp.CallToolRequest) (PaginationParams, error) { + page, err := OptionalIntParamWithDefault(r, "page", 1) + if err != nil { + return PaginationParams{}, err + } + perPage, err := OptionalIntParamWithDefault(r, "perPage", 30) + if err != nil { + return PaginationParams{}, err + } + return PaginationParams{ + page: page, + perPage: perPage, + }, nil +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 35dabaef..fd0f231b 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -78,6 +78,19 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)), toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)), ) + + notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools"). + AddReadTools( + + toolsets.NewServerTool(MarkNotificationRead(getClient, t)), + toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)), + toolsets.NewServerTool(MarkNotificationDone(getClient, t)), + ). + AddWriteTools( + toolsets.NewServerTool(GetNotifications(getClient, t)), + toolsets.NewServerTool(GetNotificationThread(getClient, t)), + ) + // Keep experiments alive so the system doesn't error out when it's always enabled experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet") @@ -88,6 +101,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, tsg.AddToolset(pullRequests) tsg.AddToolset(codeSecurity) tsg.AddToolset(secretProtection) + tsg.AddToolset(notifications) tsg.AddToolset(experiments) // Enable the requested features