feat: accept string values for all numeric input parameters (#138)

## Summary

- MCP clients may send numbers as strings. This adds `ToInt64` and `GetOptionalInt` helpers to `pkg/params` and replaces all raw `.(float64)` type assertions across operation handlers to accept both `float64` and string inputs.

## Test plan

- [x] Verify `go test ./...` passes
- [x] Test with an MCP client that sends numeric parameters as strings

*Created by Claude on behalf of @silverwind*

Reviewed-on: https://gitea.com/gitea/gitea-mcp/pulls/138
Reviewed-by: Lunny Xiao <xiaolunwen@gmail.com>
Co-authored-by: silverwind <me@silverwind.io>
Co-committed-by: silverwind <me@silverwind.io>
This commit is contained in:
silverwind
2026-02-25 23:28:14 +00:00
committed by Lunny Xiao
parent 4a2935d898
commit 67a1e1e7fe
18 changed files with 628 additions and 627 deletions

View File

@@ -11,6 +11,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
@@ -116,22 +117,12 @@ func GetRepoActionJobLogPreviewFn(ctx context.Context, req mcp.CallToolRequest)
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
jobIDFloat, ok := req.GetArguments()["job_id"].(float64) jobID, err := params.GetIndex(req.GetArguments(), "job_id")
if !ok || jobIDFloat <= 0 { if err != nil || jobID <= 0 {
return to.ErrorResult(errors.New("job_id is required")) return to.ErrorResult(errors.New("job_id is required"))
} }
tailLinesFloat, _ := req.GetArguments()["tail_lines"].(float64) tailLines := int(params.GetOptionalInt(req.GetArguments(), "tail_lines", 200))
maxBytesFloat, _ := req.GetArguments()["max_bytes"].(float64) maxBytes := int(params.GetOptionalInt(req.GetArguments(), "max_bytes", 65536))
tailLines := int(tailLinesFloat)
if tailLines <= 0 {
tailLines = 200
}
maxBytes := int(maxBytesFloat)
if maxBytes <= 0 {
maxBytes = 65536
}
jobID := int64(jobIDFloat)
raw, usedPath, err := fetchJobLogBytes(ctx, owner, repo, jobID) raw, usedPath, err := fetchJobLogBytes(ctx, owner, repo, jobID)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get job log err: %v", err)) return to.ErrorResult(fmt.Errorf("get job log err: %v", err))
@@ -161,12 +152,11 @@ func DownloadRepoActionJobLogFn(ctx context.Context, req mcp.CallToolRequest) (*
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
jobIDFloat, ok := req.GetArguments()["job_id"].(float64) jobID, err := params.GetIndex(req.GetArguments(), "job_id")
if !ok || jobIDFloat <= 0 { if err != nil || jobID <= 0 {
return to.ErrorResult(errors.New("job_id is required")) return to.ErrorResult(errors.New("job_id is required"))
} }
outputPath, _ := req.GetArguments()["output_path"].(string) outputPath, _ := req.GetArguments()["output_path"].(string)
jobID := int64(jobIDFloat)
raw, usedPath, err := fetchJobLogBytes(ctx, owner, repo, jobID) raw, usedPath, err := fetchJobLogBytes(ctx, owner, repo, jobID)
if err != nil { if err != nil {

View File

@@ -11,6 +11,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
@@ -155,14 +156,8 @@ func ListRepoActionWorkflowsFn(ctx context.Context, req mcp.CallToolRequest) (*m
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if page <= 0 { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 50)
page = 1
}
pageSize, _ := req.GetArguments()["pageSize"].(float64)
if pageSize <= 0 {
pageSize = 50
}
query := url.Values{} query := url.Values{}
query.Set("page", strconv.Itoa(int(page))) query.Set("page", strconv.Itoa(int(page)))
query.Set("limit", strconv.Itoa(int(pageSize))) query.Set("limit", strconv.Itoa(int(pageSize)))
@@ -271,14 +266,8 @@ func ListRepoActionRunsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if page <= 0 { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 50)
page = 1
}
pageSize, _ := req.GetArguments()["pageSize"].(float64)
if pageSize <= 0 {
pageSize = 50
}
statusFilter, _ := req.GetArguments()["status"].(string) statusFilter, _ := req.GetArguments()["status"].(string)
query := url.Values{} query := url.Values{}
@@ -311,15 +300,15 @@ func GetRepoActionRunFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
runID, ok := req.GetArguments()["run_id"].(float64) runID, err := params.GetIndex(req.GetArguments(), "run_id")
if !ok || runID <= 0 { if err != nil || runID <= 0 {
return to.ErrorResult(errors.New("run_id is required")) return to.ErrorResult(errors.New("run_id is required"))
} }
var result any var result any
err := doJSONWithFallback(ctx, "GET", err = doJSONWithFallback(ctx, "GET",
[]string{ []string{
fmt.Sprintf("repos/%s/%s/actions/runs/%d", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), fmt.Sprintf("repos/%s/%s/actions/runs/%d", url.PathEscape(owner), url.PathEscape(repo), runID),
}, },
nil, nil, &result, nil, nil, &result,
) )
@@ -339,14 +328,14 @@ func CancelRepoActionRunFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.C
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
runID, ok := req.GetArguments()["run_id"].(float64) runID, err := params.GetIndex(req.GetArguments(), "run_id")
if !ok || runID <= 0 { if err != nil || runID <= 0 {
return to.ErrorResult(errors.New("run_id is required")) return to.ErrorResult(errors.New("run_id is required"))
} }
err := doJSONWithFallback(ctx, "POST", err = doJSONWithFallback(ctx, "POST",
[]string{ []string{
fmt.Sprintf("repos/%s/%s/actions/runs/%d/cancel", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), fmt.Sprintf("repos/%s/%s/actions/runs/%d/cancel", url.PathEscape(owner), url.PathEscape(repo), runID),
}, },
nil, nil, nil, nil, nil, nil,
) )
@@ -366,15 +355,15 @@ func RerunRepoActionRunFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
runID, ok := req.GetArguments()["run_id"].(float64) runID, err := params.GetIndex(req.GetArguments(), "run_id")
if !ok || runID <= 0 { if err != nil || runID <= 0 {
return to.ErrorResult(errors.New("run_id is required")) return to.ErrorResult(errors.New("run_id is required"))
} }
err := doJSONWithFallback(ctx, "POST", err = doJSONWithFallback(ctx, "POST",
[]string{ []string{
fmt.Sprintf("repos/%s/%s/actions/runs/%d/rerun", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), fmt.Sprintf("repos/%s/%s/actions/runs/%d/rerun", url.PathEscape(owner), url.PathEscape(repo), runID),
fmt.Sprintf("repos/%s/%s/actions/runs/%d/rerun-failed-jobs", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), fmt.Sprintf("repos/%s/%s/actions/runs/%d/rerun-failed-jobs", url.PathEscape(owner), url.PathEscape(repo), runID),
}, },
nil, nil, nil, nil, nil, nil,
) )
@@ -398,14 +387,8 @@ func ListRepoActionJobsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if page <= 0 { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 50)
page = 1
}
pageSize, _ := req.GetArguments()["pageSize"].(float64)
if pageSize <= 0 {
pageSize = 50
}
statusFilter, _ := req.GetArguments()["status"].(string) statusFilter, _ := req.GetArguments()["status"].(string)
query := url.Values{} query := url.Values{}
@@ -438,27 +421,21 @@ func ListRepoActionRunJobsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
runID, ok := req.GetArguments()["run_id"].(float64) runID, err := params.GetIndex(req.GetArguments(), "run_id")
if !ok || runID <= 0 { if err != nil || runID <= 0 {
return to.ErrorResult(errors.New("run_id is required")) return to.ErrorResult(errors.New("run_id is required"))
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if page <= 0 { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 50)
page = 1
}
pageSize, _ := req.GetArguments()["pageSize"].(float64)
if pageSize <= 0 {
pageSize = 50
}
query := url.Values{} query := url.Values{}
query.Set("page", strconv.Itoa(int(page))) query.Set("page", strconv.Itoa(int(page)))
query.Set("limit", strconv.Itoa(int(pageSize))) query.Set("limit", strconv.Itoa(int(pageSize)))
var result any var result any
err := doJSONWithFallback(ctx, "GET", err = doJSONWithFallback(ctx, "GET",
[]string{ []string{
fmt.Sprintf("repos/%s/%s/actions/runs/%d/jobs", url.PathEscape(owner), url.PathEscape(repo), int64(runID)), fmt.Sprintf("repos/%s/%s/actions/runs/%d/jobs", url.PathEscape(owner), url.PathEscape(repo), runID),
}, },
query, nil, &result, query, nil, &result,
) )

View File

@@ -9,6 +9,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
gitea_sdk "code.gitea.io/sdk/gitea" gitea_sdk "code.gitea.io/sdk/gitea"
@@ -104,14 +105,8 @@ func ListRepoActionSecretsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if page <= 0 { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, _ := req.GetArguments()["pageSize"].(float64)
if pageSize <= 0 {
pageSize = 100
}
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
@@ -206,14 +201,8 @@ func ListOrgActionSecretsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
if !ok || org == "" { if !ok || org == "" {
return to.ErrorResult(errors.New("org is required")) return to.ErrorResult(errors.New("org is required"))
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if page <= 0 { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, _ := req.GetArguments()["pageSize"].(float64)
if pageSize <= 0 {
pageSize = 100
}
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {

View File

@@ -9,6 +9,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
gitea_sdk "code.gitea.io/sdk/gitea" gitea_sdk "code.gitea.io/sdk/gitea"
@@ -139,14 +140,8 @@ func ListRepoActionVariablesFn(ctx context.Context, req mcp.CallToolRequest) (*m
if !ok || repo == "" { if !ok || repo == "" {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if page <= 0 { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, _ := req.GetArguments()["pageSize"].(float64)
if pageSize <= 0 {
pageSize = 100
}
query := url.Values{} query := url.Values{}
query.Set("page", strconv.Itoa(int(page))) query.Set("page", strconv.Itoa(int(page)))
@@ -278,14 +273,8 @@ func ListOrgActionVariablesFn(ctx context.Context, req mcp.CallToolRequest) (*mc
if !ok || org == "" { if !ok || org == "" {
return to.ErrorResult(errors.New("org is required")) return to.ErrorResult(errors.New("org is required"))
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if page <= 0 { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, _ := req.GetArguments()["pageSize"].(float64)
if pageSize <= 0 {
pageSize = 100
}
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {

View File

@@ -167,14 +167,8 @@ func ListRepoIssuesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
if !ok { if !ok {
state = "all" state = "all"
} }
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
opt := gitea_sdk.ListIssueOption{ opt := gitea_sdk.ListIssueOption{
State: gitea_sdk.StateType(state), State: gitea_sdk.StateType(state),
ListOptions: gitea_sdk.ListOptions{ ListOptions: gitea_sdk.ListOptions{
@@ -295,9 +289,10 @@ func EditIssueFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolRes
} }
} }
opt.Assignees = assignees opt.Assignees = assignees
milestone, ok := req.GetArguments()["milestone"].(float64) if val, exists := req.GetArguments()["milestone"]; exists {
if ok { if milestone, ok := params.ToInt64(val); ok {
opt.Milestone = new(int64(milestone)) opt.Milestone = new(milestone)
}
} }
state, ok := req.GetArguments()["state"].(string) state, ok := req.GetArguments()["state"].(string)
if ok { if ok {
@@ -326,9 +321,9 @@ func EditIssueCommentFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
commentID, ok := req.GetArguments()["commentID"].(float64) commentID, err := params.GetIndex(req.GetArguments(), "commentID")
if !ok { if err != nil {
return to.ErrorResult(errors.New("comment ID is required")) return to.ErrorResult(err)
} }
body, ok := req.GetArguments()["body"].(string) body, ok := req.GetArguments()["body"].(string)
if !ok { if !ok {
@@ -341,9 +336,9 @@ func EditIssueCommentFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
issueComment, _, err := client.EditIssueComment(owner, repo, int64(commentID), opt) issueComment, _, err := client.EditIssueComment(owner, repo, commentID, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("edit %v/%v/issues/comments/%v err: %v", owner, repo, int64(commentID), err)) return to.ErrorResult(fmt.Errorf("edit %v/%v/issues/comments/%v err: %v", owner, repo, commentID, err))
} }
return to.TextResult(issueComment) return to.TextResult(issueComment)

View File

@@ -218,14 +218,8 @@ func ListRepoLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
opt := gitea_sdk.ListLabelsOptions{ opt := gitea_sdk.ListLabelsOptions{
ListOptions: gitea_sdk.ListOptions{ ListOptions: gitea_sdk.ListOptions{
@@ -254,18 +248,18 @@ func GetRepoLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("label ID is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
label, _, err := client.GetRepoLabel(owner, repo, int64(id)) label, _, err := client.GetRepoLabel(owner, repo, id)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get %v/%v/label/%v err: %v", owner, repo, int64(id), err)) return to.ErrorResult(fmt.Errorf("get %v/%v/label/%v err: %v", owner, repo, id, err))
} }
return to.TextResult(label) return to.TextResult(label)
} }
@@ -317,9 +311,9 @@ func EditRepoLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("label ID is required")) return to.ErrorResult(err)
} }
opt := gitea_sdk.EditLabelOption{} opt := gitea_sdk.EditLabelOption{}
@@ -337,9 +331,9 @@ func EditRepoLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
label, _, err := client.EditLabel(owner, repo, int64(id), opt) label, _, err := client.EditLabel(owner, repo, id, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("edit %v/%v/label/%v err: %v", owner, repo, int64(id), err)) return to.ErrorResult(fmt.Errorf("edit %v/%v/label/%v err: %v", owner, repo, id, err))
} }
return to.TextResult(label) return to.TextResult(label)
} }
@@ -354,18 +348,18 @@ func DeleteRepoLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("label ID is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteLabel(owner, repo, int64(id)) _, err = client.DeleteLabel(owner, repo, id)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("delete %v/%v/label/%v err: %v", owner, repo, int64(id), err)) return to.ErrorResult(fmt.Errorf("delete %v/%v/label/%v err: %v", owner, repo, id, err))
} }
return to.TextResult("Label deleted successfully") return to.TextResult("Label deleted successfully")
} }
@@ -390,8 +384,8 @@ func AddIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
} }
var labels []int64 var labels []int64
for _, l := range labelsRaw { for _, l := range labelsRaw {
if labelID, ok := l.(float64); ok { if labelID, ok := params.ToInt64(l); ok {
labels = append(labels, int64(labelID)) labels = append(labels, labelID)
} else { } else {
return to.ErrorResult(errors.New("invalid label ID in labels array")) return to.ErrorResult(errors.New("invalid label ID in labels array"))
} }
@@ -432,8 +426,8 @@ func ReplaceIssueLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca
} }
var labels []int64 var labels []int64
for _, l := range labelsRaw { for _, l := range labelsRaw {
if labelID, ok := l.(float64); ok { if labelID, ok := params.ToInt64(l); ok {
labels = append(labels, int64(labelID)) labels = append(labels, labelID)
} else { } else {
return to.ErrorResult(errors.New("invalid label ID in labels array")) return to.ErrorResult(errors.New("invalid label ID in labels array"))
} }
@@ -494,18 +488,18 @@ func RemoveIssueLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
if err != nil { if err != nil {
return to.ErrorResult(err) return to.ErrorResult(err)
} }
labelID, ok := req.GetArguments()["label_id"].(float64) labelID, err := params.GetIndex(req.GetArguments(), "label_id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("label ID is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteIssueLabel(owner, repo, index, int64(labelID)) _, err = client.DeleteIssueLabel(owner, repo, index, labelID)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("remove label %v from %v/%v/issue/%v err: %v", int64(labelID), owner, repo, index, err)) return to.ErrorResult(fmt.Errorf("remove label %v from %v/%v/issue/%v err: %v", labelID, owner, repo, index, err))
} }
return to.TextResult("Label removed successfully") return to.TextResult("Label removed successfully")
} }
@@ -516,14 +510,8 @@ func ListOrgLabelsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo
if !ok { if !ok {
return to.ErrorResult(errors.New("org is required")) return to.ErrorResult(errors.New("org is required"))
} }
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
opt := gitea_sdk.ListOrgLabelsOptions{ opt := gitea_sdk.ListOrgLabelsOptions{
ListOptions: gitea_sdk.ListOptions{ ListOptions: gitea_sdk.ListOptions{
@@ -583,9 +571,9 @@ func EditOrgLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool
if !ok { if !ok {
return to.ErrorResult(errors.New("org is required")) return to.ErrorResult(errors.New("org is required"))
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("label ID is required")) return to.ErrorResult(err)
} }
opt := gitea_sdk.EditOrgLabelOption{} opt := gitea_sdk.EditOrgLabelOption{}
@@ -606,9 +594,9 @@ func EditOrgLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
label, _, err := client.EditOrgLabel(org, int64(id), opt) label, _, err := client.EditOrgLabel(org, id, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("edit %v/labels/%v err: %v", org, int64(id), err)) return to.ErrorResult(fmt.Errorf("edit %v/labels/%v err: %v", org, id, err))
} }
return to.TextResult(label) return to.TextResult(label)
} }
@@ -619,18 +607,18 @@ func DeleteOrgLabelFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
if !ok { if !ok {
return to.ErrorResult(errors.New("org is required")) return to.ErrorResult(errors.New("org is required"))
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("label ID is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteOrgLabel(org, int64(id)) _, err = client.DeleteOrgLabel(org, id)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("delete %v/labels/%v err: %v", org, int64(id), err)) return to.ErrorResult(fmt.Errorf("delete %v/labels/%v err: %v", org, id, err))
} }
return to.TextResult("Label deleted successfully") return to.TextResult("Label deleted successfully")
} }

View File

@@ -7,6 +7,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"gitea.com/gitea/gitea-mcp/pkg/tool" "gitea.com/gitea/gitea-mcp/pkg/tool"
@@ -109,17 +110,17 @@ func GetMilestoneFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("id is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
milestone, _, err := client.GetMilestone(owner, repo, int64(id)) milestone, _, err := client.GetMilestone(owner, repo, id)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get %v/%v/milestone/%v err: %v", owner, repo, int64(id), err)) return to.ErrorResult(fmt.Errorf("get %v/%v/milestone/%v err: %v", owner, repo, id, err))
} }
return to.TextResult(milestone) return to.TextResult(milestone)
@@ -143,14 +144,8 @@ func ListMilestonesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
if !ok { if !ok {
name = "" name = ""
} }
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
opt := gitea_sdk.ListMilestoneOption{ opt := gitea_sdk.ListMilestoneOption{
State: gitea_sdk.StateType(state), State: gitea_sdk.StateType(state),
Name: name, Name: name,
@@ -216,9 +211,9 @@ func EditMilestoneFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("id is required")) return to.ErrorResult(err)
} }
opt := gitea_sdk.EditMilestoneOption{} opt := gitea_sdk.EditMilestoneOption{}
@@ -240,9 +235,9 @@ func EditMilestoneFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
milestone, _, err := client.EditMilestone(owner, repo, int64(id), opt) milestone, _, err := client.EditMilestone(owner, repo, id, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("edit %v/%v/milestone/%v err: %v", owner, repo, int64(id), err)) return to.ErrorResult(fmt.Errorf("edit %v/%v/milestone/%v err: %v", owner, repo, id, err))
} }
return to.TextResult(milestone) return to.TextResult(milestone)
@@ -258,17 +253,17 @@ func DeleteMilestoneFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("id is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteMilestone(owner, repo, int64(id)) _, err = client.DeleteMilestone(owner, repo, id)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("delete %v/%v/milestone/%v err: %v", owner, repo, int64(id), err)) return to.ErrorResult(fmt.Errorf("delete %v/%v/milestone/%v err: %v", owner, repo, id, err))
} }
return to.TextResult("Milestone deleted successfully") return to.TextResult("Milestone deleted successfully")

