Skip to content

Commit 6e132eb

Browse files
committed
Add optional string array param
1 parent 6f7458a commit 6e132eb

File tree

4 files changed

+107
-11
lines changed

4 files changed

+107
-11
lines changed

Diff for: pkg/github/issues.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,13 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t
286286
}
287287

288288
// Get assignees
289-
assignees, err := optionalParam[[]string](request, "assignees")
289+
assignees, err := optionalStringArrayParam(request, "assignees")
290290
if err != nil {
291291
return mcp.NewToolResultError(err.Error()), nil
292292
}
293293

294294
// Get labels
295-
labels, err := optionalParam[[]string](request, "labels")
295+
labels, err := optionalStringArrayParam(request, "labels")
296296
if err != nil {
297297
return mcp.NewToolResultError(err.Error()), nil
298298
}
@@ -401,7 +401,7 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to
401401
}
402402

403403
// Get labels
404-
opts.Labels, err = optionalParam[[]string](request, "labels")
404+
opts.Labels, err = optionalStringArrayParam(request, "labels")
405405
if err != nil {
406406
return mcp.NewToolResultError(err.Error()), nil
407407
}
@@ -548,7 +548,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t
548548
}
549549

550550
// Get labels
551-
labels, err := optionalParam[[]string](request, "labels")
551+
labels, err := optionalStringArrayParam(request, "labels")
552552
if err != nil {
553553
return mcp.NewToolResultError(err.Error()), nil
554554
}
@@ -557,7 +557,7 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t
557557
}
558558

559559
// Get assignees
560-
assignees, err := optionalParam[[]string](request, "assignees")
560+
assignees, err := optionalStringArrayParam(request, "assignees")
561561
if err != nil {
562562
return mcp.NewToolResultError(err.Error()), nil
563563
}

Diff for: pkg/github/issues_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ func Test_CreateIssue(t *testing.T) {
436436
"repo": "repo",
437437
"title": "Test Issue",
438438
"body": "This is a test issue",
439-
"assignees": []string{"user1", "user2"},
440-
"labels": []string{"bug", "help wanted"},
439+
"assignees": []any{"user1", "user2"},
440+
"labels": []any{"bug", "help wanted"},
441441
"milestone": float64(5),
442442
},
443443
expectError: false,
@@ -636,7 +636,7 @@ func Test_ListIssues(t *testing.T) {
636636
"owner": "owner",
637637
"repo": "repo",
638638
"state": "open",
639-
"labels": []string{"bug", "enhancement"},
639+
"labels": []any{"bug", "enhancement"},
640640
"sort": "created",
641641
"direction": "desc",
642642
"since": "2023-01-01T00:00:00Z",
@@ -790,8 +790,8 @@ func Test_UpdateIssue(t *testing.T) {
790790
"title": "Updated Issue Title",
791791
"body": "Updated issue description",
792792
"state": "closed",
793-
"labels": []string{"bug", "priority"},
794-
"assignees": []string{"assignee1", "assignee2"},
793+
"labels": []any{"bug", "priority"},
794+
"assignees": []any{"assignee1", "assignee2"},
795795
"milestone": float64(5),
796796
},
797797
expectError: false,

Diff for: pkg/github/server.go

+29-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ func optionalParam[T any](r mcp.CallToolRequest, p string) (T, error) {
171171

172172
// Check if the parameter is of the expected type
173173
if _, ok := r.Params.Arguments[p].(T); !ok {
174-
return zero, fmt.Errorf("parameter %s is not of type %T", p, zero)
174+
return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, r.Params.Arguments[p])
175175
}
176176

177177
return r.Params.Arguments[p].(T), nil
@@ -201,3 +201,31 @@ func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e
201201
}
202202
return v, nil
203203
}
204+
205+
// optionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request.
206+
// It does the following checks:
207+
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
208+
// 2. If it is present, iterates the elements and checks each is a string
209+
func optionalStringArrayParam(r mcp.CallToolRequest, p string) ([]string, error) {
210+
// Check if the parameter is present in the request
211+
if _, ok := r.Params.Arguments[p]; !ok {
212+
return []string{}, nil
213+
}
214+
215+
switch v := r.Params.Arguments[p].(type) {
216+
case []string:
217+
return v, nil
218+
case []any:
219+
strSlice := make([]string, len(v))
220+
for i, v := range v {
221+
s, ok := v.(string)
222+
if !ok {
223+
return []string{}, fmt.Errorf("parameter %s is not of type string, is %T", p, v)
224+
}
225+
strSlice[i] = s
226+
}
227+
return strSlice, nil
228+
default:
229+
return []string{}, fmt.Errorf("parameter %s could not be coerced to []string, is %T", p, r.Params.Arguments[p])
230+
}
231+
}

Diff for: pkg/github/server_test.go

+68
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,71 @@ func Test_OptionalBooleanParam(t *testing.T) {
483483
})
484484
}
485485
}
486+
487+
func TestOptionalStringArrayParam(t *testing.T) {
488+
tests := []struct {
489+
name string
490+
params map[string]interface{}
491+
paramName string
492+
expected []string
493+
expectError bool
494+
}{
495+
{
496+
name: "parameter not in request",
497+
params: map[string]any{},
498+
paramName: "flag",
499+
expected: []string{},
500+
expectError: false,
501+
},
502+
{
503+
name: "valid any array parameter",
504+
params: map[string]any{
505+
"flag": []any{"v1", "v2"},
506+
},
507+
paramName: "flag",
508+
expected: []string{"v1", "v2"},
509+
expectError: false,
510+
},
511+
{
512+
name: "valid string array parameter",
513+
params: map[string]any{
514+
"flag": []string{"v1", "v2"},
515+
},
516+
paramName: "flag",
517+
expected: []string{"v1", "v2"},
518+
expectError: false,
519+
},
520+
{
521+
name: "wrong type parameter",
522+
params: map[string]any{
523+
"flag": 1,
524+
},
525+
paramName: "flag",
526+
expected: []string{},
527+
expectError: true,
528+
},
529+
{
530+
name: "wrong slice type parameter",
531+
params: map[string]any{
532+
"flag": []any{"foo", 2},
533+
},
534+
paramName: "flag",
535+
expected: []string{},
536+
expectError: true,
537+
},
538+
}
539+
540+
for _, tc := range tests {
541+
t.Run(tc.name, func(t *testing.T) {
542+
request := createMCPRequest(tc.params)
543+
result, err := optionalStringArrayParam(request, tc.paramName)
544+
545+
if tc.expectError {
546+
assert.Error(t, err)
547+
} else {
548+
assert.NoError(t, err)
549+
assert.Equal(t, tc.expected, result)
550+
}
551+
})
552+
}
553+
}

0 commit comments

Comments
 (0)