feat: accept string or number for index parameters (#121)

This change makes index parameters more flexible by accepting both
numeric and string values. LLM agents often pass issue/PR indices
as strings (e.g., "123") since they appear as string identifiers in
URLs and CLI contexts. The implementation:

- Created pkg/params package with GetIndex() helper function
- Updated 25+ tool functions across issue, pull, label, and timetracking operations
- Improved error messages to say "must be a valid integer" instead of misleading "is required"
- Added comprehensive tests for both numeric and string inputs

This improves UX for MCP clients and LLMs while maintaining backward
compatibility with existing numeric callers.

Fixes: #121

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
James Pharaoh
2026-02-10 09:23:45 +00:00
parent 1f7392305f
commit 71dbc9d6da
6 changed files with 288 additions and 135 deletions

View File

@@ -6,6 +6,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/ptr" "gitea.com/gitea/gitea-mcp/pkg/ptr"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"gitea.com/gitea/gitea-mcp/pkg/tool" "gitea.com/gitea/gitea-mcp/pkg/tool"
@@ -136,17 +137,17 @@ func GetIssueByIndexFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
issue, _, err := client.GetIssue(owner, repo, int64(index)) issue, _, err := client.GetIssue(owner, repo, index)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("get %v/%v/issue/%v err: %v", owner, repo, index, err))
} }
return to.TextResult(issue) return to.TextResult(issue)
@@ -235,9 +236,9 @@ func CreateIssueCommentFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
body, ok := req.GetArguments()["body"].(string) body, ok := req.GetArguments()["body"].(string)
if !ok { if !ok {
@@ -250,9 +251,9 @@ func CreateIssueCommentFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
issueComment, _, err := client.CreateIssueComment(owner, repo, int64(index), opt) issueComment, _, err := client.CreateIssueComment(owner, repo, index, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("create %v/%v/issue/%v/comment err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("create %v/%v/issue/%v/comment err: %v", owner, repo, index, err))
} }
return to.TextResult(issueComment) return to.TextResult(issueComment)
@@ -268,9 +269,9 @@ func EditIssueFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolRes
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
opt := gitea_sdk.EditIssueOption{} opt := gitea_sdk.EditIssueOption{}
@@ -307,9 +308,9 @@ func EditIssueFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolRes
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
issue, _, err := client.EditIssue(owner, repo, int64(index), opt) issue, _, err := client.EditIssue(owner, repo, index, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("edit %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("edit %v/%v/issue/%v err: %v", owner, repo, index, err))
} }
return to.TextResult(issue) return to.TextResult(issue)
@@ -358,18 +359,18 @@ func GetIssueCommentsByIndexFn(ctx context.Context, req mcp.CallToolRequest) (*m
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
opt := gitea_sdk.ListIssueCommentOptions{} opt := gitea_sdk.ListIssueCommentOptions{}
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
issue, _, err := client.ListIssueComments(owner, repo, int64(index), opt) issue, _, err := client.ListIssueComments(owner, repo, index, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get %v/%v/issues/%v/comments err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("get %v/%v/issues/%v/comments err: %v", owner, repo, index, err))
} }
return to.TextResult(issue) return to.TextResult(issue)

View File

@@ -6,6 +6,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/ptr" "gitea.com/gitea/gitea-mcp/pkg/ptr"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"gitea.com/gitea/gitea-mcp/pkg/tool" "gitea.com/gitea/gitea-mcp/pkg/tool"
@@ -380,9 +381,9 @@ func AddIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("issue index is required")) return to.ErrorResult(err)
} }
labelsRaw, ok := req.GetArguments()["labels"].([]interface{}) labelsRaw, ok := req.GetArguments()["labels"].([]interface{})
if !ok { if !ok {
@@ -405,9 +406,9 @@ func AddIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
issueLabels, _, err := client.AddIssueLabels(owner, repo, int64(index), opt) issueLabels, _, err := client.AddIssueLabels(owner, repo, index, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("add labels to %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("add labels to %v/%v/issue/%v err: %v", owner, repo, index, err))
} }
return to.TextResult(issueLabels) return to.TextResult(issueLabels)
} }
@@ -422,9 +423,9 @@ func ReplaceIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("issue index is required")) return to.ErrorResult(err)
} }
labelsRaw, ok := req.GetArguments()["labels"].([]interface{}) labelsRaw, ok := req.GetArguments()["labels"].([]interface{})
if !ok { if !ok {
@@ -447,9 +448,9 @@ func ReplaceIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
issueLabels, _, err := client.ReplaceIssueLabels(owner, repo, int64(index), opt) issueLabels, _, err := client.ReplaceIssueLabels(owner, repo, index, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("replace labels on %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("replace labels on %v/%v/issue/%v err: %v", owner, repo, index, err))
} }
return to.TextResult(issueLabels) return to.TextResult(issueLabels)
} }
@@ -464,18 +465,18 @@ func ClearIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("issue index is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.ClearIssueLabels(owner, repo, int64(index)) _, err = client.ClearIssueLabels(owner, repo, index)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("clear labels on %v/%v/issue/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("clear labels on %v/%v/issue/%v err: %v", owner, repo, index, err))
} }
return to.TextResult("Labels cleared successfully") return to.TextResult("Labels cleared successfully")
} }
@@ -490,9 +491,9 @@ func RemoveIssueLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("issue index is required")) return to.ErrorResult(err)
} }
labelID, ok := req.GetArguments()["label_id"].(float64) labelID, ok := req.GetArguments()["label_id"].(float64)
if !ok { if !ok {
@@ -503,9 +504,9 @@ func RemoveIssueLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteIssueLabel(owner, repo, int64(index), int64(labelID)) _, err = client.DeleteIssueLabel(owner, repo, index, int64(labelID))
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("remove label %v from %v/%v/issue/%v err: %v", int64(labelID), owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("remove label %v from %v/%v/issue/%v err: %v", int64(labelID), owner, repo, index, err))
} }
return to.TextResult("Label removed successfully") return to.TextResult("Label removed successfully")
} }

