From e6e4a211bd862028d4606d6a3d2b5146ddc882a4 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sun, 30 Mar 2025 01:36:28 +0100 Subject: [PATCH] feat: enable secret protection --- pkg/github/repositories.go | 151 +++++++++++++++++++++++++++++++++++++ pkg/github/server.go | 7 +- 2 files changed, 157 insertions(+), 1 deletion(-) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index e4302b88..9686e975 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -1,6 +1,7 @@ package github import ( + "bytes" "context" "encoding/json" "fmt" @@ -610,3 +611,153 @@ func pushFiles(client *github.Client, t translations.TranslationHelperFunc) (too return mcp.NewToolResultText(string(r)), nil } } + +func securityFeatureToggle(isEnabled bool) map[string]string { + if isEnabled { + return map[string]string{"status": "enabled"} + } + return map[string]string{"status": "disabled"} +} + +func toggleSecretProtectionFeatures(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("toggle_secret_protection_features", + mcp.WithDescription(t("TOOL_TOGGLE_SECRET_PROTECTION_FEATURES_DESCRIPTION", "Enable or disable Secret Protection features for a repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithBoolean("secret_scanning", + mcp.Required(), + mcp.Description("Enable or disable secret scanning"), + ), + mcp.WithBoolean("secret_scanning_push_protection", + mcp.Required(), + mcp.Description("Enable or disable secret scanning push protection"), + ), + mcp.WithBoolean("secret_scanning_ai_detection", + mcp.Required(), + mcp.Description("Enable or disable secret scanning AI detection"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + secretScanningEnabled, err := requiredParam[bool](request, "secret_scanning") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pushProtectionEnabled, err := optionalParam[bool](request, "secret_scanning_push_protection") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + aiDetectionEnabled, err := optionalParam[bool](request, "secret_scanning_ai_detection") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + securityAndAnalysis := map[string]map[string]string{ + "secret_scanning": securityFeatureToggle(secretScanningEnabled), + "secret_scanning_push_protection": securityFeatureToggle(pushProtectionEnabled), + "secret_scanning_ai_detection": securityFeatureToggle(aiDetectionEnabled), + } + + requestBody := map[string]interface{}{ + "security_and_analysis": securityAndAnalysis, + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequest( + "PATCH", + fmt.Sprintf("%srepos/%s/%s", client.BaseURL.String(), owner, repo), + bytes.NewBuffer(jsonBody), + ) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Client().Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to toggle Secret Protection Features: %s", string(body))), nil + } + + return mcp.NewToolResultText("Secret Protection features toggled successfully"), nil + } +} + +// getRepositorySettings creates a tool to get repository settings including security features +func getRepositorySettings(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_repository_settings", + mcp.WithDescription(t("TOOL_GET_REPOSITORY_SETTINGS_DESCRIPTION", "Get repository settings including security features")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + req, err := http.NewRequest( + "GET", + fmt.Sprintf("%srepos/%s/%s", client.BaseURL.String(), owner, repo), + nil, + ) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Client().Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get security analysis settings: %s", string(body))), nil + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + return mcp.NewToolResultText(string(body)), nil + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index c01e0918..076bd01c 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "reflect" "strings" "github.com/github/github-mcp-server/pkg/translations" @@ -58,6 +59,7 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH } // Add GitHub tools - Repositories + s.AddTool(getRepositorySettings(client, t)) s.AddTool(searchRepositories(client, t)) s.AddTool(getFileContents(client, t)) s.AddTool(listCommits(client, t)) @@ -67,6 +69,7 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH s.AddTool(forkRepository(client, t)) s.AddTool(createBranch(client, t)) s.AddTool(pushFiles(client, t)) + s.AddTool(toggleSecretProtectionFeatures(client, t)) } // Add GitHub tools - Search @@ -157,7 +160,9 @@ func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) { return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) } - if r.Params.Arguments[p].(T) == zero { + // Check if the parameter is not empty, i.e: non-zero value + // Note: This check is not applicable for bool type, as false is a valid value + if r.Params.Arguments[p].(T) == zero && reflect.TypeOf(zero).Kind() != reflect.Bool { return zero, fmt.Errorf("missing required parameter: %s", p) }