From 67a1e1e7fe231eb19b1331a4beedd64065c2b56e Mon Sep 17 00:00:00 2001 From: silverwind Date: Wed, 25 Feb 2026 23:28:14 +0000 Subject: [PATCH] feat: accept string values for all numeric input parameters (#138) ## Summary - MCP clients may send numbers as strings. This adds `ToInt64` and `GetOptionalInt` helpers to `pkg/params` and replaces all raw `.(float64)` type assertions across operation handlers to accept both `float64` and string inputs. ## Test plan - [x] Verify `go test ./...` passes - [x] Test with an MCP client that sends numeric parameters as strings *Created by Claude on behalf of @silverwind* Reviewed-on: https://gitea.com/gitea/gitea-mcp/pulls/138 Reviewed-by: Lunny Xiao Co-authored-by: silverwind Co-committed-by: silverwind --- operation/actions/logs.go | 24 +- operation/actions/runs.go | 75 ++- operation/actions/secrets.go | 21 +- operation/actions/variables.go | 21 +- operation/issue/issue.go | 27 +- operation/label/label.go | 88 ++-- operation/milestone/milestone.go | 41 +- operation/pull/pull.go | 102 ++-- operation/pull/pull_test.go | 622 +++++++++++++------------ operation/repo/commit.go | 13 +- operation/repo/release.go | 21 +- operation/repo/repo.go | 11 +- operation/repo/tag.go | 5 +- operation/search/search.go | 35 +- operation/timetracking/timetracking.go | 40 +- operation/user/user.go | 7 +- pkg/params/params.go | 45 +- pkg/params/params_test.go | 57 +++ 18 files changed, 628 insertions(+), 627 deletions(-) diff --git a/operation/actions/logs.go b/operation/actions/logs.go index 2dc1d87..cd99efd 100644 --- a/operation/actions/logs.go +++ b/operation/actions/logs.go @@ -11,6 +11,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" "github.com/mark3labs/mcp-go/mcp" @@ -116,22 +117,12 @@ func GetRepoActionJobLogPreviewFn(ctx context.Context, req mcp.CallToolRequest) if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - jobIDFloat, ok := req.GetArguments()["job_id"].(float64) - if !ok || jobIDFloat <= 0 { + jobID, err := params.GetIndex(req.GetArguments(), "job_id") + if err != nil || jobID <= 0 { return to.ErrorResult(errors.New("job_id is required")) } - tailLinesFloat, _ := req.GetArguments()["tail_lines"].(float64) - maxBytesFloat, _ := req.GetArguments()["max_bytes"].(float64) - tailLines := int(tailLinesFloat) - if tailLines <= 0 { - tailLines = 200 - } - maxBytes := int(maxBytesFloat) - if maxBytes <= 0 { - maxBytes = 65536 - } - - jobID := int64(jobIDFloat) + tailLines := int(params.GetOptionalInt(req.GetArguments(), "tail_lines", 200)) + maxBytes := int(params.GetOptionalInt(req.GetArguments(), "max_bytes", 65536)) raw, usedPath, err := fetchJobLogBytes(ctx, owner, repo, jobID) if err != nil { return to.ErrorResult(fmt.Errorf("get job log err: %v", err)) @@ -161,12 +152,11 @@ func DownloadRepoActionJobLogFn(ctx context.Context, req mcp.CallToolRequest) (* if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - jobIDFloat, ok := req.GetArguments()["job_id"].(float64) - if !ok || jobIDFloat <= 0 { + jobID, err := params.GetIndex(req.GetArguments(), "job_id") + if err != nil || jobID <= 0 { return to.ErrorResult(errors.New("job_id is required")) } outputPath, _ := req.GetArguments()["output_path"].(string) - jobID := int64(jobIDFloat) raw, usedPath, err := fetchJobLogBytes(ctx, owner, repo, jobID) if err != nil { diff --git a/operation/actions/runs.go b/operation/actions/runs.go index a980b02..1fb632b 100644 --- a/operation/actions/runs.go +++ b/operation/actions/runs.go @@ -11,6 +11,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" "github.com/mark3labs/mcp-go/mcp" @@ -155,14 +156,8 @@ func ListRepoActionWorkflowsFn(ctx context.Context, req mcp.CallToolRequest) (*m if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - page, _ := req.GetArguments()["page"].(float64) - if page <= 0 { - page = 1 - } - pageSize, _ := req.GetArguments()["pageSize"].(float64) - if pageSize <= 0 { - pageSize = 50 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 50) query := url.Values{} query.Set("page", strconv.Itoa(int(page))) query.Set("limit", strconv.Itoa(int(pageSize))) @@ -271,14 +266,8 @@ func ListRepoActionRunsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - page, _ := req.GetArguments()["page"].(float64) - if page <= 0 { - page = 1 - } - pageSize, _ := req.GetArguments()["pageSize"].(float64) - if pageSize <= 0 { - pageSize = 50 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 50) statusFilter, _ := req.GetArguments()["status"].(string) query := url.Values{} @@ -311,15 +300,15 @@ func GetRepoActionRunFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - runID, ok := req.GetArguments()["run_id"].(float64) - if !ok || runID <= 0 { + runID, err := params.GetIndex(req.GetArguments(), "run_id") + if err != nil || runID <= 0 { return to.ErrorResult(errors.New("run_id is required")) } var result any - err := doJSONWithFallback(ctx, "GET", + err = doJSONWithFallback(ctx, "GET", []string{ - fmt.Sprintf("repos/%s/%s/actions/runs/%d", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), + fmt.Sprintf("repos/%s/%s/actions/runs/%d", url.PathEscape(owner), url.PathEscape(repo), runID), }, nil, nil, &result, ) @@ -339,14 +328,14 @@ func CancelRepoActionRunFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.C if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - runID, ok := req.GetArguments()["run_id"].(float64) - if !ok || runID <= 0 { + runID, err := params.GetIndex(req.GetArguments(), "run_id") + if err != nil || runID <= 0 { return to.ErrorResult(errors.New("run_id is required")) } - err := doJSONWithFallback(ctx, "POST", + err = doJSONWithFallback(ctx, "POST", []string{ - fmt.Sprintf("repos/%s/%s/actions/runs/%d/cancel", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), + fmt.Sprintf("repos/%s/%s/actions/runs/%d/cancel", url.PathEscape(owner), url.PathEscape(repo), runID), }, nil, nil, nil, ) @@ -366,15 +355,15 @@ func RerunRepoActionRunFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - runID, ok := req.GetArguments()["run_id"].(float64) - if !ok || runID <= 0 { + runID, err := params.GetIndex(req.GetArguments(), "run_id") + if err != nil || runID <= 0 { return to.ErrorResult(errors.New("run_id is required")) } - err := doJSONWithFallback(ctx, "POST", + err = doJSONWithFallback(ctx, "POST", []string{ - fmt.Sprintf("repos/%s/%s/actions/runs/%d/rerun", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), - fmt.Sprintf("repos/%s/%s/actions/runs/%d/rerun-failed-jobs", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), + fmt.Sprintf("repos/%s/%s/actions/runs/%d/rerun", url.PathEscape(owner), url.PathEscape(repo), runID), + fmt.Sprintf("repos/%s/%s/actions/runs/%d/rerun-failed-jobs", url.PathEscape(owner), url.PathEscape(repo), runID), }, nil, nil, nil, ) @@ -398,14 +387,8 @@ func ListRepoActionJobsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - page, _ := req.GetArguments()["page"].(float64) - if page <= 0 { - page = 1 - } - pageSize, _ := req.GetArguments()["pageSize"].(float64) - if pageSize <= 0 { - pageSize = 50 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 50) statusFilter, _ := req.GetArguments()["status"].(string) query := url.Values{} @@ -438,27 +421,21 @@ func ListRepoActionRunJobsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - runID, ok := req.GetArguments()["run_id"].(float64) - if !ok || runID <= 0 { + runID, err := params.GetIndex(req.GetArguments(), "run_id") + if err != nil || runID <= 0 { return to.ErrorResult(errors.New("run_id is required")) } - page, _ := req.GetArguments()["page"].(float64) - if page <= 0 { - page = 1 - } - pageSize, _ := req.GetArguments()["pageSize"].(float64) - if pageSize <= 0 { - pageSize = 50 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 50) query := url.Values{} query.Set("page", strconv.Itoa(int(page))) query.Set("limit", strconv.Itoa(int(pageSize))) var result any - err := doJSONWithFallback(ctx, "GET", + err = doJSONWithFallback(ctx, "GET", []string{ - fmt.Sprintf("repos/%s/%s/actions/runs/%d/jobs", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), + fmt.Sprintf("repos/%s/%s/actions/runs/%d/jobs", url.PathEscape(owner), url.PathEscape(repo), runID), }, query, nil, &result, ) diff --git a/operation/actions/secrets.go b/operation/actions/secrets.go index 9843c78..586cda1 100644 --- a/operation/actions/secrets.go +++ b/operation/actions/secrets.go @@ -9,6 +9,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" gitea_sdk "code.gitea.io/sdk/gitea" @@ -104,14 +105,8 @@ func ListRepoActionSecretsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - page, _ := req.GetArguments()["page"].(float64) - if page <= 0 { - page = 1 - } - pageSize, _ := req.GetArguments()["pageSize"].(float64) - if pageSize <= 0 { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) client, err := gitea.ClientFromContext(ctx) if err != nil { @@ -206,14 +201,8 @@ func ListOrgActionSecretsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp. if !ok || org == "" { return to.ErrorResult(errors.New("org is required")) } - page, _ := req.GetArguments()["page"].(float64) - if page <= 0 { - page = 1 - } - pageSize, _ := req.GetArguments()["pageSize"].(float64) - if pageSize <= 0 { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) client, err := gitea.ClientFromContext(ctx) if err != nil { diff --git a/operation/actions/variables.go b/operation/actions/variables.go index a7ceb65..dfeacfc 100644 --- a/operation/actions/variables.go +++ b/operation/actions/variables.go @@ -9,6 +9,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" gitea_sdk "code.gitea.io/sdk/gitea" @@ -139,14 +140,8 @@ func ListRepoActionVariablesFn(ctx context.Context, req mcp.CallToolRequest) (*m if !ok || repo == "" { return to.ErrorResult(errors.New("repo is required")) } - page, _ := req.GetArguments()["page"].(float64) - if page <= 0 { - page = 1 - } - pageSize, _ := req.GetArguments()["pageSize"].(float64) - if pageSize <= 0 { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) query := url.Values{} query.Set("page", strconv.Itoa(int(page))) @@ -278,14 +273,8 @@ func ListOrgActionVariablesFn(ctx context.Context, req mcp.CallToolRequest) (*mc if !ok || org == "" { return to.ErrorResult(errors.New("org is required")) } - page, _ := req.GetArguments()["page"].(float64) - if page <= 0 { - page = 1 - } - pageSize, _ := req.GetArguments()["pageSize"].(float64) - if pageSize <= 0 { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) client, err := gitea.ClientFromContext(ctx) if err != nil { diff --git a/operation/issue/issue.go b/operation/issue/issue.go index 97ecba9..a380a84 100644 --- a/operation/issue/issue.go +++ b/operation/issue/issue.go @@ -167,14 +167,8 @@ func ListRepoIssuesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo if !ok { state = "all" } - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) opt := gitea_sdk.ListIssueOption{ State: gitea_sdk.StateType(state), ListOptions: gitea_sdk.ListOptions{ @@ -295,9 +289,10 @@ func EditIssueFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolRes } } opt.Assignees = assignees - milestone, ok := req.GetArguments()["milestone"].(float64) - if ok { - opt.Milestone = new(int64(milestone)) + if val, exists := req.GetArguments()["milestone"]; exists { + if milestone, ok := params.ToInt64(val); ok { + opt.Milestone = new(milestone) + } } state, ok := req.GetArguments()["state"].(string) if ok { @@ -326,9 +321,9 @@ func EditIssueCommentFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call if !ok { return to.ErrorResult(errors.New("repo is required")) } - commentID, ok := req.GetArguments()["commentID"].(float64) - if !ok { - return to.ErrorResult(errors.New("comment ID is required")) + commentID, err := params.GetIndex(req.GetArguments(), "commentID") + if err != nil { + return to.ErrorResult(err) } body, ok := req.GetArguments()["body"].(string) if !ok { @@ -341,9 +336,9 @@ func EditIssueCommentFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - issueComment, _, err := client.EditIssueComment(owner, repo, int64(commentID), opt) + issueComment, _, err := client.EditIssueComment(owner, repo, commentID, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("edit %v/%v/issues/comments/%v err: %v", owner, repo, int64(commentID), err)) + return to.ErrorResult(fmt.Errorf("edit %v/%v/issues/comments/%v err: %v", owner, repo, commentID, err)) } return to.TextResult(issueComment) diff --git a/operation/label/label.go b/operation/label/label.go index b918c8e..acb6128 100644 --- a/operation/label/label.go +++ b/operation/label/label.go @@ -218,14 +218,8 @@ func ListRepoLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo if !ok { return to.ErrorResult(errors.New("repo is required")) } - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) opt := gitea_sdk.ListLabelsOptions{ ListOptions: gitea_sdk.ListOptions{ @@ -254,18 +248,18 @@ func GetRepoLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool if !ok { return to.ErrorResult(errors.New("repo is required")) } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return to.ErrorResult(errors.New("label ID is required")) + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - label, _, err := client.GetRepoLabel(owner, repo, int64(id)) + label, _, err := client.GetRepoLabel(owner, repo, id) if err != nil { - return to.ErrorResult(fmt.Errorf("get %v/%v/label/%v err: %v", owner, repo, int64(id), err)) + return to.ErrorResult(fmt.Errorf("get %v/%v/label/%v err: %v", owner, repo, id, err)) } return to.TextResult(label) } @@ -317,9 +311,9 @@ func EditRepoLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo if !ok { return to.ErrorResult(errors.New("repo is required")) } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return to.ErrorResult(errors.New("label ID is required")) + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } opt := gitea_sdk.EditLabelOption{} @@ -337,9 +331,9 @@ func EditRepoLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - label, _, err := client.EditLabel(owner, repo, int64(id), opt) + label, _, err := client.EditLabel(owner, repo, id, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("edit %v/%v/label/%v err: %v", owner, repo, int64(id), err)) + return to.ErrorResult(fmt.Errorf("edit %v/%v/label/%v err: %v", owner, repo, id, err)) } return to.TextResult(label) } @@ -354,18 +348,18 @@ func DeleteRepoLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT if !ok { return to.ErrorResult(errors.New("repo is required")) } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return to.ErrorResult(errors.New("label ID is required")) + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteLabel(owner, repo, int64(id)) + _, err = client.DeleteLabel(owner, repo, id) if err != nil { - return to.ErrorResult(fmt.Errorf("delete %v/%v/label/%v err: %v", owner, repo, int64(id), err)) + return to.ErrorResult(fmt.Errorf("delete %v/%v/label/%v err: %v", owner, repo, id, err)) } return to.TextResult("Label deleted successfully") } @@ -390,8 +384,8 @@ func AddIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo } var labels []int64 for _, l := range labelsRaw { - if labelID, ok := l.(float64); ok { - labels = append(labels, int64(labelID)) + if labelID, ok := params.ToInt64(l); ok { + labels = append(labels, labelID) } else { return to.ErrorResult(errors.New("invalid label ID in labels array")) } @@ -432,8 +426,8 @@ func ReplaceIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca } var labels []int64 for _, l := range labelsRaw { - if labelID, ok := l.(float64); ok { - labels = append(labels, int64(labelID)) + if labelID, ok := params.ToInt64(l); ok { + labels = append(labels, labelID) } else { return to.ErrorResult(errors.New("invalid label ID in labels array")) } @@ -494,18 +488,18 @@ func RemoveIssueLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call if err != nil { return to.ErrorResult(err) } - labelID, ok := req.GetArguments()["label_id"].(float64) - if !ok { - return to.ErrorResult(errors.New("label ID is required")) + labelID, err := params.GetIndex(req.GetArguments(), "label_id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteIssueLabel(owner, repo, index, int64(labelID)) + _, err = client.DeleteIssueLabel(owner, repo, index, labelID) if err != nil { - return to.ErrorResult(fmt.Errorf("remove label %v from %v/%v/issue/%v err: %v", int64(labelID), owner, repo, index, err)) + return to.ErrorResult(fmt.Errorf("remove label %v from %v/%v/issue/%v err: %v", labelID, owner, repo, index, err)) } return to.TextResult("Label removed successfully") } @@ -516,14 +510,8 @@ func ListOrgLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo if !ok { return to.ErrorResult(errors.New("org is required")) } - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) opt := gitea_sdk.ListOrgLabelsOptions{ ListOptions: gitea_sdk.ListOptions{ @@ -583,9 +571,9 @@ func EditOrgLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool if !ok { return to.ErrorResult(errors.New("org is required")) } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return to.ErrorResult(errors.New("label ID is required")) + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } opt := gitea_sdk.EditOrgLabelOption{} @@ -606,9 +594,9 @@ func EditOrgLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - label, _, err := client.EditOrgLabel(org, int64(id), opt) + label, _, err := client.EditOrgLabel(org, id, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("edit %v/labels/%v err: %v", org, int64(id), err)) + return to.ErrorResult(fmt.Errorf("edit %v/labels/%v err: %v", org, id, err)) } return to.TextResult(label) } @@ -619,18 +607,18 @@ func DeleteOrgLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo if !ok { return to.ErrorResult(errors.New("org is required")) } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return to.ErrorResult(errors.New("label ID is required")) + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteOrgLabel(org, int64(id)) + _, err = client.DeleteOrgLabel(org, id) if err != nil { - return to.ErrorResult(fmt.Errorf("delete %v/labels/%v err: %v", org, int64(id), err)) + return to.ErrorResult(fmt.Errorf("delete %v/labels/%v err: %v", org, id, err)) } return to.TextResult("Label deleted successfully") } diff --git a/operation/milestone/milestone.go b/operation/milestone/milestone.go index 05ef9eb..8f853f3 100644 --- a/operation/milestone/milestone.go +++ b/operation/milestone/milestone.go @@ -7,6 +7,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/tool" @@ -109,17 +110,17 @@ func GetMilestoneFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool if !ok { return to.ErrorResult(errors.New("repo is required")) } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return to.ErrorResult(errors.New("id is required")) + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - milestone, _, err := client.GetMilestone(owner, repo, int64(id)) + milestone, _, err := client.GetMilestone(owner, repo, id) if err != nil { - return to.ErrorResult(fmt.Errorf("get %v/%v/milestone/%v err: %v", owner, repo, int64(id), err)) + return to.ErrorResult(fmt.Errorf("get %v/%v/milestone/%v err: %v", owner, repo, id, err)) } return to.TextResult(milestone) @@ -143,14 +144,8 @@ func ListMilestonesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo if !ok { name = "" } - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) opt := gitea_sdk.ListMilestoneOption{ State: gitea_sdk.StateType(state), Name: name, @@ -216,9 +211,9 @@ func EditMilestoneFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo if !ok { return to.ErrorResult(errors.New("repo is required")) } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return to.ErrorResult(errors.New("id is required")) + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } opt := gitea_sdk.EditMilestoneOption{} @@ -240,9 +235,9 @@ func EditMilestoneFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - milestone, _, err := client.EditMilestone(owner, repo, int64(id), opt) + milestone, _, err := client.EditMilestone(owner, repo, id, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("edit %v/%v/milestone/%v err: %v", owner, repo, int64(id), err)) + return to.ErrorResult(fmt.Errorf("edit %v/%v/milestone/%v err: %v", owner, repo, id, err)) } return to.TextResult(milestone) @@ -258,17 +253,17 @@ func DeleteMilestoneFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT if !ok { return to.ErrorResult(errors.New("repo is required")) } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return to.ErrorResult(errors.New("id is required")) + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteMilestone(owner, repo, int64(id)) + _, err = client.DeleteMilestone(owner, repo, id) if err != nil { - return to.ErrorResult(fmt.Errorf("delete %v/%v/milestone/%v err: %v", owner, repo, int64(id), err)) + return to.ErrorResult(fmt.Errorf("delete %v/%v/milestone/%v err: %v", owner, repo, id, err)) } return to.TextResult("Milestone deleted successfully") diff --git a/operation/pull/pull.go b/operation/pull/pull.go index fdca9ef..d55df30 100644 --- a/operation/pull/pull.go +++ b/operation/pull/pull.go @@ -345,19 +345,13 @@ func ListRepoPullRequestsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp. if !ok { sort = "recentupdate" } - milestone, _ := req.GetArguments()["milestone"].(float64) - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + milestone := params.GetOptionalInt(req.GetArguments(), "milestone", 0) + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) opt := gitea_sdk.ListPullRequestsOptions{ State: gitea_sdk.StateType(state), Sort: sort, - Milestone: int64(milestone), + Milestone: milestone, ListOptions: gitea_sdk.ListOptions{ Page: int(page), PageSize: int(pageSize), @@ -555,14 +549,8 @@ func ListPullRequestReviewsFn(ctx context.Context, req mcp.CallToolRequest) (*mc if err != nil { return to.ErrorResult(err) } - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) client, err := gitea.ClientFromContext(ctx) if err != nil { @@ -596,9 +584,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp. if err != nil { return to.ErrorResult(err) } - reviewID, ok := req.GetArguments()["review_id"].(float64) - if !ok { - return to.ErrorResult(errors.New("review_id is required")) + reviewID, err := params.GetIndex(req.GetArguments(), "review_id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) @@ -606,9 +594,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp. return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - review, _, err := client.GetPullReview(owner, repo, index, int64(reviewID)) + review, _, err := client.GetPullReview(owner, repo, index, reviewID) if err != nil { - return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) + return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err)) } return to.TextResult(review) @@ -628,9 +616,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques if err != nil { return to.ErrorResult(err) } - reviewID, ok := req.GetArguments()["review_id"].(float64) - if !ok { - return to.ErrorResult(errors.New("review_id is required")) + reviewID, err := params.GetIndex(req.GetArguments(), "review_id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) @@ -638,9 +626,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - comments, _, err := client.ListPullReviewComments(owner, repo, index, int64(reviewID)) + comments, _, err := client.ListPullReviewComments(owner, repo, index, reviewID) if err != nil { - return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) + return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err)) } return to.TextResult(comments) @@ -685,11 +673,11 @@ func CreatePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m if body, ok := commentMap["body"].(string); ok { reviewComment.Body = body } - if oldLineNum, ok := commentMap["old_line_num"].(float64); ok { - reviewComment.OldLineNum = int64(oldLineNum) + if oldLineNum, ok := params.ToInt64(commentMap["old_line_num"]); ok { + reviewComment.OldLineNum = oldLineNum } - if newLineNum, ok := commentMap["new_line_num"].(float64); ok { - reviewComment.NewLineNum = int64(newLineNum) + if newLineNum, ok := params.ToInt64(commentMap["new_line_num"]); ok { + reviewComment.NewLineNum = newLineNum } opt.Comments = append(opt.Comments, reviewComment) } @@ -724,9 +712,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m if err != nil { return to.ErrorResult(err) } - reviewID, ok := req.GetArguments()["review_id"].(float64) - if !ok { - return to.ErrorResult(errors.New("review_id is required")) + reviewID, err := params.GetIndex(req.GetArguments(), "review_id") + if err != nil { + return to.ErrorResult(err) } state, ok := req.GetArguments()["state"].(string) if !ok { @@ -745,9 +733,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - review, _, err := client.SubmitPullReview(owner, repo, index, int64(reviewID), opt) + review, _, err := client.SubmitPullReview(owner, repo, index, reviewID, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) + return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err)) } return to.TextResult(review) @@ -767,9 +755,9 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m if err != nil { return to.ErrorResult(err) } - reviewID, ok := req.GetArguments()["review_id"].(float64) - if !ok { - return to.ErrorResult(errors.New("review_id is required")) + reviewID, err := params.GetIndex(req.GetArguments(), "review_id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) @@ -777,14 +765,14 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeletePullReview(owner, repo, index, int64(reviewID)) + _, err = client.DeletePullReview(owner, repo, index, reviewID) if err != nil { - return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) + return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err)) } successMsg := map[string]any{ "message": "Successfully deleted review", - "review_id": int64(reviewID), + "review_id": reviewID, "pr_index": index, "repository": fmt.Sprintf("%s/%s", owner, repo), } @@ -806,9 +794,9 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (* if err != nil { return to.ErrorResult(err) } - reviewID, ok := req.GetArguments()["review_id"].(float64) - if !ok { - return to.ErrorResult(errors.New("review_id is required")) + reviewID, err := params.GetIndex(req.GetArguments(), "review_id") + if err != nil { + return to.ErrorResult(err) } opt := gitea_sdk.DismissPullReviewOptions{} @@ -821,14 +809,14 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (* return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DismissPullReview(owner, repo, index, int64(reviewID), opt) + _, err = client.DismissPullReview(owner, repo, index, reviewID, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) + return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err)) } successMsg := map[string]any{ "message": "Successfully dismissed review", - "review_id": int64(reviewID), + "review_id": reviewID, "pr_index": index, "repository": fmt.Sprintf("%s/%s", owner, repo), } @@ -917,9 +905,9 @@ func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT if !ok { return to.ErrorResult(errors.New("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(errors.New("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } opt := gitea_sdk.EditPullRequestOption{} @@ -947,8 +935,10 @@ func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT opt.Assignees = assignees } } - if milestone, ok := req.GetArguments()["milestone"].(float64); ok { - opt.Milestone = int64(milestone) + if val, exists := req.GetArguments()["milestone"]; exists { + if milestone, ok := params.ToInt64(val); ok { + opt.Milestone = milestone + } } if state, ok := req.GetArguments()["state"].(string); ok { opt.State = new(gitea_sdk.StateType(state)) @@ -962,9 +952,9 @@ func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - pr, _, err := client.EditPullRequest(owner, repo, int64(index), opt) + pr, _, err := client.EditPullRequest(owner, repo, index, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("edit %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("edit %v/%v/pr/%v err: %v", owner, repo, index, err)) } return to.TextResult(pr) diff --git a/operation/pull/pull_test.go b/operation/pull/pull_test.go index 84343b9..a6cdcdc 100644 --- a/operation/pull/pull_test.go +++ b/operation/pull/pull_test.go @@ -20,100 +20,112 @@ func TestEditPullRequestFn(t *testing.T) { index = 7 ) - var ( - mu sync.Mutex - gotMethod string - gotPath string - gotBody map[string]any - ) + indexInputs := []struct { + name string + val any + }{ + {"float64", float64(index)}, + {"string", "7"}, + } + + for _, ii := range indexInputs { + t.Run(ii.name, func(t *testing.T) { + var ( + mu sync.Mutex + gotMethod string + gotPath string + gotBody map[string]any + ) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/version": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version":"1.12.0"}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"private":false}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index): + mu.Lock() + gotMethod = r.Method + gotPath = r.URL.Path + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + gotBody = body + mu.Unlock() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"%s","state":"open"}`, index, body["title"])) + default: + http.NotFound(w, r) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + origHost := flag.Host + origToken := flag.Token + origVersion := flag.Version + flag.Host = server.URL + flag.Token = "" + flag.Version = "test" + defer func() { + flag.Host = origHost + flag.Token = origToken + flag.Version = origVersion + }() + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "owner": owner, + "repo": repo, + "index": ii.val, + "title": "WIP: my feature", + "state": "open", + }, + }, + } + + result, err := EditPullRequestFn(context.Background(), req) + if err != nil { + t.Fatalf("EditPullRequestFn() error = %v", err) + } - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/v1/version": - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version":"1.12.0"}`)) - case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"private":false}`)) - case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index): mu.Lock() - gotMethod = r.Method - gotPath = r.URL.Path - var body map[string]any - _ = json.NewDecoder(r.Body).Decode(&body) - gotBody = body - mu.Unlock() - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"%s","state":"open"}`, index, body["title"])) - default: - http.NotFound(w, r) - } - }) + defer mu.Unlock() - server := httptest.NewServer(handler) - defer server.Close() + if gotMethod != http.MethodPatch { + t.Fatalf("expected PATCH request, got %s", gotMethod) + } + if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index) { + t.Fatalf("unexpected path: %s", gotPath) + } + if gotBody["title"] != "WIP: my feature" { + t.Fatalf("expected title 'WIP: my feature', got %v", gotBody["title"]) + } + if gotBody["state"] != "open" { + t.Fatalf("expected state 'open', got %v", gotBody["state"]) + } - origHost := flag.Host - origToken := flag.Token - origVersion := flag.Version - flag.Host = server.URL - flag.Token = "" - flag.Version = "test" - defer func() { - flag.Host = origHost - flag.Token = origToken - flag.Version = origVersion - }() + if len(result.Content) == 0 { + t.Fatalf("expected content in result") + } + textContent, ok := mcp.AsTextContent(result.Content[0]) + if !ok { + t.Fatalf("expected text content, got %T", result.Content[0]) + } - req := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: map[string]any{ - "owner": owner, - "repo": repo, - "index": float64(index), - "title": "WIP: my feature", - "state": "open", - }, - }, - } - - result, err := EditPullRequestFn(context.Background(), req) - if err != nil { - t.Fatalf("EditPullRequestFn() error = %v", err) - } - - mu.Lock() - defer mu.Unlock() - - if gotMethod != http.MethodPatch { - t.Fatalf("expected PATCH request, got %s", gotMethod) - } - if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index) { - t.Fatalf("unexpected path: %s", gotPath) - } - if gotBody["title"] != "WIP: my feature" { - t.Fatalf("expected title 'WIP: my feature', got %v", gotBody["title"]) - } - if gotBody["state"] != "open" { - t.Fatalf("expected state 'open', got %v", gotBody["state"]) - } - - if len(result.Content) == 0 { - t.Fatalf("expected content in result") - } - textContent, ok := mcp.AsTextContent(result.Content[0]) - if !ok { - t.Fatalf("expected text content, got %T", result.Content[0]) - } - - var parsed struct { - Result map[string]any `json:"Result"` - } - if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil { - t.Fatalf("unmarshal result text: %v", err) - } - if got := parsed.Result["title"].(string); got != "WIP: my feature" { - t.Fatalf("result title = %q, want %q", got, "WIP: my feature") + var parsed struct { + Result map[string]any `json:"Result"` + } + if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil { + t.Fatalf("unmarshal result text: %v", err) + } + if got := parsed.Result["title"].(string); got != "WIP: my feature" { + t.Fatalf("result title = %q, want %q", got, "WIP: my feature") + } + }) } } @@ -124,113 +136,125 @@ func TestMergePullRequestFn(t *testing.T) { index = 5 ) - var ( - mu sync.Mutex - gotMethod string - gotPath string - gotBody map[string]any - ) + indexInputs := []struct { + name string + val any + }{ + {"float64", float64(index)}, + {"string", "5"}, + } + + for _, ii := range indexInputs { + t.Run(ii.name, func(t *testing.T) { + var ( + mu sync.Mutex + gotMethod string + gotPath string + gotBody map[string]any + ) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/version": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version":"1.12.0"}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"private":false}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index): + mu.Lock() + gotMethod = r.Method + gotPath = r.URL.Path + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + gotBody = body + mu.Unlock() + w.WriteHeader(http.StatusOK) + default: + http.NotFound(w, r) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + origHost := flag.Host + origToken := flag.Token + origVersion := flag.Version + flag.Host = server.URL + flag.Token = "" + flag.Version = "test" + defer func() { + flag.Host = origHost + flag.Token = origToken + flag.Version = origVersion + }() + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "owner": owner, + "repo": repo, + "index": ii.val, + "merge_style": "squash", + "title": "feat: my squashed commit", + "message": "Squash merge of PR #5", + "delete_branch": true, + }, + }, + } + + result, err := MergePullRequestFn(context.Background(), req) + if err != nil { + t.Fatalf("MergePullRequestFn() error = %v", err) + } - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/v1/version": - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version":"1.12.0"}`)) - case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"private":false}`)) - case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index): mu.Lock() - gotMethod = r.Method - gotPath = r.URL.Path - var body map[string]any - _ = json.NewDecoder(r.Body).Decode(&body) - gotBody = body - mu.Unlock() - w.WriteHeader(http.StatusOK) - default: - http.NotFound(w, r) - } - }) + defer mu.Unlock() - server := httptest.NewServer(handler) - defer server.Close() + if gotMethod != http.MethodPost { + t.Fatalf("expected POST request, got %s", gotMethod) + } + if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index) { + t.Fatalf("unexpected path: %s", gotPath) + } + if gotBody["Do"] != "squash" { + t.Fatalf("expected Do 'squash', got %v", gotBody["Do"]) + } + if gotBody["MergeTitleField"] != "feat: my squashed commit" { + t.Fatalf("expected MergeTitleField 'feat: my squashed commit', got %v", gotBody["MergeTitleField"]) + } + if gotBody["MergeMessageField"] != "Squash merge of PR #5" { + t.Fatalf("expected MergeMessageField 'Squash merge of PR #5', got %v", gotBody["MergeMessageField"]) + } + if gotBody["delete_branch_after_merge"] != true { + t.Fatalf("expected delete_branch_after_merge true, got %v", gotBody["delete_branch_after_merge"]) + } - origHost := flag.Host - origToken := flag.Token - origVersion := flag.Version - flag.Host = server.URL - flag.Token = "" - flag.Version = "test" - defer func() { - flag.Host = origHost - flag.Token = origToken - flag.Version = origVersion - }() + if len(result.Content) == 0 { + t.Fatalf("expected content in result") + } + textContent, ok := mcp.AsTextContent(result.Content[0]) + if !ok { + t.Fatalf("expected text content, got %T", result.Content[0]) + } - req := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: map[string]any{ - "owner": owner, - "repo": repo, - "index": float64(index), - "merge_style": "squash", - "title": "feat: my squashed commit", - "message": "Squash merge of PR #5", - "delete_branch": true, - }, - }, - } - - result, err := MergePullRequestFn(context.Background(), req) - if err != nil { - t.Fatalf("MergePullRequestFn() error = %v", err) - } - - mu.Lock() - defer mu.Unlock() - - if gotMethod != http.MethodPost { - t.Fatalf("expected POST request, got %s", gotMethod) - } - if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index) { - t.Fatalf("unexpected path: %s", gotPath) - } - if gotBody["Do"] != "squash" { - t.Fatalf("expected Do 'squash', got %v", gotBody["Do"]) - } - if gotBody["MergeTitleField"] != "feat: my squashed commit" { - t.Fatalf("expected MergeTitleField 'feat: my squashed commit', got %v", gotBody["MergeTitleField"]) - } - if gotBody["MergeMessageField"] != "Squash merge of PR #5" { - t.Fatalf("expected MergeMessageField 'Squash merge of PR #5', got %v", gotBody["MergeMessageField"]) - } - if gotBody["delete_branch_after_merge"] != true { - t.Fatalf("expected delete_branch_after_merge true, got %v", gotBody["delete_branch_after_merge"]) - } - - if len(result.Content) == 0 { - t.Fatalf("expected content in result") - } - textContent, ok := mcp.AsTextContent(result.Content[0]) - if !ok { - t.Fatalf("expected text content, got %T", result.Content[0]) - } - - var parsed struct { - Result map[string]any `json:"Result"` - } - if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil { - t.Fatalf("unmarshal result text: %v", err) - } - if parsed.Result["merged"] != true { - t.Fatalf("expected merged=true, got %v", parsed.Result["merged"]) - } - if parsed.Result["merge_style"] != "squash" { - t.Fatalf("expected merge_style 'squash', got %v", parsed.Result["merge_style"]) - } - if parsed.Result["branch_deleted"] != true { - t.Fatalf("expected branch_deleted=true, got %v", parsed.Result["branch_deleted"]) + var parsed struct { + Result map[string]any `json:"Result"` + } + if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil { + t.Fatalf("unmarshal result text: %v", err) + } + if parsed.Result["merged"] != true { + t.Fatalf("expected merged=true, got %v", parsed.Result["merged"]) + } + if parsed.Result["merge_style"] != "squash" { + t.Fatalf("expected merge_style 'squash', got %v", parsed.Result["merge_style"]) + } + if parsed.Result["branch_deleted"] != true { + t.Fatalf("expected branch_deleted=true, got %v", parsed.Result["branch_deleted"]) + } + }) } } @@ -242,120 +266,132 @@ func TestGetPullRequestDiffFn(t *testing.T) { diffRaw = "diff --git a/file.txt b/file.txt\n+line\n" ) - var ( - mu sync.Mutex - diffRequested bool - binaryValue string - ) - errCh := make(chan error, 1) + indexInputs := []struct { + name string + val any + }{ + {"float64", float64(index)}, + {"string", "12"}, + } - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/v1/version": - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version":"1.12.0"}`)) - case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"private":false}`)) - case fmt.Sprintf("/%s/%s/pulls/%d.diff", owner, repo, index): - if r.Method != http.MethodGet { - select { - case errCh <- fmt.Errorf("unexpected method: %s", r.Method): + for _, ii := range indexInputs { + t.Run(ii.name, func(t *testing.T) { + var ( + mu sync.Mutex + diffRequested bool + binaryValue string + ) + errCh := make(chan error, 1) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/version": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version":"1.12.0"}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"private":false}`)) + case fmt.Sprintf("/%s/%s/pulls/%d.diff", owner, repo, index): + if r.Method != http.MethodGet { + select { + case errCh <- fmt.Errorf("unexpected method: %s", r.Method): + default: + } + } + mu.Lock() + diffRequested = true + binaryValue = r.URL.Query().Get("binary") + mu.Unlock() + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(diffRaw)) default: + select { + case errCh <- fmt.Errorf("unexpected request path: %s", r.URL.Path): + default: + } } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + origHost := flag.Host + origToken := flag.Token + origVersion := flag.Version + flag.Host = server.URL + flag.Token = "" + flag.Version = "test" + defer func() { + flag.Host = origHost + flag.Token = origToken + flag.Version = origVersion + }() + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "owner": owner, + "repo": repo, + "index": ii.val, + "binary": true, + }, + }, } - mu.Lock() - diffRequested = true - binaryValue = r.URL.Query().Get("binary") - mu.Unlock() - w.Header().Set("Content-Type", "text/plain") - _, _ = w.Write([]byte(diffRaw)) - default: + + result, err := GetPullRequestDiffFn(context.Background(), req) + if err != nil { + t.Fatalf("GetPullRequestDiffFn() error = %v", err) + } + select { - case errCh <- fmt.Errorf("unexpected request path: %s", r.URL.Path): + case reqErr := <-errCh: + t.Fatalf("handler error: %v", reqErr) default: } - } - }) - server := httptest.NewServer(handler) - defer server.Close() + mu.Lock() + requested := diffRequested + gotBinary := binaryValue + mu.Unlock() - origHost := flag.Host - origToken := flag.Token - origVersion := flag.Version - flag.Host = server.URL - flag.Token = "" - flag.Version = "test" - defer func() { - flag.Host = origHost - flag.Token = origToken - flag.Version = origVersion - }() + if !requested { + t.Fatalf("expected diff request to be made") + } + if gotBinary != "true" { + t.Fatalf("expected binary=true query param, got %q", gotBinary) + } - req := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: map[string]any{ - "owner": owner, - "repo": repo, - "index": float64(index), - "binary": true, - }, - }, - } + if len(result.Content) == 0 { + t.Fatalf("expected content in result") + } - result, err := GetPullRequestDiffFn(context.Background(), req) - if err != nil { - t.Fatalf("GetPullRequestDiffFn() error = %v", err) - } + textContent, ok := mcp.AsTextContent(result.Content[0]) + if !ok { + t.Fatalf("expected text content, got %T", result.Content[0]) + } - select { - case reqErr := <-errCh: - t.Fatalf("handler error: %v", reqErr) - default: - } + var parsed struct { + Result map[string]any `json:"Result"` + } + if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil { + t.Fatalf("unmarshal result text: %v", err) + } - mu.Lock() - requested := diffRequested - gotBinary := binaryValue - mu.Unlock() - - if !requested { - t.Fatalf("expected diff request to be made") - } - if gotBinary != "true" { - t.Fatalf("expected binary=true query param, got %q", gotBinary) - } - - if len(result.Content) == 0 { - t.Fatalf("expected content in result") - } - - textContent, ok := mcp.AsTextContent(result.Content[0]) - if !ok { - t.Fatalf("expected text content, got %T", result.Content[0]) - } - - var parsed struct { - Result map[string]any `json:"Result"` - } - if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil { - t.Fatalf("unmarshal result text: %v", err) - } - - if got, ok := parsed.Result["diff"].(string); !ok || got != diffRaw { - t.Fatalf("diff = %q, want %q", got, diffRaw) - } - if got, ok := parsed.Result["binary"].(bool); !ok || got != true { - t.Fatalf("binary = %v, want true", got) - } - if got, ok := parsed.Result["index"].(float64); !ok || int64(got) != int64(index) { - t.Fatalf("index = %v, want %d", got, index) - } - if got, ok := parsed.Result["owner"].(string); !ok || got != owner { - t.Fatalf("owner = %q, want %q", got, owner) - } - if got, ok := parsed.Result["repo"].(string); !ok || got != repo { - t.Fatalf("repo = %q, want %q", got, repo) + if got, ok := parsed.Result["diff"].(string); !ok || got != diffRaw { + t.Fatalf("diff = %q, want %q", got, diffRaw) + } + if got, ok := parsed.Result["binary"].(bool); !ok || got != true { + t.Fatalf("binary = %v, want true", got) + } + if got, ok := parsed.Result["index"].(float64); !ok || int64(got) != int64(index) { + t.Fatalf("index = %v, want %d", got, index) + } + if got, ok := parsed.Result["owner"].(string); !ok || got != owner { + t.Fatalf("owner = %q, want %q", got, owner) + } + if got, ok := parsed.Result["repo"].(string); !ok || got != repo { + t.Fatalf("repo = %q, want %q", got, repo) + } + }) } } diff --git a/operation/repo/commit.go b/operation/repo/commit.go index 77fab4b..e66cb4e 100644 --- a/operation/repo/commit.go +++ b/operation/repo/commit.go @@ -7,6 +7,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" gitea_sdk "code.gitea.io/sdk/gitea" @@ -46,13 +47,13 @@ func ListRepoCommitsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT if !ok { return to.ErrorResult(errors.New("repo is required")) } - page, ok := req.GetArguments()["page"].(float64) - if !ok { - return to.ErrorResult(errors.New("page is required")) + page, err := params.GetIndex(req.GetArguments(), "page") + if err != nil { + return to.ErrorResult(err) } - pageSize, ok := req.GetArguments()["page_size"].(float64) - if !ok { - return to.ErrorResult(errors.New("page_size is required")) + pageSize, err := params.GetIndex(req.GetArguments(), "page_size") + if err != nil { + return to.ErrorResult(err) } sha, _ := req.GetArguments()["sha"].(string) path, _ := req.GetArguments()["path"].(string) diff --git a/operation/repo/release.go b/operation/repo/release.go index 04631f0..e8ec44c 100644 --- a/operation/repo/release.go +++ b/operation/repo/release.go @@ -8,6 +8,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" gitea_sdk "code.gitea.io/sdk/gitea" @@ -163,16 +164,16 @@ func DeleteReleaseFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo if !ok { return nil, errors.New("repo is required") } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return nil, errors.New("id is required") + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteRelease(owner, repo, int64(id)) + _, err = client.DeleteRelease(owner, repo, id) if err != nil { return nil, fmt.Errorf("delete release error: %v", err) } @@ -190,16 +191,16 @@ func GetReleaseFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolRe if !ok { return nil, errors.New("repo is required") } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return nil, errors.New("id is required") + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - release, _, err := client.GetRelease(owner, repo, int64(id)) + release, _, err := client.GetRelease(owner, repo, id) if err != nil { return nil, fmt.Errorf("get release error: %v", err) } @@ -250,8 +251,8 @@ func ListReleasesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool if ok { pIsPreRelease = new(isPreRelease) } - page, _ := req.GetArguments()["page"].(float64) - pageSize, _ := req.GetArguments()["pageSize"].(float64) + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 20) client, err := gitea.ClientFromContext(ctx) if err != nil { diff --git a/operation/repo/repo.go b/operation/repo/repo.go index d7a191f..18dc18c 100644 --- a/operation/repo/repo.go +++ b/operation/repo/repo.go @@ -7,6 +7,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/tool" @@ -191,14 +192,8 @@ func ForkRepoFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResu func ListMyReposFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { log.Debugf("Called ListMyReposFn") - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) opt := gitea_sdk.ListReposOptions{ ListOptions: gitea_sdk.ListOptions{ Page: int(page), diff --git a/operation/repo/tag.go b/operation/repo/tag.go index 4f792e1..42803df 100644 --- a/operation/repo/tag.go +++ b/operation/repo/tag.go @@ -7,6 +7,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" gitea_sdk "code.gitea.io/sdk/gitea" @@ -183,8 +184,8 @@ func ListTagsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResu if !ok { return nil, errors.New("repo is required") } - page, _ := req.GetArguments()["page"].(float64) - pageSize, _ := req.GetArguments()["pageSize"].(float64) + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 20) client, err := gitea.ClientFromContext(ctx) if err != nil { diff --git a/operation/search/search.go b/operation/search/search.go index 83a5d14..bbb6127 100644 --- a/operation/search/search.go +++ b/operation/search/search.go @@ -7,6 +7,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/tool" @@ -79,14 +80,8 @@ func UsersFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, if !ok { return to.ErrorResult(errors.New("keyword is required")) } - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) opt := gitea_sdk.SearchUsersOption{ KeyWord: keyword, ListOptions: gitea_sdk.ListOptions{ @@ -116,14 +111,8 @@ func OrgTeamsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResu return to.ErrorResult(errors.New("query is required")) } includeDescription, _ := req.GetArguments()["includeDescription"].(bool) - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) opt := gitea_sdk.SearchTeamsOptions{ Query: query, IncludeDescription: includeDescription, @@ -151,7 +140,7 @@ func ReposFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, } keywordIsTopic, _ := req.GetArguments()["keywordIsTopic"].(bool) keywordInDescription, _ := req.GetArguments()["keywordInDescription"].(bool) - ownerID, _ := req.GetArguments()["ownerID"].(float64) + ownerID := params.GetOptionalInt(req.GetArguments(), "ownerID", 0) var pIsPrivate *bool isPrivate, ok := req.GetArguments()["isPrivate"].(bool) if ok { @@ -164,19 +153,13 @@ func ReposFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, } sort, _ := req.GetArguments()["sort"].(string) order, _ := req.GetArguments()["order"].(string) - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) opt := gitea_sdk.SearchRepoOptions{ Keyword: keyword, KeywordIsTopic: keywordIsTopic, KeywordInDescription: keywordInDescription, - OwnerID: int64(ownerID), + OwnerID: ownerID, IsPrivate: pIsPrivate, IsArchived: pIsArchived, Sort: sort, diff --git a/operation/timetracking/timetracking.go b/operation/timetracking/timetracking.go index c35e80f..fb82b8b 100644 --- a/operation/timetracking/timetracking.go +++ b/operation/timetracking/timetracking.go @@ -233,14 +233,8 @@ func ListTrackedTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call if err != nil { return to.ErrorResult(err) } - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) @@ -276,16 +270,16 @@ func AddTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo return to.ErrorResult(err) } - timeSeconds, ok := req.GetArguments()["time"].(float64) - if !ok { - return to.ErrorResult(errors.New("time is required")) + timeSeconds, err := params.GetIndex(req.GetArguments(), "time") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } trackedTime, _, err := client.AddTime(owner, repo, index, gitea_sdk.AddTimeOption{ - Time: int64(timeSeconds), + Time: timeSeconds, }) if err != nil { return to.ErrorResult(fmt.Errorf("add tracked time to %s/%s#%d err: %v", owner, repo, index, err)) @@ -308,19 +302,19 @@ func DeleteTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal if err != nil { return to.ErrorResult(err) } - id, ok := req.GetArguments()["id"].(float64) - if !ok { - return to.ErrorResult(errors.New("id is required")) + id, err := params.GetIndex(req.GetArguments(), "id") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteTime(owner, repo, index, int64(id)) + _, err = client.DeleteTime(owner, repo, index, id) if err != nil { - return to.ErrorResult(fmt.Errorf("delete tracked time %d from %s/%s#%d err: %v", int64(id), owner, repo, index, err)) + return to.ErrorResult(fmt.Errorf("delete tracked time %d from %s/%s#%d err: %v", id, owner, repo, index, err)) } - return to.TextResult(fmt.Sprintf("Tracked time entry %d deleted from issue %s/%s#%d", int64(id), owner, repo, index)) + return to.TextResult(fmt.Sprintf("Tracked time entry %d deleted from issue %s/%s#%d", id, owner, repo, index)) } func ListRepoTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -334,14 +328,8 @@ func ListRepoTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo return to.ErrorResult(errors.New("repo is required")) } - page, ok := req.GetArguments()["page"].(float64) - if !ok { - page = 1 - } - pageSize, ok := req.GetArguments()["pageSize"].(float64) - if !ok { - pageSize = 100 - } + page := params.GetOptionalInt(req.GetArguments(), "page", 1) + pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100) client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) diff --git a/operation/user/user.go b/operation/user/user.go index acb62f3..21447bf 100644 --- a/operation/user/user.go +++ b/operation/user/user.go @@ -6,6 +6,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/tool" @@ -68,11 +69,11 @@ func registerTools() { // getIntArg parses an integer argument from the MCP request arguments map. // Returns def if missing, not a number, or less than 1. Used for pagination arguments. func getIntArg(req mcp.CallToolRequest, name string, def int) int { - val, ok := req.GetArguments()[name].(float64) - if !ok || val < 1 { + v := params.GetOptionalInt(req.GetArguments(), name, int64(def)) + if v < 1 { return def } - return int(val) + return int(v) } // GetUserInfoFn is the handler for "get_my_user_info" MCP tool requests. diff --git a/pkg/params/params.go b/pkg/params/params.go index 3f6275e..5ace6a3 100644 --- a/pkg/params/params.go +++ b/pkg/params/params.go @@ -5,7 +5,24 @@ import ( "strconv" ) -// GetIndex extracts an index parameter from MCP tool arguments. +// ToInt64 converts a value to int64, accepting both float64 (JSON number) and +// string representations. Returns false if the value cannot be converted. +func ToInt64(val any) (int64, bool) { + switch v := val.(type) { + case float64: + return int64(v), true + case string: + i, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, false + } + return i, true + default: + return 0, false + } +} + +// GetIndex extracts a required integer parameter from MCP tool arguments. // It accepts both numeric (float64 from JSON) and string representations. // This provides better UX for LLM callers that may naturally use strings // for identifiers like issue/PR numbers. @@ -15,19 +32,27 @@ func GetIndex(args map[string]any, key string) (int64, error) { return 0, fmt.Errorf("%s is required", key) } - // Try float64 (JSON number type) - if f, ok := val.(float64); ok { - return int64(f), nil + if i, ok := ToInt64(val); ok { + return i, nil } - // Try string and parse to integer if s, ok := val.(string); ok { - i, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return 0, fmt.Errorf("%s must be a valid integer (got %q)", key, s) - } - return i, nil + return 0, fmt.Errorf("%s must be a valid integer (got %q)", key, s) } return 0, fmt.Errorf("%s must be a number or numeric string", key) } + +// GetOptionalInt extracts an optional integer parameter from MCP tool arguments. +// Returns defaultVal if the key is missing or the value cannot be parsed. +// Accepts both float64 (JSON number) and string representations. +func GetOptionalInt(args map[string]any, key string, defaultVal int64) int64 { + val, exists := args[key] + if !exists { + return defaultVal + } + if i, ok := ToInt64(val); ok { + return i + } + return defaultVal +} diff --git a/pkg/params/params_test.go b/pkg/params/params_test.go index f951cc4..85f4606 100644 --- a/pkg/params/params_test.go +++ b/pkg/params/params_test.go @@ -5,6 +5,63 @@ import ( "testing" ) +func TestToInt64(t *testing.T) { + tests := []struct { + name string + val any + want int64 + ok bool + }{ + {"float64", float64(42), 42, true}, + {"float64 zero", float64(0), 0, true}, + {"float64 negative", float64(-5), -5, true}, + {"string", "123", 123, true}, + {"string zero", "0", 0, true}, + {"string negative", "-10", -10, true}, + {"invalid string", "abc", 0, false}, + {"decimal string", "1.5", 0, false}, + {"bool", true, 0, false}, + {"nil", nil, 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := ToInt64(tt.val) + if ok != tt.ok { + t.Errorf("ToInt64() ok = %v, want %v", ok, tt.ok) + } + if got != tt.want { + t.Errorf("ToInt64() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetOptionalInt(t *testing.T) { + tests := []struct { + name string + args map[string]any + key string + defaultVal int64 + want int64 + }{ + {"present float64", map[string]any{"page": float64(3)}, "page", 1, 3}, + {"present string", map[string]any{"page": "5"}, "page", 1, 5}, + {"missing key", map[string]any{}, "page", 1, 1}, + {"invalid string", map[string]any{"page": "abc"}, "page", 1, 1}, + {"invalid type", map[string]any{"page": true}, "page", 1, 1}, + {"zero value", map[string]any{"id": float64(0)}, "id", 99, 0}, + {"string zero", map[string]any{"id": "0"}, "id", 99, 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetOptionalInt(tt.args, tt.key, tt.defaultVal) + if got != tt.want { + t.Errorf("GetOptionalInt() = %v, want %v", got, tt.want) + } + }) + } +} + func TestGetIndex(t *testing.T) { tests := []struct { name string