View File

@@ -345,19 +345,13 @@ func ListRepoPullRequestsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
if !ok { if !ok {
sort = "recentupdate" sort = "recentupdate"
} }
milestone, _ := req.GetArguments()["milestone"].(float64) milestone := params.GetOptionalInt(req.GetArguments(), "milestone", 0)
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
opt := gitea_sdk.ListPullRequestsOptions{ opt := gitea_sdk.ListPullRequestsOptions{
State: gitea_sdk.StateType(state), State: gitea_sdk.StateType(state),
Sort: sort, Sort: sort,
Milestone: int64(milestone), Milestone: milestone,
ListOptions: gitea_sdk.ListOptions{ ListOptions: gitea_sdk.ListOptions{
Page: int(page), Page: int(page),
PageSize: int(pageSize), PageSize: int(pageSize),
@@ -555,14 +549,8 @@ func ListPullRequestReviewsFn(ctx context.Context, req mcp.CallToolRequest) (*mc
if err != nil { if err != nil {
return to.ErrorResult(err) return to.ErrorResult(err)
} }
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
@@ -596,9 +584,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
if err != nil { if err != nil {
return to.ErrorResult(err) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("review_id is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
@@ -606,9 +594,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
review, _, err := client.GetPullReview(owner, repo, index, int64(reviewID)) review, _, err := client.GetPullReview(owner, repo, index, reviewID)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
} }
return to.TextResult(review) return to.TextResult(review)
@@ -628,9 +616,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques
if err != nil { if err != nil {
return to.ErrorResult(err) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("review_id is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
@@ -638,9 +626,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
comments, _, err := client.ListPullReviewComments(owner, repo, index, int64(reviewID)) comments, _, err := client.ListPullReviewComments(owner, repo, index, reviewID)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
} }
return to.TextResult(comments) return to.TextResult(comments)
@@ -685,11 +673,11 @@ func CreatePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
if body, ok := commentMap["body"].(string); ok { if body, ok := commentMap["body"].(string); ok {
reviewComment.Body = body reviewComment.Body = body
} }
if oldLineNum, ok := commentMap["old_line_num"].(float64); ok { if oldLineNum, ok := params.ToInt64(commentMap["old_line_num"]); ok {
reviewComment.OldLineNum = int64(oldLineNum) reviewComment.OldLineNum = oldLineNum
} }
if newLineNum, ok := commentMap["new_line_num"].(float64); ok { if newLineNum, ok := params.ToInt64(commentMap["new_line_num"]); ok {
reviewComment.NewLineNum = int64(newLineNum) reviewComment.NewLineNum = newLineNum
} }
opt.Comments = append(opt.Comments, reviewComment) opt.Comments = append(opt.Comments, reviewComment)
} }
@@ -724,9 +712,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
if err != nil { if err != nil {
return to.ErrorResult(err) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("review_id is required")) return to.ErrorResult(err)
} }
state, ok := req.GetArguments()["state"].(string) state, ok := req.GetArguments()["state"].(string)
if !ok { if !ok {
@@ -745,9 +733,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
review, _, err := client.SubmitPullReview(owner, repo, index, int64(reviewID), opt) review, _, err := client.SubmitPullReview(owner, repo, index, reviewID, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
} }
return to.TextResult(review) return to.TextResult(review)
@@ -767,9 +755,9 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
if err != nil { if err != nil {
return to.ErrorResult(err) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("review_id is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
@@ -777,14 +765,14 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeletePullReview(owner, repo, index, int64(reviewID)) _, err = client.DeletePullReview(owner, repo, index, reviewID)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
} }
successMsg := map[string]any{ successMsg := map[string]any{
"message": "Successfully deleted review", "message": "Successfully deleted review",
"review_id": int64(reviewID), "review_id": reviewID,
"pr_index": index, "pr_index": index,
"repository": fmt.Sprintf("%s/%s", owner, repo), "repository": fmt.Sprintf("%s/%s", owner, repo),
} }
@@ -806,9 +794,9 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*
if err != nil { if err != nil {
return to.ErrorResult(err) return to.ErrorResult(err)
} }
reviewID, ok := req.GetArguments()["review_id"].(float64) reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("review_id is required")) return to.ErrorResult(err)
} }
opt := gitea_sdk.DismissPullReviewOptions{} opt := gitea_sdk.DismissPullReviewOptions{}
@@ -821,14 +809,14 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DismissPullReview(owner, repo, index, int64(reviewID), opt) _, err = client.DismissPullReview(owner, repo, index, reviewID, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err)) return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
} }
successMsg := map[string]any{ successMsg := map[string]any{
"message": "Successfully dismissed review", "message": "Successfully dismissed review",
"review_id": int64(reviewID), "review_id": reviewID,
"pr_index": index, "pr_index": index,
"repository": fmt.Sprintf("%s/%s", owner, repo), "repository": fmt.Sprintf("%s/%s", owner, repo),
} }
@@ -917,9 +905,9 @@ func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
index, ok := req.GetArguments()["index"].(float64) index, err := params.GetIndex(req.GetArguments(), "index")
if !ok { if err != nil {
return to.ErrorResult(errors.New("index is required")) return to.ErrorResult(err)
} }
opt := gitea_sdk.EditPullRequestOption{} opt := gitea_sdk.EditPullRequestOption{}
@@ -947,8 +935,10 @@ func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
opt.Assignees = assignees opt.Assignees = assignees
} }
} }
if milestone, ok := req.GetArguments()["milestone"].(float64); ok { if val, exists := req.GetArguments()["milestone"]; exists {
opt.Milestone = int64(milestone) if milestone, ok := params.ToInt64(val); ok {
opt.Milestone = milestone
}
} }
if state, ok := req.GetArguments()["state"].(string); ok { if state, ok := req.GetArguments()["state"].(string); ok {
opt.State = new(gitea_sdk.StateType(state)) opt.State = new(gitea_sdk.StateType(state))
@@ -962,9 +952,9 @@ func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
pr, _, err := client.EditPullRequest(owner, repo, int64(index), opt) pr, _, err := client.EditPullRequest(owner, repo, index, opt)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("edit %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) return to.ErrorResult(fmt.Errorf("edit %v/%v/pr/%v err: %v", owner, repo, index, err))
} }
return to.TextResult(pr) return to.TextResult(pr)

View File

@@ -20,100 +20,112 @@ func TestEditPullRequestFn(t *testing.T) {
index = 7 index = 7
) )
var ( indexInputs := []struct {
mu sync.Mutex name string
gotMethod string val any
gotPath string }{
gotBody map[string]any {"float64", float64(index)},
) {"string", "7"},
}
for _, ii := range indexInputs {
t.Run(ii.name, func(t *testing.T) {
var (
mu sync.Mutex
gotMethod string
gotPath string
gotBody map[string]any
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/version":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"private":false}`))
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index):
mu.Lock()
gotMethod = r.Method
gotPath = r.URL.Path
var body map[string]any
_ = json.NewDecoder(r.Body).Decode(&body)
gotBody = body
mu.Unlock()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"%s","state":"open"}`, index, body["title"]))
default:
http.NotFound(w, r)
}
})
server := httptest.NewServer(handler)
defer server.Close()
origHost := flag.Host
origToken := flag.Token
origVersion := flag.Version
flag.Host = server.URL
flag.Token = ""
flag.Version = "test"
defer func() {
flag.Host = origHost
flag.Token = origToken
flag.Version = origVersion
}()
req := mcp.CallToolRequest{
Params: mcp.CallToolParams{
Arguments: map[string]any{
"owner": owner,
"repo": repo,
"index": ii.val,
"title": "WIP: my feature",
"state": "open",
},
},
}
result, err := EditPullRequestFn(context.Background(), req)
if err != nil {
t.Fatalf("EditPullRequestFn() error = %v", err)
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/version":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"private":false}`))
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index):
mu.Lock() mu.Lock()
gotMethod = r.Method defer mu.Unlock()
gotPath = r.URL.Path
var body map[string]any
_ = json.NewDecoder(r.Body).Decode(&body)
gotBody = body
mu.Unlock()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"%s","state":"open"}`, index, body["title"]))
default:
http.NotFound(w, r)
}
})
server := httptest.NewServer(handler) if gotMethod != http.MethodPatch {
defer server.Close() t.Fatalf("expected PATCH request, got %s", gotMethod)
}
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index) {
t.Fatalf("unexpected path: %s", gotPath)
}
if gotBody["title"] != "WIP: my feature" {
t.Fatalf("expected title 'WIP: my feature', got %v", gotBody["title"])
}
if gotBody["state"] != "open" {
t.Fatalf("expected state 'open', got %v", gotBody["state"])
}
origHost := flag.Host if len(result.Content) == 0 {
origToken := flag.Token t.Fatalf("expected content in result")
origVersion := flag.Version }
flag.Host = server.URL textContent, ok := mcp.AsTextContent(result.Content[0])
flag.Token = "" if !ok {
flag.Version = "test" t.Fatalf("expected text content, got %T", result.Content[0])
defer func() { }
flag.Host = origHost
flag.Token = origToken
flag.Version = origVersion
}()
req := mcp.CallToolRequest{ var parsed struct {
Params: mcp.CallToolParams{ Result map[string]any `json:"Result"`
Arguments: map[string]any{ }
"owner": owner, if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
"repo": repo, t.Fatalf("unmarshal result text: %v", err)
"index": float64(index), }
"title": "WIP: my feature", if got := parsed.Result["title"].(string); got != "WIP: my feature" {
"state": "open", t.Fatalf("result title = %q, want %q", got, "WIP: my feature")
}, }
}, })
}
result, err := EditPullRequestFn(context.Background(), req)
if err != nil {
t.Fatalf("EditPullRequestFn() error = %v", err)
}
mu.Lock()
defer mu.Unlock()
if gotMethod != http.MethodPatch {
t.Fatalf("expected PATCH request, got %s", gotMethod)
}
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index) {
t.Fatalf("unexpected path: %s", gotPath)
}
if gotBody["title"] != "WIP: my feature" {
t.Fatalf("expected title 'WIP: my feature', got %v", gotBody["title"])
}
if gotBody["state"] != "open" {
t.Fatalf("expected state 'open', got %v", gotBody["state"])
}
if len(result.Content) == 0 {
t.Fatalf("expected content in result")
}
textContent, ok := mcp.AsTextContent(result.Content[0])
if !ok {
t.Fatalf("expected text content, got %T", result.Content[0])
}
var parsed struct {
Result map[string]any `json:"Result"`
}
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
t.Fatalf("unmarshal result text: %v", err)
}
if got := parsed.Result["title"].(string); got != "WIP: my feature" {
t.Fatalf("result title = %q, want %q", got, "WIP: my feature")
} }
} }
@@ -124,113 +136,125 @@ func TestMergePullRequestFn(t *testing.T) {
index = 5 index = 5
) )
var ( indexInputs := []struct {
mu sync.Mutex name string
gotMethod string val any
gotPath string }{
gotBody map[string]any {"float64", float64(index)},
) {"string", "5"},
}
for _, ii := range indexInputs {
t.Run(ii.name, func(t *testing.T) {
var (
mu sync.Mutex
gotMethod string
gotPath string
gotBody map[string]any
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/version":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"private":false}`))
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index):
mu.Lock()
gotMethod = r.Method
gotPath = r.URL.Path
var body map[string]any
_ = json.NewDecoder(r.Body).Decode(&body)
gotBody = body
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
http.NotFound(w, r)
}
})
server := httptest.NewServer(handler)
defer server.Close()
origHost := flag.Host
origToken := flag.Token
origVersion := flag.Version
flag.Host = server.URL
flag.Token = ""
flag.Version = "test"
defer func() {
flag.Host = origHost
flag.Token = origToken
flag.Version = origVersion
}()
req := mcp.CallToolRequest{
Params: mcp.CallToolParams{
Arguments: map[string]any{
"owner": owner,
"repo": repo,
"index": ii.val,
"merge_style": "squash",
"title": "feat: my squashed commit",
"message": "Squash merge of PR #5",
"delete_branch": true,
},
},
}
result, err := MergePullRequestFn(context.Background(), req)
if err != nil {
t.Fatalf("MergePullRequestFn() error = %v", err)
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/version":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"private":false}`))
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index):
mu.Lock() mu.Lock()
gotMethod = r.Method defer mu.Unlock()
gotPath = r.URL.Path
var body map[string]any
_ = json.NewDecoder(r.Body).Decode(&body)
gotBody = body
mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
http.NotFound(w, r)
}
})
server := httptest.NewServer(handler) if gotMethod != http.MethodPost {
defer server.Close() t.Fatalf("expected POST request, got %s", gotMethod)
}
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index) {
t.Fatalf("unexpected path: %s", gotPath)
}
if gotBody["Do"] != "squash" {
t.Fatalf("expected Do 'squash', got %v", gotBody["Do"])
}
if gotBody["MergeTitleField"] != "feat: my squashed commit" {
t.Fatalf("expected MergeTitleField 'feat: my squashed commit', got %v", gotBody["MergeTitleField"])
}
if gotBody["MergeMessageField"] != "Squash merge of PR #5" {
t.Fatalf("expected MergeMessageField 'Squash merge of PR #5', got %v", gotBody["MergeMessageField"])
}
if gotBody["delete_branch_after_merge"] != true {
t.Fatalf("expected delete_branch_after_merge true, got %v", gotBody["delete_branch_after_merge"])
}
origHost := flag.Host if len(result.Content) == 0 {
origToken := flag.Token t.Fatalf("expected content in result")
origVersion := flag.Version }
flag.Host = server.URL textContent, ok := mcp.AsTextContent(result.Content[0])
flag.Token = "" if !ok {
flag.Version = "test" t.Fatalf("expected text content, got %T", result.Content[0])
defer func() { }
flag.Host = origHost
flag.Token = origToken
flag.Version = origVersion
}()
req := mcp.CallToolRequest{ var parsed struct {
Params: mcp.CallToolParams{ Result map[string]any `json:"Result"`
Arguments: map[string]any{ }
"owner": owner, if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
"repo": repo, t.Fatalf("unmarshal result text: %v", err)
"index": float64(index), }
"merge_style": "squash", if parsed.Result["merged"] != true {
"title": "feat: my squashed commit", t.Fatalf("expected merged=true, got %v", parsed.Result["merged"])
"message": "Squash merge of PR #5", }
"delete_branch": true, if parsed.Result["merge_style"] != "squash" {
}, t.Fatalf("expected merge_style 'squash', got %v", parsed.Result["merge_style"])
}, }
} if parsed.Result["branch_deleted"] != true {
t.Fatalf("expected branch_deleted=true, got %v", parsed.Result["branch_deleted"])
result, err := MergePullRequestFn(context.Background(), req) }
if err != nil { })
t.Fatalf("MergePullRequestFn() error = %v", err)
}
mu.Lock()
defer mu.Unlock()
if gotMethod != http.MethodPost {
t.Fatalf("expected POST request, got %s", gotMethod)
}
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index) {
t.Fatalf("unexpected path: %s", gotPath)
}
if gotBody["Do"] != "squash" {
t.Fatalf("expected Do 'squash', got %v", gotBody["Do"])
}
if gotBody["MergeTitleField"] != "feat: my squashed commit" {
t.Fatalf("expected MergeTitleField 'feat: my squashed commit', got %v", gotBody["MergeTitleField"])
}
if gotBody["MergeMessageField"] != "Squash merge of PR #5" {
t.Fatalf("expected MergeMessageField 'Squash merge of PR #5', got %v", gotBody["MergeMessageField"])
}
if gotBody["delete_branch_after_merge"] != true {
t.Fatalf("expected delete_branch_after_merge true, got %v", gotBody["delete_branch_after_merge"])
}
if len(result.Content) == 0 {
t.Fatalf("expected content in result")
}
textContent, ok := mcp.AsTextContent(result.Content[0])
if !ok {
t.Fatalf("expected text content, got %T", result.Content[0])
}
var parsed struct {
Result map[string]any `json:"Result"`
}
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
t.Fatalf("unmarshal result text: %v", err)
}
if parsed.Result["merged"] != true {
t.Fatalf("expected merged=true, got %v", parsed.Result["merged"])
}
if parsed.Result["merge_style"] != "squash" {
t.Fatalf("expected merge_style 'squash', got %v", parsed.Result["merge_style"])
}
if parsed.Result["branch_deleted"] != true {
t.Fatalf("expected branch_deleted=true, got %v", parsed.Result["branch_deleted"])
} }
} }
@@ -242,120 +266,132 @@ func TestGetPullRequestDiffFn(t *testing.T) {
diffRaw = "diff --git a/file.txt b/file.txt\n+line\n" diffRaw = "diff --git a/file.txt b/file.txt\n+line\n"
) )
var ( indexInputs := []struct {
mu sync.Mutex name string
diffRequested bool val any
binaryValue string }{
) {"float64", float64(index)},
errCh := make(chan error, 1) {"string", "12"},
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for _, ii := range indexInputs {
switch r.URL.Path { t.Run(ii.name, func(t *testing.T) {
case "/api/v1/version": var (
w.Header().Set("Content-Type", "application/json") mu sync.Mutex
_, _ = w.Write([]byte(`{"version":"1.12.0"}`)) diffRequested bool
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): binaryValue string
w.Header().Set("Content-Type", "application/json") )
_, _ = w.Write([]byte(`{"private":false}`)) errCh := make(chan error, 1)
case fmt.Sprintf("/%s/%s/pulls/%d.diff", owner, repo, index):
if r.Method != http.MethodGet { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select { switch r.URL.Path {
case errCh <- fmt.Errorf("unexpected method: %s", r.Method): case "/api/v1/version":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"private":false}`))
case fmt.Sprintf("/%s/%s/pulls/%d.diff", owner, repo, index):
if r.Method != http.MethodGet {
select {
case errCh <- fmt.Errorf("unexpected method: %s", r.Method):
default:
}
}
mu.Lock()
diffRequested = true
binaryValue = r.URL.Query().Get("binary")
mu.Unlock()
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(diffRaw))
default: default:
select {
case errCh <- fmt.Errorf("unexpected request path: %s", r.URL.Path):
default:
}
} }
})
server := httptest.NewServer(handler)
defer server.Close()
origHost := flag.Host
origToken := flag.Token
origVersion := flag.Version
flag.Host = server.URL
flag.Token = ""
flag.Version = "test"
defer func() {
flag.Host = origHost
flag.Token = origToken
flag.Version = origVersion
}()
req := mcp.CallToolRequest{
Params: mcp.CallToolParams{
Arguments: map[string]any{
"owner": owner,
"repo": repo,
"index": ii.val,
"binary": true,
},
},
} }
mu.Lock()
diffRequested = true result, err := GetPullRequestDiffFn(context.Background(), req)
binaryValue = r.URL.Query().Get("binary") if err != nil {
mu.Unlock() t.Fatalf("GetPullRequestDiffFn() error = %v", err)
w.Header().Set("Content-Type", "text/plain") }
_, _ = w.Write([]byte(diffRaw))
default:
select { select {
case errCh <- fmt.Errorf("unexpected request path: %s", r.URL.Path): case reqErr := <-errCh:
t.Fatalf("handler error: %v", reqErr)
default: default:
} }
}
})
server := httptest.NewServer(handler) mu.Lock()
defer server.Close() requested := diffRequested
gotBinary := binaryValue
mu.Unlock()
origHost := flag.Host if !requested {
origToken := flag.Token t.Fatalf("expected diff request to be made")
origVersion := flag.Version }
flag.Host = server.URL if gotBinary != "true" {
flag.Token = "" t.Fatalf("expected binary=true query param, got %q", gotBinary)
flag.Version = "test" }
defer func() {
flag.Host = origHost
flag.Token = origToken
flag.Version = origVersion
}()
req := mcp.CallToolRequest{ if len(result.Content) == 0 {
Params: mcp.CallToolParams{ t.Fatalf("expected content in result")
Arguments: map[string]any{ }
"owner": owner,
"repo": repo,
"index": float64(index),
"binary": true,
},
},
}
result, err := GetPullRequestDiffFn(context.Background(), req) textContent, ok := mcp.AsTextContent(result.Content[0])
if err != nil { if !ok {
t.Fatalf("GetPullRequestDiffFn() error = %v", err) t.Fatalf("expected text content, got %T", result.Content[0])
} }
select { var parsed struct {
case reqErr := <-errCh: Result map[string]any `json:"Result"`
t.Fatalf("handler error: %v", reqErr) }
default: if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
} t.Fatalf("unmarshal result text: %v", err)
}
mu.Lock() if got, ok := parsed.Result["diff"].(string); !ok || got != diffRaw {
requested := diffRequested t.Fatalf("diff = %q, want %q", got, diffRaw)
gotBinary := binaryValue }
mu.Unlock() if got, ok := parsed.Result["binary"].(bool); !ok || got != true {
t.Fatalf("binary = %v, want true", got)
if !requested { }
t.Fatalf("expected diff request to be made") if got, ok := parsed.Result["index"].(float64); !ok || int64(got) != int64(index) {
} t.Fatalf("index = %v, want %d", got, index)
if gotBinary != "true" { }
t.Fatalf("expected binary=true query param, got %q", gotBinary) if got, ok := parsed.Result["owner"].(string); !ok || got != owner {
} t.Fatalf("owner = %q, want %q", got, owner)
}
if len(result.Content) == 0 { if got, ok := parsed.Result["repo"].(string); !ok || got != repo {
t.Fatalf("expected content in result") t.Fatalf("repo = %q, want %q", got, repo)
} }
})
textContent, ok := mcp.AsTextContent(result.Content[0])
if !ok {
t.Fatalf("expected text content, got %T", result.Content[0])
}
var parsed struct {
Result map[string]any `json:"Result"`
}
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
t.Fatalf("unmarshal result text: %v", err)
}
if got, ok := parsed.Result["diff"].(string); !ok || got != diffRaw {
t.Fatalf("diff = %q, want %q", got, diffRaw)
}
if got, ok := parsed.Result["binary"].(bool); !ok || got != true {
t.Fatalf("binary = %v, want true", got)
}
if got, ok := parsed.Result["index"].(float64); !ok || int64(got) != int64(index) {
t.Fatalf("index = %v, want %d", got, index)
}
if got, ok := parsed.Result["owner"].(string); !ok || got != owner {
t.Fatalf("owner = %q, want %q", got, owner)
}
if got, ok := parsed.Result["repo"].(string); !ok || got != repo {
t.Fatalf("repo = %q, want %q", got, repo)
} }
} }

View File

@@ -7,6 +7,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
gitea_sdk "code.gitea.io/sdk/gitea" gitea_sdk "code.gitea.io/sdk/gitea"
@@ -46,13 +47,13 @@ func ListRepoCommitsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
if !ok { if !ok {
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
page, ok := req.GetArguments()["page"].(float64) page, err := params.GetIndex(req.GetArguments(), "page")
if !ok { if err != nil {
return to.ErrorResult(errors.New("page is required")) return to.ErrorResult(err)
} }
pageSize, ok := req.GetArguments()["page_size"].(float64) pageSize, err := params.GetIndex(req.GetArguments(), "page_size")
if !ok { if err != nil {
return to.ErrorResult(errors.New("page_size is required")) return to.ErrorResult(err)
} }
sha, _ := req.GetArguments()["sha"].(string) sha, _ := req.GetArguments()["sha"].(string)
path, _ := req.GetArguments()["path"].(string) path, _ := req.GetArguments()["path"].(string)

View File

@@ -8,6 +8,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
gitea_sdk "code.gitea.io/sdk/gitea" gitea_sdk "code.gitea.io/sdk/gitea"
@@ -163,16 +164,16 @@ func DeleteReleaseFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo
if !ok { if !ok {
return nil, errors.New("repo is required") return nil, errors.New("repo is required")
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return nil, errors.New("id is required") return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteRelease(owner, repo, int64(id)) _, err = client.DeleteRelease(owner, repo, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("delete release error: %v", err) return nil, fmt.Errorf("delete release error: %v", err)
} }
@@ -190,16 +191,16 @@ func GetReleaseFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolRe
if !ok { if !ok {
return nil, errors.New("repo is required") return nil, errors.New("repo is required")
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return nil, errors.New("id is required") return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
release, _, err := client.GetRelease(owner, repo, int64(id)) release, _, err := client.GetRelease(owner, repo, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get release error: %v", err) return nil, fmt.Errorf("get release error: %v", err)
} }
@@ -250,8 +251,8 @@ func ListReleasesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool
if ok { if ok {
pIsPreRelease = new(isPreRelease) pIsPreRelease = new(isPreRelease)
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
pageSize, _ := req.GetArguments()["pageSize"].(float64) pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 20)
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"gitea.com/gitea/gitea-mcp/pkg/tool" "gitea.com/gitea/gitea-mcp/pkg/tool"
@@ -191,14 +192,8 @@ func ForkRepoFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResu
func ListMyReposFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func ListMyReposFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
log.Debugf("Called ListMyReposFn") log.Debugf("Called ListMyReposFn")
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
opt := gitea_sdk.ListReposOptions{ opt := gitea_sdk.ListReposOptions{
ListOptions: gitea_sdk.ListOptions{ ListOptions: gitea_sdk.ListOptions{
Page: int(page), Page: int(page),

View File

@@ -7,6 +7,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
gitea_sdk "code.gitea.io/sdk/gitea" gitea_sdk "code.gitea.io/sdk/gitea"
@@ -183,8 +184,8 @@ func ListTagsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResu
if !ok { if !ok {
return nil, errors.New("repo is required") return nil, errors.New("repo is required")
} }
page, _ := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
pageSize, _ := req.GetArguments()["pageSize"].(float64) pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 20)
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"gitea.com/gitea/gitea-mcp/pkg/tool" "gitea.com/gitea/gitea-mcp/pkg/tool"
@@ -79,14 +80,8 @@ func UsersFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult,
if !ok { if !ok {
return to.ErrorResult(errors.New("keyword is required")) return to.ErrorResult(errors.New("keyword is required"))
} }
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
opt := gitea_sdk.SearchUsersOption{ opt := gitea_sdk.SearchUsersOption{
KeyWord: keyword, KeyWord: keyword,
ListOptions: gitea_sdk.ListOptions{ ListOptions: gitea_sdk.ListOptions{
@@ -116,14 +111,8 @@ func OrgTeamsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResu
return to.ErrorResult(errors.New("query is required")) return to.ErrorResult(errors.New("query is required"))
} }
includeDescription, _ := req.GetArguments()["includeDescription"].(bool) includeDescription, _ := req.GetArguments()["includeDescription"].(bool)
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
opt := gitea_sdk.SearchTeamsOptions{ opt := gitea_sdk.SearchTeamsOptions{
Query: query, Query: query,
IncludeDescription: includeDescription, IncludeDescription: includeDescription,
@@ -151,7 +140,7 @@ func ReposFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult,
} }
keywordIsTopic, _ := req.GetArguments()["keywordIsTopic"].(bool) keywordIsTopic, _ := req.GetArguments()["keywordIsTopic"].(bool)
keywordInDescription, _ := req.GetArguments()["keywordInDescription"].(bool) keywordInDescription, _ := req.GetArguments()["keywordInDescription"].(bool)
ownerID, _ := req.GetArguments()["ownerID"].(float64) ownerID := params.GetOptionalInt(req.GetArguments(), "ownerID", 0)
var pIsPrivate *bool var pIsPrivate *bool
isPrivate, ok := req.GetArguments()["isPrivate"].(bool) isPrivate, ok := req.GetArguments()["isPrivate"].(bool)
if ok { if ok {
@@ -164,19 +153,13 @@ func ReposFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult,
} }
sort, _ := req.GetArguments()["sort"].(string) sort, _ := req.GetArguments()["sort"].(string)
order, _ := req.GetArguments()["order"].(string) order, _ := req.GetArguments()["order"].(string)
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
opt := gitea_sdk.SearchRepoOptions{ opt := gitea_sdk.SearchRepoOptions{
Keyword: keyword, Keyword: keyword,
KeywordIsTopic: keywordIsTopic, KeywordIsTopic: keywordIsTopic,
KeywordInDescription: keywordInDescription, KeywordInDescription: keywordInDescription,
OwnerID: int64(ownerID), OwnerID: ownerID,
IsPrivate: pIsPrivate, IsPrivate: pIsPrivate,
IsArchived: pIsArchived, IsArchived: pIsArchived,
Sort: sort, Sort: sort,

View File

@@ -233,14 +233,8 @@ func ListTrackedTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call
if err != nil { if err != nil {
return to.ErrorResult(err) return to.ErrorResult(err)
} }
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
@@ -276,16 +270,16 @@ func AddTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
return to.ErrorResult(err) return to.ErrorResult(err)
} }
timeSeconds, ok := req.GetArguments()["time"].(float64) timeSeconds, err := params.GetIndex(req.GetArguments(), "time")
if !ok { if err != nil {
return to.ErrorResult(errors.New("time is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
trackedTime, _, err := client.AddTime(owner, repo, index, gitea_sdk.AddTimeOption{ trackedTime, _, err := client.AddTime(owner, repo, index, gitea_sdk.AddTimeOption{
Time: int64(timeSeconds), Time: timeSeconds,
}) })
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("add tracked time to %s/%s#%d err: %v", owner, repo, index, err)) return to.ErrorResult(fmt.Errorf("add tracked time to %s/%s#%d err: %v", owner, repo, index, err))
@@ -308,19 +302,19 @@ func DeleteTrackedTimeFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal
if err != nil { if err != nil {
return to.ErrorResult(err) return to.ErrorResult(err)
} }
id, ok := req.GetArguments()["id"].(float64) id, err := params.GetIndex(req.GetArguments(), "id")
if !ok { if err != nil {
return to.ErrorResult(errors.New("id is required")) return to.ErrorResult(err)
} }
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
} }
_, err = client.DeleteTime(owner, repo, index, int64(id)) _, err = client.DeleteTime(owner, repo, index, id)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("delete tracked time %d from %s/%s#%d err: %v", int64(id), owner, repo, index, err)) return to.ErrorResult(fmt.Errorf("delete tracked time %d from %s/%s#%d err: %v", id, owner, repo, index, err))
} }
return to.TextResult(fmt.Sprintf("Tracked time entry %d deleted from issue %s/%s#%d", int64(id), owner, repo, index)) return to.TextResult(fmt.Sprintf("Tracked time entry %d deleted from issue %s/%s#%d", id, owner, repo, index))
} }
func ListRepoTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func ListRepoTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -334,14 +328,8 @@ func ListRepoTimesFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo
return to.ErrorResult(errors.New("repo is required")) return to.ErrorResult(errors.New("repo is required"))
} }
page, ok := req.GetArguments()["page"].(float64) page := params.GetOptionalInt(req.GetArguments(), "page", 1)
if !ok { pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
page = 1
}
pageSize, ok := req.GetArguments()["pageSize"].(float64)
if !ok {
pageSize = 100
}
client, err := gitea.ClientFromContext(ctx) client, err := gitea.ClientFromContext(ctx)
if err != nil { if err != nil {
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))

View File

@@ -6,6 +6,7 @@ import (
"gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/gitea"
"gitea.com/gitea/gitea-mcp/pkg/log" "gitea.com/gitea/gitea-mcp/pkg/log"
"gitea.com/gitea/gitea-mcp/pkg/params"
"gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/to"
"gitea.com/gitea/gitea-mcp/pkg/tool" "gitea.com/gitea/gitea-mcp/pkg/tool"
@@ -68,11 +69,11 @@ func registerTools() {
// getIntArg parses an integer argument from the MCP request arguments map. // getIntArg parses an integer argument from the MCP request arguments map.
// Returns def if missing, not a number, or less than 1. Used for pagination arguments. // Returns def if missing, not a number, or less than 1. Used for pagination arguments.
func getIntArg(req mcp.CallToolRequest, name string, def int) int { func getIntArg(req mcp.CallToolRequest, name string, def int) int {
val, ok := req.GetArguments()[name].(float64) v := params.GetOptionalInt(req.GetArguments(), name, int64(def))
if !ok || val < 1 { if v < 1 {
return def return def
} }
return int(val) return int(v)
} }
// GetUserInfoFn is the handler for "get_my_user_info" MCP tool requests. // GetUserInfoFn is the handler for "get_my_user_info" MCP tool requests.

View File

@@ -5,7 +5,24 @@ import (
"strconv" "strconv"
) )
// GetIndex extracts an index parameter from MCP tool arguments. // ToInt64 converts a value to int64, accepting both float64 (JSON number) and
// string representations. Returns false if the value cannot be converted.
func ToInt64(val any) (int64, bool) {
switch v := val.(type) {
case float64:
return int64(v), true
case string:
i, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, false
}
return i, true
default:
return 0, false
}
}
// GetIndex extracts a required integer parameter from MCP tool arguments.
// It accepts both numeric (float64 from JSON) and string representations. // It accepts both numeric (float64 from JSON) and string representations.
// This provides better UX for LLM callers that may naturally use strings // This provides better UX for LLM callers that may naturally use strings
// for identifiers like issue/PR numbers. // for identifiers like issue/PR numbers.
@@ -15,19 +32,27 @@ func GetIndex(args map[string]any, key string) (int64, error) {
return 0, fmt.Errorf("%s is required", key) return 0, fmt.Errorf("%s is required", key)
} }
// Try float64 (JSON number type) if i, ok := ToInt64(val); ok {
if f, ok := val.(float64); ok { return i, nil
return int64(f), nil
} }
// Try string and parse to integer
if s, ok := val.(string); ok { if s, ok := val.(string); ok {
i, err := strconv.ParseInt(s, 10, 64) return 0, fmt.Errorf("%s must be a valid integer (got %q)", key, s)
if err != nil {
return 0, fmt.Errorf("%s must be a valid integer (got %q)", key, s)
}
return i, nil
} }
return 0, fmt.Errorf("%s must be a number or numeric string", key) return 0, fmt.Errorf("%s must be a number or numeric string", key)
} }
// GetOptionalInt extracts an optional integer parameter from MCP tool arguments.
// Returns defaultVal if the key is missing or the value cannot be parsed.
// Accepts both float64 (JSON number) and string representations.
func GetOptionalInt(args map[string]any, key string, defaultVal int64) int64 {
val, exists := args[key]
if !exists {
return defaultVal
}
if i, ok := ToInt64(val); ok {
return i
}
return defaultVal
}

View File

@@ -5,6 +5,63 @@ import (
"testing" "testing"
) )
func TestToInt64(t *testing.T) {
tests := []struct {
name string
val any
want int64
ok bool
}{
{"float64", float64(42), 42, true},
{"float64 zero", float64(0), 0, true},
{"float64 negative", float64(-5), -5, true},
{"string", "123", 123, true},
{"string zero", "0", 0, true},
{"string negative", "-10", -10, true},
{"invalid string", "abc", 0, false},
{"decimal string", "1.5", 0, false},
{"bool", true, 0, false},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := ToInt64(tt.val)
if ok != tt.ok {
t.Errorf("ToInt64() ok = %v, want %v", ok, tt.ok)
}
if got != tt.want {
t.Errorf("ToInt64() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetOptionalInt(t *testing.T) {
tests := []struct {
name string
args map[string]any
key string
defaultVal int64
want int64
}{
{"present float64", map[string]any{"page": float64(3)}, "page", 1, 3},
{"present string", map[string]any{"page": "5"}, "page", 1, 5},
{"missing key", map[string]any{}, "page", 1, 1},
{"invalid string", map[string]any{"page": "abc"}, "page", 1, 1},
{"invalid type", map[string]any{"page": true}, "page", 1, 1},
{"zero value", map[string]any{"id": float64(0)}, "id", 99, 0},
{"string zero", map[string]any{"id": "0"}, "id", 99, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetOptionalInt(tt.args, tt.key, tt.defaultVal)
if got != tt.want {
t.Errorf("GetOptionalInt() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetIndex(t *testing.T) { func TestGetIndex(t *testing.T) {
tests := []struct { tests := []struct {
name string name string