diff --git a/operation/issue/issue.go b/operation/issue/issue.go index ba62951..5c7387e 100644 --- a/operation/issue/issue.go +++ b/operation/issue/issue.go @@ -6,6 +6,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/ptr" "gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/tool" @@ -136,17 +137,17 @@ func GetIssueByIndexFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - issue, _, err := client.GetIssue(owner, repo, int64(index)) + issue, _, err := client.GetIssue(owner, repo, index) if err != nil { - return to.ErrorResult(fmt.Errorf("get %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("get %v/%v/issue/%v err: %v", owner, repo, index, err)) } return to.TextResult(issue) @@ -235,9 +236,9 @@ func CreateIssueCommentFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } body, ok := req.GetArguments()["body"].(string) if !ok { @@ -250,9 +251,9 @@ func CreateIssueCommentFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - issueComment, _, err := client.CreateIssueComment(owner, repo, int64(index), opt) + issueComment, _, err := client.CreateIssueComment(owner, repo, index, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("create %v/%v/issue/%v/comment err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("create %v/%v/issue/%v/comment err: %v", owner, repo, index, err)) } return to.TextResult(issueComment) @@ -268,9 +269,9 @@ func EditIssueFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolRes if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } opt := gitea_sdk.EditIssueOption{} @@ -307,9 +308,9 @@ func EditIssueFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolRes if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - issue, _, err := client.EditIssue(owner, repo, int64(index), opt) + issue, _, err := client.EditIssue(owner, repo, index, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("edit %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("edit %v/%v/issue/%v err: %v", owner, repo, index, err)) } return to.TextResult(issue) @@ -358,18 +359,18 @@ func GetIssueCommentsByIndexFn(ctx context.Context, req mcp.CallToolRequest) (*m if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } opt := gitea_sdk.ListIssueCommentOptions{} client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - issue, _, err := client.ListIssueComments(owner, repo, int64(index), opt) + issue, _, err := client.ListIssueComments(owner, repo, index, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("get %v/%v/issues/%v/comments err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("get %v/%v/issues/%v/comments err: %v", owner, repo, index, err)) } return to.TextResult(issue) diff --git a/operation/label/label.go b/operation/label/label.go index e6bbbb8..2eee6c2 100644 --- a/operation/label/label.go +++ b/operation/label/label.go @@ -6,6 +6,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/ptr" "gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/tool" @@ -380,9 +381,9 @@ func AddIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("issue index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } labelsRaw, ok := req.GetArguments()["labels"].([]interface{}) if !ok { @@ -405,9 +406,9 @@ func AddIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - issueLabels, _, err := client.AddIssueLabels(owner, repo, int64(index), opt) + issueLabels, _, err := client.AddIssueLabels(owner, repo, index, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("add labels to %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("add labels to %v/%v/issue/%v err: %v", owner, repo, index, err)) } return to.TextResult(issueLabels) } @@ -422,9 +423,9 @@ func ReplaceIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("issue index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } labelsRaw, ok := req.GetArguments()["labels"].([]interface{}) if !ok { @@ -447,9 +448,9 @@ func ReplaceIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - issueLabels, _, err := client.ReplaceIssueLabels(owner, repo, int64(index), opt) + issueLabels, _, err := client.ReplaceIssueLabels(owner, repo, index, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("replace labels on %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("replace labels on %v/%v/issue/%v err: %v", owner, repo, index, err)) } return to.TextResult(issueLabels) } @@ -464,18 +465,18 @@ func ClearIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("issue index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.ClearIssueLabels(owner, repo, int64(index)) + _, err = client.ClearIssueLabels(owner, repo, index) if err != nil { - return to.ErrorResult(fmt.Errorf("clear labels on %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("clear labels on %v/%v/issue/%v err: %v", owner, repo, index, err)) } return to.TextResult("Labels cleared successfully") } @@ -490,9 +491,9 @@ func RemoveIssueLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("issue index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } labelID, ok := req.GetArguments()["label_id"].(float64) if !ok { @@ -503,9 +504,9 @@ func RemoveIssueLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteIssueLabel(owner, repo, int64(index), int64(labelID)) + _, err = client.DeleteIssueLabel(owner, repo, index, int64(labelID)) if err != nil { - return to.ErrorResult(fmt.Errorf("remove label %v from %v/%v/issue/%v err: %v", int64(labelID), owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("remove label %v from %v/%v/issue/%v err: %v", int64(labelID), owner, repo, index, err)) } return to.TextResult("Label removed successfully") } diff --git a/operation/pull/pull.go b/operation/pull/pull.go index f5ff583..642f4ee 100644 --- a/operation/pull/pull.go +++ b/operation/pull/pull.go @@ -6,6 +6,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/tool" @@ -237,17 +238,17 @@ func GetPullRequestByIndexFn(ctx context.Context, req mcp.CallToolRequest) (*mcp if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - pr, _, err := client.GetPullRequest(owner, repo, int64(index)) + pr, _, err := client.GetPullRequest(owner, repo, index) if err != nil { - return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v err: %v", owner, repo, index, err)) } return to.TextResult(pr) @@ -263,9 +264,9 @@ func GetPullRequestDiffFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } binary, _ := req.GetArguments()["binary"].(bool) @@ -273,17 +274,17 @@ func GetPullRequestDiffFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - diffBytes, _, err := client.GetPullRequestDiff(owner, repo, int64(index), gitea_sdk.PullRequestDiffOptions{ + diffBytes, _, err := client.GetPullRequestDiff(owner, repo, index, gitea_sdk.PullRequestDiffOptions{ Binary: binary, }) if err != nil { - return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v diff err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v diff err: %v", owner, repo, index, err)) } result := map[string]interface{}{ "diff": string(diffBytes), "binary": binary, - "index": int64(index), + "index": index, "repo": repo, "owner": owner, } @@ -388,9 +389,9 @@ func CreatePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) ( if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } var reviewers []string @@ -420,12 +421,12 @@ func CreatePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) ( return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.CreateReviewRequests(owner, repo, int64(index), gitea_sdk.PullReviewRequestOptions{ + _, err = client.CreateReviewRequests(owner, repo, index, gitea_sdk.PullReviewRequestOptions{ Reviewers: reviewers, TeamReviewers: teamReviewers, }) if err != nil { - return to.ErrorResult(fmt.Errorf("create review requests for %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("create review requests for %v/%v/pr/%v err: %v", owner, repo, index, err)) } // Return a success message instead of the Response object which contains non-serializable functions @@ -433,7 +434,7 @@ func CreatePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) ( "message": "Successfully created review requests", "reviewers": reviewers, "team_reviewers": teamReviewers, - "pr_index": int64(index), + "pr_index": index, "repository": fmt.Sprintf("%s/%s", owner, repo), } @@ -450,9 +451,9 @@ func DeletePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) ( if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } var reviewers []string @@ -482,19 +483,19 @@ func DeletePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) ( return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteReviewRequests(owner, repo, int64(index), gitea_sdk.PullReviewRequestOptions{ + _, err = client.DeleteReviewRequests(owner, repo, index, gitea_sdk.PullReviewRequestOptions{ Reviewers: reviewers, TeamReviewers: teamReviewers, }) if err != nil { - return to.ErrorResult(fmt.Errorf("delete review requests for %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("delete review requests for %v/%v/pr/%v err: %v", owner, repo, index, err)) } successMsg := map[string]interface{}{ "message": "Successfully deleted review requests", "reviewers": reviewers, "team_reviewers": teamReviewers, - "pr_index": int64(index), + "pr_index": index, "repository": fmt.Sprintf("%s/%s", owner, repo), } @@ -511,9 +512,9 @@ func ListPullRequestReviewsFn(ctx context.Context, req mcp.CallToolRequest) (*mc if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } page, ok := req.GetArguments()["page"].(float64) if !ok { @@ -529,14 +530,14 @@ func ListPullRequestReviewsFn(ctx context.Context, req mcp.CallToolRequest) (*mc return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - reviews, _, err := client.ListPullReviews(owner, repo, int64(index), gitea_sdk.ListPullReviewsOptions{ + reviews, _, err := client.ListPullReviews(owner, repo, index, gitea_sdk.ListPullReviewsOptions{ ListOptions: gitea_sdk.ListOptions{ Page: int(page), PageSize: int(pageSize), }, }) if err != nil { - return to.ErrorResult(fmt.Errorf("list reviews for %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("list reviews for %v/%v/pr/%v err: %v", owner, repo, index, err)) } return to.TextResult(reviews) @@ -552,9 +553,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp. if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } reviewID, ok := req.GetArguments()["review_id"].(float64) if !ok { @@ -566,9 +567,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp. return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - review, _, err := client.GetPullReview(owner, repo, int64(index), int64(reviewID)) + review, _, err := client.GetPullReview(owner, repo, index, int64(reviewID)) if err != nil { - return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) } return to.TextResult(review) @@ -584,9 +585,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } reviewID, ok := req.GetArguments()["review_id"].(float64) if !ok { @@ -598,9 +599,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - comments, _, err := client.ListPullReviewComments(owner, repo, int64(index), int64(reviewID)) + comments, _, err := client.ListPullReviewComments(owner, repo, index, int64(reviewID)) if err != nil { - return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) } return to.TextResult(comments) @@ -616,9 +617,9 @@ func CreatePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } opt := gitea_sdk.CreatePullReviewOptions{} @@ -662,9 +663,9 @@ func CreatePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - review, _, err := client.CreatePullReview(owner, repo, int64(index), opt) + review, _, err := client.CreatePullReview(owner, repo, index, opt) if err != nil { - return to.ErrorResult(fmt.Errorf("create review for %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("create review for %v/%v/pr/%v err: %v", owner, repo, index, err)) } return to.TextResult(review) @@ -680,9 +681,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } reviewID, ok := req.GetArguments()["review_id"].(float64) if !ok { @@ -705,9 +706,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - review, _, err := client.SubmitPullReview(owner, repo, int64(index), int64(reviewID), opt) + review, _, err := client.SubmitPullReview(owner, repo, index, int64(reviewID), opt) if err != nil { - return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) } return to.TextResult(review) @@ -723,9 +724,9 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } reviewID, ok := req.GetArguments()["review_id"].(float64) if !ok { @@ -737,15 +738,15 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeletePullReview(owner, repo, int64(index), int64(reviewID)) + _, err = client.DeletePullReview(owner, repo, index, int64(reviewID)) if err != nil { - return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) } successMsg := map[string]interface{}{ "message": "Successfully deleted review", "review_id": int64(reviewID), - "pr_index": int64(index), + "pr_index": index, "repository": fmt.Sprintf("%s/%s", owner, repo), } @@ -762,9 +763,9 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (* if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } reviewID, ok := req.GetArguments()["review_id"].(float64) if !ok { @@ -781,15 +782,15 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (* return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DismissPullReview(owner, repo, int64(index), int64(reviewID), opt) + _, err = client.DismissPullReview(owner, repo, index, int64(reviewID), opt) if err != nil { - return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) } successMsg := map[string]interface{}{ "message": "Successfully dismissed review", "review_id": int64(reviewID), - "pr_index": int64(index), + "pr_index": index, "repository": fmt.Sprintf("%s/%s", owner, repo), } diff --git a/operation/timetracking/timetracking.go b/operation/timetracking/timetracking.go index 04a5804..0f6d169 100644 --- a/operation/timetracking/timetracking.go +++ b/operation/timetracking/timetracking.go @@ -8,6 +8,7 @@ import ( gitea_sdk "code.gitea.io/sdk/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/params" "gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/tool" @@ -134,19 +135,19 @@ func StartStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.StartIssueStopWatch(owner, repo, int64(index)) + _, err = client.StartIssueStopWatch(owner, repo, index) if err != nil { - return to.ErrorResult(fmt.Errorf("start stopwatch on %s/%s#%d err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("start stopwatch on %s/%s#%d err: %v", owner, repo, index, err)) } - return to.TextResult(fmt.Sprintf("Stopwatch started on issue %s/%s#%d", owner, repo, int64(index))) + return to.TextResult(fmt.Sprintf("Stopwatch started on issue %s/%s#%d", owner, repo, index)) } func StopStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -159,19 +160,19 @@ func StopStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.StopIssueStopWatch(owner, repo, int64(index)) + _, err = client.StopIssueStopWatch(owner, repo, index) if err != nil { - return to.ErrorResult(fmt.Errorf("stop stopwatch on %s/%s#%d err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("stop stopwatch on %s/%s#%d err: %v", owner, repo, index, err)) } - return to.TextResult(fmt.Sprintf("Stopwatch stopped on issue %s/%s#%d - time recorded", owner, repo, int64(index))) + return to.TextResult(fmt.Sprintf("Stopwatch stopped on issue %s/%s#%d - time recorded", owner, repo, index)) } func DeleteStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -184,19 +185,19 @@ func DeleteStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteIssueStopwatch(owner, repo, int64(index)) + _, err = client.DeleteIssueStopwatch(owner, repo, index) if err != nil { - return to.ErrorResult(fmt.Errorf("delete stopwatch on %s/%s#%d err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("delete stopwatch on %s/%s#%d err: %v", owner, repo, index, err)) } - return to.TextResult(fmt.Sprintf("Stopwatch deleted/cancelled on issue %s/%s#%d", owner, repo, int64(index))) + return to.TextResult(fmt.Sprintf("Stopwatch deleted/cancelled on issue %s/%s#%d", owner, repo, index)) } func GetMyStopwatchesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -227,9 +228,9 @@ func ListTrackedTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } page, ok := req.GetArguments()["page"].(float64) if !ok { @@ -244,17 +245,17 @@ func ListTrackedTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - times, _, err := client.ListIssueTrackedTimes(owner, repo, int64(index), gitea_sdk.ListTrackedTimesOptions{ + times, _, err := client.ListIssueTrackedTimes(owner, repo, index, gitea_sdk.ListTrackedTimesOptions{ ListOptions: gitea_sdk.ListOptions{ Page: int(page), PageSize: int(pageSize), }, }) if err != nil { - return to.ErrorResult(fmt.Errorf("list tracked times for %s/%s#%d err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("list tracked times for %s/%s#%d err: %v", owner, repo, index, err)) } if len(times) == 0 { - return to.TextResult(fmt.Sprintf("No tracked times for issue %s/%s#%d", owner, repo, int64(index))) + return to.TextResult(fmt.Sprintf("No tracked times for issue %s/%s#%d", owner, repo, index)) } return to.TextResult(times) } @@ -269,9 +270,9 @@ func AddTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo if !ok { return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } timeSeconds, ok := req.GetArguments()["time"].(float64) @@ -282,11 +283,11 @@ func AddTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - trackedTime, _, err := client.AddTime(owner, repo, int64(index), gitea_sdk.AddTimeOption{ + trackedTime, _, err := client.AddTime(owner, repo, index, gitea_sdk.AddTimeOption{ Time: int64(timeSeconds), }) if err != nil { - return to.ErrorResult(fmt.Errorf("add tracked time to %s/%s#%d err: %v", owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("add tracked time to %s/%s#%d err: %v", owner, repo, index, err)) } return to.TextResult(trackedTime) } @@ -302,9 +303,9 @@ func DeleteTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal return to.ErrorResult(fmt.Errorf("repo is required")) } - index, ok := req.GetArguments()["index"].(float64) - if !ok { - return to.ErrorResult(fmt.Errorf("index is required")) + index, err := params.GetIndex(req.GetArguments(), "index") + if err != nil { + return to.ErrorResult(err) } id, ok := req.GetArguments()["id"].(float64) if !ok { @@ -314,11 +315,11 @@ func DeleteTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) } - _, err = client.DeleteTime(owner, repo, int64(index), int64(id)) + _, err = client.DeleteTime(owner, repo, index, int64(id)) if err != nil { - return to.ErrorResult(fmt.Errorf("delete tracked time %d from %s/%s#%d err: %v", int64(id), owner, repo, int64(index), err)) + return to.ErrorResult(fmt.Errorf("delete tracked time %d from %s/%s#%d err: %v", int64(id), owner, repo, index, err)) } - return to.TextResult(fmt.Sprintf("Tracked time entry %d deleted from issue %s/%s#%d", int64(id), owner, repo, int64(index))) + return to.TextResult(fmt.Sprintf("Tracked time entry %d deleted from issue %s/%s#%d", int64(id), owner, repo, index)) } func ListRepoTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { diff --git a/pkg/params/params.go b/pkg/params/params.go new file mode 100644 index 0000000..6afdafc --- /dev/null +++ b/pkg/params/params.go @@ -0,0 +1,33 @@ +package params + +import ( + "fmt" + "strconv" +) + +// GetIndex extracts an index parameter from MCP tool arguments. +// It accepts both numeric (float64 from JSON) and string representations. +// This provides better UX for LLM callers that may naturally use strings +// for identifiers like issue/PR numbers. +func GetIndex(args map[string]interface{}, key string) (int64, error) { + val, exists := args[key] + if !exists { + return 0, fmt.Errorf("%s is required", key) + } + + // Try float64 (JSON number type) + if f, ok := val.(float64); ok { + return int64(f), nil + } + + // Try string and parse to integer + if s, ok := val.(string); ok { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, fmt.Errorf("%s must be a valid integer (got %q)", key, s) + } + return i, nil + } + + return 0, fmt.Errorf("%s must be a number or numeric string", key) +} diff --git a/pkg/params/params_test.go b/pkg/params/params_test.go new file mode 100644 index 0000000..7461882 --- /dev/null +++ b/pkg/params/params_test.go @@ -0,0 +1,116 @@ +package params + +import ( + "testing" +) + +func TestGetIndex(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + key string + wantIndex int64 + wantErr bool + errMsg string + }{ + { + name: "valid float64", + args: map[string]interface{}{"index": float64(123)}, + key: "index", + wantIndex: 123, + wantErr: false, + }, + { + name: "valid string", + args: map[string]interface{}{"index": "456"}, + key: "index", + wantIndex: 456, + wantErr: false, + }, + { + name: "valid string with large number", + args: map[string]interface{}{"index": "999999"}, + key: "index", + wantIndex: 999999, + wantErr: false, + }, + { + name: "missing parameter", + args: map[string]interface{}{}, + key: "index", + wantErr: true, + errMsg: "index is required", + }, + { + name: "invalid string (not a number)", + args: map[string]interface{}{"index": "abc"}, + key: "index", + wantErr: true, + errMsg: "must be a valid integer", + }, + { + name: "invalid string (decimal)", + args: map[string]interface{}{"index": "12.34"}, + key: "index", + wantErr: true, + errMsg: "must be a valid integer", + }, + { + name: "invalid type (bool)", + args: map[string]interface{}{"index": true}, + key: "index", + wantErr: true, + errMsg: "must be a number or numeric string", + }, + { + name: "invalid type (map)", + args: map[string]interface{}{"index": map[string]string{"foo": "bar"}}, + key: "index", + wantErr: true, + errMsg: "must be a number or numeric string", + }, + { + name: "custom key name", + args: map[string]interface{}{"pr_index": "789"}, + key: "pr_index", + wantIndex: 789, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIndex, err := GetIndex(tt.args, tt.key) + if tt.wantErr { + if err == nil { + t.Errorf("GetIndex() expected error but got nil") + return + } + if tt.errMsg != "" && !contains(err.Error(), tt.errMsg) { + t.Errorf("GetIndex() error = %v, want error containing %q", err, tt.errMsg) + } + return + } + if err != nil { + t.Errorf("GetIndex() unexpected error = %v", err) + return + } + if gotIndex != tt.wantIndex { + t.Errorf("GetIndex() = %v, want %v", gotIndex, tt.wantIndex) + } + }) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || containsMiddle(s, substr))) +} + +func containsMiddle(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}