View File

@@ -6,6 +6,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"gitea.com/gitea/gitea-mcp/pkg/tool" "gitea.com/gitea/gitea-mcp/pkg/tool"
@@ -237,17 +238,17 @@ func GetPullRequestByIndexFn(ctx context.Context, req mcp.CallToolRequest) (*mcp
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
pr, _, err := client.GetPullRequest(owner, repo, int64(index)) pr, _, err := client.GetPullRequest(owner, repo, index)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v err: %v", owner, repo, index, err))
} }
return to.TextResult(pr) return to.TextResult(pr)
@@ -263,9 +264,9 @@ func GetPullRequestDiffFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
binary, _ := req.GetArguments()["binary"].(bool) binary, _ := req.GetArguments()["binary"].(bool)
@@ -273,17 +274,17 @@ func GetPullRequestDiffFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
diffBytes, _, err := client.GetPullRequestDiff(owner, repo, int64(index), gitea_sdk.PullRequestDiffOptions{ diffBytes, _, err := client.GetPullRequestDiff(owner, repo, index, gitea_sdk.PullRequestDiffOptions{
Binary: binary, Binary: binary,
}) })
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v diff err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v diff err: %v", owner, repo, index, err))
} }
result := map[string]interface{}{ result := map[string]interface{}{
"diff": string(diffBytes), "diff": string(diffBytes),
"binary": binary, "binary": binary,
"index": int64(index), "index": index,
"repo": repo, "repo": repo,
"owner": owner, "owner": owner,
} }
@@ -388,9 +389,9 @@ func CreatePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) (
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
var reviewers []string var reviewers []string
@@ -420,12 +421,12 @@ func CreatePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) (
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.CreateReviewRequests(owner, repo, int64(index), gitea_sdk.PullReviewRequestOptions{ _, err = client.CreateReviewRequests(owner, repo, index, gitea_sdk.PullReviewRequestOptions{
Reviewers: reviewers, Reviewers: reviewers,
TeamReviewers: teamReviewers, TeamReviewers: teamReviewers,
}) })
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("create review requests for %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("create review requests for %v/%v/pr/%v err: %v", owner, repo, index, err))
} }
// Return a success message instead of the Response object which contains non-serializable functions // Return a success message instead of the Response object which contains non-serializable functions
@@ -433,7 +434,7 @@ func CreatePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) (
"message": "Successfully created review requests", "message": "Successfully created review requests",
"reviewers": reviewers, "reviewers": reviewers,
"team_reviewers": teamReviewers, "team_reviewers": teamReviewers,
"pr_index": int64(index), "pr_index": index,
"repository": fmt.Sprintf("%s/%s", owner, repo), "repository": fmt.Sprintf("%s/%s", owner, repo),
} }
@@ -450,9 +451,9 @@ func DeletePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) (
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
var reviewers []string var reviewers []string
@@ -482,19 +483,19 @@ func DeletePullRequestReviewerFn(ctx context.Context, req mcp.CallToolRequest) (
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteReviewRequests(owner, repo, int64(index), gitea_sdk.PullReviewRequestOptions{ _, err = client.DeleteReviewRequests(owner, repo, index, gitea_sdk.PullReviewRequestOptions{
Reviewers: reviewers, Reviewers: reviewers,
TeamReviewers: teamReviewers, TeamReviewers: teamReviewers,
}) })
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("delete review requests for %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("delete review requests for %v/%v/pr/%v err: %v", owner, repo, index, err))
} }
successMsg := map[string]interface{}{ successMsg := map[string]interface{}{
"message": "Successfully deleted review requests", "message": "Successfully deleted review requests",
"reviewers": reviewers, "reviewers": reviewers,
"team_reviewers": teamReviewers, "team_reviewers": teamReviewers,
"pr_index": int64(index), "pr_index": index,
"repository": fmt.Sprintf("%s/%s", owner, repo), "repository": fmt.Sprintf("%s/%s", owner, repo),
} }
@@ -511,9 +512,9 @@ func ListPullRequestReviewsFn(ctx context.Context, req mcp.CallToolRequest) (*mc
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
page, ok := req.GetArguments()["page"].(float64) page, ok := req.GetArguments()["page"].(float64)
if !ok { if !ok {
@@ -529,14 +530,14 @@ func ListPullRequestReviewsFn(ctx context.Context, req mcp.CallToolRequest) (*mc
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
reviews, _, err := client.ListPullReviews(owner, repo, int64(index), gitea_sdk.ListPullReviewsOptions{ reviews, _, err := client.ListPullReviews(owner, repo, index, gitea_sdk.ListPullReviewsOptions{
ListOptions: gitea_sdk.ListOptions{ ListOptions: gitea_sdk.ListOptions{
Page: int(page), Page: int(page),
PageSize: int(pageSize), PageSize: int(pageSize),
}, },
}) })
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("list reviews for %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("list reviews for %v/%v/pr/%v err: %v", owner, repo, index, err))
} }
return to.TextResult(reviews) return to.TextResult(reviews)
@@ -552,9 +553,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, ok := req.GetArguments()["review_id"].(float64)
if !ok { if !ok {
@@ -566,9 +567,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
review, _, err := client.GetPullReview(owner, repo, int64(index), int64(reviewID)) review, _, err := client.GetPullReview(owner, repo, index, int64(reviewID))
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
} }
return to.TextResult(review) return to.TextResult(review)
@@ -584,9 +585,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, ok := req.GetArguments()["review_id"].(float64)
if !ok { if !ok {
@@ -598,9 +599,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
comments, _, err := client.ListPullReviewComments(owner, repo, int64(index), int64(reviewID)) comments, _, err := client.ListPullReviewComments(owner, repo, index, int64(reviewID))
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
} }
return to.TextResult(comments) return to.TextResult(comments)
@@ -616,9 +617,9 @@ func CreatePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
opt := gitea_sdk.CreatePullReviewOptions{} opt := gitea_sdk.CreatePullReviewOptions{}
@@ -662,9 +663,9 @@ func CreatePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
review, _, err := client.CreatePullReview(owner, repo, int64(index), opt) review, _, err := client.CreatePullReview(owner, repo, index, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("create review for %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("create review for %v/%v/pr/%v err: %v", owner, repo, index, err))
} }
return to.TextResult(review) return to.TextResult(review)
@@ -680,9 +681,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, ok := req.GetArguments()["review_id"].(float64)
if !ok { if !ok {
@@ -705,9 +706,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
review, _, err := client.SubmitPullReview(owner, repo, int64(index), int64(reviewID), opt) review, _, err := client.SubmitPullReview(owner, repo, index, int64(reviewID), opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
} }
return to.TextResult(review) return to.TextResult(review)
@@ -723,9 +724,9 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, ok := req.GetArguments()["review_id"].(float64)
if !ok { if !ok {
@@ -737,15 +738,15 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeletePullReview(owner, repo, int64(index), int64(reviewID)) _, err = client.DeletePullReview(owner, repo, index, int64(reviewID))
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
} }
successMsg := map[string]interface{}{ successMsg := map[string]interface{}{
"message": "Successfully deleted review", "message": "Successfully deleted review",
"review_id": int64(reviewID), "review_id": int64(reviewID),
"pr_index": int64(index), "pr_index": index,
"repository": fmt.Sprintf("%s/%s", owner, repo), "repository": fmt.Sprintf("%s/%s", owner, repo),
} }
@@ -762,9 +763,9 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, ok := req.GetArguments()["review_id"].(float64)
if !ok { if !ok {
@@ -781,15 +782,15 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DismissPullReview(owner, repo, int64(index), int64(reviewID), opt) _, err = client.DismissPullReview(owner, repo, index, int64(reviewID), opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
} }
successMsg := map[string]interface{}{ successMsg := map[string]interface{}{
"message": "Successfully dismissed review", "message": "Successfully dismissed review",
"review_id": int64(reviewID), "review_id": int64(reviewID),
"pr_index": int64(index), "pr_index": index,
"repository": fmt.Sprintf("%s/%s", owner, repo), "repository": fmt.Sprintf("%s/%s", owner, repo),
} }

View File

@@ -8,6 +8,7 @@ import (
gitea_sdk "code.gitea.io/sdk/gitea" gitea_sdk "code.gitea.io/sdk/gitea"
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"gitea.com/gitea/gitea-mcp/pkg/tool" "gitea.com/gitea/gitea-mcp/pkg/tool"
@@ -134,19 +135,19 @@ func StartStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.StartIssueStopWatch(owner, repo, int64(index)) _, err = client.StartIssueStopWatch(owner, repo, index)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("start stopwatch on %s/%s#%d err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("start stopwatch on %s/%s#%d err: %v", owner, repo, index, err))
} }
return to.TextResult(fmt.Sprintf("Stopwatch started on issue %s/%s#%d", owner, repo, int64(index))) return to.TextResult(fmt.Sprintf("Stopwatch started on issue %s/%s#%d", owner, repo, index))
} }
func StopStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func StopStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -159,19 +160,19 @@ func StopStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.StopIssueStopWatch(owner, repo, int64(index)) _, err = client.StopIssueStopWatch(owner, repo, index)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("stop stopwatch on %s/%s#%d err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("stop stopwatch on %s/%s#%d err: %v", owner, repo, index, err))
} }
return to.TextResult(fmt.Sprintf("Stopwatch stopped on issue %s/%s#%d - time recorded", owner, repo, int64(index))) return to.TextResult(fmt.Sprintf("Stopwatch stopped on issue %s/%s#%d - time recorded", owner, repo, index))
} }
func DeleteStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func DeleteStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -184,19 +185,19 @@ func DeleteStopwatchFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteIssueStopwatch(owner, repo, int64(index)) _, err = client.DeleteIssueStopwatch(owner, repo, index)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("delete stopwatch on %s/%s#%d err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("delete stopwatch on %s/%s#%d err: %v", owner, repo, index, err))
} }
return to.TextResult(fmt.Sprintf("Stopwatch deleted/cancelled on issue %s/%s#%d", owner, repo, int64(index))) return to.TextResult(fmt.Sprintf("Stopwatch deleted/cancelled on issue %s/%s#%d", owner, repo, index))
} }
func GetMyStopwatchesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func GetMyStopwatchesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -227,9 +228,9 @@ func ListTrackedTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
page, ok := req.GetArguments()["page"].(float64) page, ok := req.GetArguments()["page"].(float64)
if !ok { if !ok {
@@ -244,17 +245,17 @@ func ListTrackedTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
times, _, err := client.ListIssueTrackedTimes(owner, repo, int64(index), gitea_sdk.ListTrackedTimesOptions{ times, _, err := client.ListIssueTrackedTimes(owner, repo, index, gitea_sdk.ListTrackedTimesOptions{
ListOptions: gitea_sdk.ListOptions{ ListOptions: gitea_sdk.ListOptions{
Page: int(page), Page: int(page),
PageSize: int(pageSize), PageSize: int(pageSize),
}, },
}) })
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("list tracked times for %s/%s#%d err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("list tracked times for %s/%s#%d err: %v", owner, repo, index, err))
} }
if len(times) == 0 { if len(times) == 0 {
return to.TextResult(fmt.Sprintf("No tracked times for issue %s/%s#%d", owner, repo, int64(index))) return to.TextResult(fmt.Sprintf("No tracked times for issue %s/%s#%d", owner, repo, index))
} }
return to.TextResult(times) return to.TextResult(times)
} }
@@ -269,9 +270,9 @@ func AddTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
if !ok { if !ok {
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
timeSeconds, ok := req.GetArguments()["time"].(float64) timeSeconds, ok := req.GetArguments()["time"].(float64)
@@ -282,11 +283,11 @@ func AddTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
trackedTime, _, err := client.AddTime(owner, repo, int64(index), gitea_sdk.AddTimeOption{ trackedTime, _, err := client.AddTime(owner, repo, index, gitea_sdk.AddTimeOption{
Time: int64(timeSeconds), Time: int64(timeSeconds),
}) })
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("add tracked time to %s/%s#%d err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("add tracked time to %s/%s#%d err: %v", owner, repo, index, err))
} }
return to.TextResult(trackedTime) return to.TextResult(trackedTime)
} }
@@ -302,9 +303,9 @@ func DeleteTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal
return to.ErrorResult(fmt.Errorf("repo is required")) return to.ErrorResult(fmt.Errorf("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(fmt.Errorf("index is required")) return to.ErrorResult(err)
} }
id, ok := req.GetArguments()["id"].(float64) id, ok := req.GetArguments()["id"].(float64)
if !ok { if !ok {
@@ -314,11 +315,11 @@ func DeleteTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteTime(owner, repo, int64(index), int64(id)) _, err = client.DeleteTime(owner, repo, index, int64(id))
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("delete tracked time %d from %s/%s#%d err: %v", int64(id), owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("delete tracked time %d from %s/%s#%d err: %v", int64(id), owner, repo, index, err))
} }
return to.TextResult(fmt.Sprintf("Tracked time entry %d deleted from issue %s/%s#%d", int64(id), owner, repo, int64(index))) return to.TextResult(fmt.Sprintf("Tracked time entry %d deleted from issue %s/%s#%d", int64(id), owner, repo, index))
} }
func ListRepoTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func ListRepoTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {

33
pkg/params/params.go Normal file
View File

@@ -0,0 +1,33 @@
package params
import (
"fmt"
"strconv"
)
// GetIndex extracts an index parameter from MCP tool arguments.
// It accepts both numeric (float64 from JSON) and string representations.
// This provides better UX for LLM callers that may naturally use strings
// for identifiers like issue/PR numbers.
func GetIndex(args map[string]interface{}, key string) (int64, error) {
val, exists := args[key]
if !exists {
return 0, fmt.Errorf("%s is required", key)
}
// Try float64 (JSON number type)
if f, ok := val.(float64); ok {
return int64(f), nil
}
// Try string and parse to integer
if s, ok := val.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, fmt.Errorf("%s must be a valid integer (got %q)", key, s)
}
return i, nil
}
return 0, fmt.Errorf("%s must be a number or numeric string", key)
}

116
pkg/params/params_test.go Normal file
View File

@@ -0,0 +1,116 @@
package params
import (
"testing"
)
func TestGetIndex(t *testing.T) {
tests := []struct {
name string
args map[string]interface{}
key string
wantIndex int64
wantErr bool
errMsg string
}{
{
name: "valid float64",
args: map[string]interface{}{"index": float64(123)},
key: "index",
wantIndex: 123,
wantErr: false,
},
{
name: "valid string",
args: map[string]interface{}{"index": "456"},
key: "index",
wantIndex: 456,
wantErr: false,
},
{
name: "valid string with large number",
args: map[string]interface{}{"index": "999999"},
key: "index",
wantIndex: 999999,
wantErr: false,
},
{
name: "missing parameter",
args: map[string]interface{}{},
key: "index",
wantErr: true,
errMsg: "index is required",
},
{
name: "invalid string (not a number)",
args: map[string]interface{}{"index": "abc"},
key: "index",
wantErr: true,
errMsg: "must be a valid integer",
},
{
name: "invalid string (decimal)",
args: map[string]interface{}{"index": "12.34"},
key: "index",
wantErr: true,
errMsg: "must be a valid integer",
},
{
name: "invalid type (bool)",
args: map[string]interface{}{"index": true},
key: "index",
wantErr: true,
errMsg: "must be a number or numeric string",
},
{
name: "invalid type (map)",
args: map[string]interface{}{"index": map[string]string{"foo": "bar"}},
key: "index",
wantErr: true,
errMsg: "must be a number or numeric string",
},
{
name: "custom key name",
args: map[string]interface{}{"pr_index": "789"},
key: "pr_index",
wantIndex: 789,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotIndex, err := GetIndex(tt.args, tt.key)
if tt.wantErr {
if err == nil {
t.Errorf("GetIndex() expected error but got nil")
return
}
if tt.errMsg != "" && !contains(err.Error(), tt.errMsg) {
t.Errorf("GetIndex() error = %v, want error containing %q", err, tt.errMsg)
}
return
}
if err != nil {
t.Errorf("GetIndex() unexpected error = %v", err)
return
}
if gotIndex != tt.wantIndex {
t.Errorf("GetIndex() = %v, want %v", gotIndex, tt.wantIndex)
}
})
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || containsMiddle(s, substr)))
}
func containsMiddle(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}