mirror of
https://gitea.com/gitea/gitea-mcp.git
synced 2026-02-27 09:05:12 +00:00
feat: accept string values for all numeric input parameters (#138)
## Summary - MCP clients may send numbers as strings. This adds `ToInt64` and `GetOptionalInt` helpers to `pkg/params` and replaces all raw `.(float64)` type assertions across operation handlers to accept both `float64` and string inputs. ## Test plan - [x] Verify `go test ./...` passes - [x] Test with an MCP client that sends numeric parameters as strings *Created by Claude on behalf of @silverwind* Reviewed-on: https://gitea.com/gitea/gitea-mcp/pulls/138 Reviewed-by: Lunny Xiao <xiaolunwen@gmail.com> Co-authored-by: silverwind <me@silverwind.io> Co-committed-by: silverwind <me@silverwind.io>
This commit is contained in:
@@ -345,19 +345,13 @@ func ListRepoPullRequestsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
|
||||
if !ok {
|
||||
sort = "recentupdate"
|
||||
}
|
||||
milestone, _ := req.GetArguments()["milestone"].(float64)
|
||||
page, ok := req.GetArguments()["page"].(float64)
|
||||
if !ok {
|
||||
page = 1
|
||||
}
|
||||
pageSize, ok := req.GetArguments()["pageSize"].(float64)
|
||||
if !ok {
|
||||
pageSize = 100
|
||||
}
|
||||
milestone := params.GetOptionalInt(req.GetArguments(), "milestone", 0)
|
||||
page := params.GetOptionalInt(req.GetArguments(), "page", 1)
|
||||
pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
|
||||
opt := gitea_sdk.ListPullRequestsOptions{
|
||||
State: gitea_sdk.StateType(state),
|
||||
Sort: sort,
|
||||
Milestone: int64(milestone),
|
||||
Milestone: milestone,
|
||||
ListOptions: gitea_sdk.ListOptions{
|
||||
Page: int(page),
|
||||
PageSize: int(pageSize),
|
||||
@@ -555,14 +549,8 @@ func ListPullRequestReviewsFn(ctx context.Context, req mcp.CallToolRequest) (*mc
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
page, ok := req.GetArguments()["page"].(float64)
|
||||
if !ok {
|
||||
page = 1
|
||||
}
|
||||
pageSize, ok := req.GetArguments()["pageSize"].(float64)
|
||||
if !ok {
|
||||
pageSize = 100
|
||||
}
|
||||
page := params.GetOptionalInt(req.GetArguments(), "page", 1)
|
||||
pageSize := params.GetOptionalInt(req.GetArguments(), "pageSize", 100)
|
||||
|
||||
client, err := gitea.ClientFromContext(ctx)
|
||||
if err != nil {
|
||||
@@ -596,9 +584,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
reviewID, ok := req.GetArguments()["review_id"].(float64)
|
||||
if !ok {
|
||||
return to.ErrorResult(errors.New("review_id is required"))
|
||||
reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
|
||||
client, err := gitea.ClientFromContext(ctx)
|
||||
@@ -606,9 +594,9 @@ func GetPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
|
||||
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
||||
}
|
||||
|
||||
review, _, err := client.GetPullReview(owner, repo, index, int64(reviewID))
|
||||
review, _, err := client.GetPullReview(owner, repo, index, reviewID)
|
||||
if err != nil {
|
||||
return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
|
||||
return to.ErrorResult(fmt.Errorf("get review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
|
||||
}
|
||||
|
||||
return to.TextResult(review)
|
||||
@@ -628,9 +616,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
reviewID, ok := req.GetArguments()["review_id"].(float64)
|
||||
if !ok {
|
||||
return to.ErrorResult(errors.New("review_id is required"))
|
||||
reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
|
||||
client, err := gitea.ClientFromContext(ctx)
|
||||
@@ -638,9 +626,9 @@ func ListPullRequestReviewCommentsFn(ctx context.Context, req mcp.CallToolReques
|
||||
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
||||
}
|
||||
|
||||
comments, _, err := client.ListPullReviewComments(owner, repo, index, int64(reviewID))
|
||||
comments, _, err := client.ListPullReviewComments(owner, repo, index, reviewID)
|
||||
if err != nil {
|
||||
return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
|
||||
return to.ErrorResult(fmt.Errorf("list review comments for review %v on %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
|
||||
}
|
||||
|
||||
return to.TextResult(comments)
|
||||
@@ -685,11 +673,11 @@ func CreatePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
|
||||
if body, ok := commentMap["body"].(string); ok {
|
||||
reviewComment.Body = body
|
||||
}
|
||||
if oldLineNum, ok := commentMap["old_line_num"].(float64); ok {
|
||||
reviewComment.OldLineNum = int64(oldLineNum)
|
||||
if oldLineNum, ok := params.ToInt64(commentMap["old_line_num"]); ok {
|
||||
reviewComment.OldLineNum = oldLineNum
|
||||
}
|
||||
if newLineNum, ok := commentMap["new_line_num"].(float64); ok {
|
||||
reviewComment.NewLineNum = int64(newLineNum)
|
||||
if newLineNum, ok := params.ToInt64(commentMap["new_line_num"]); ok {
|
||||
reviewComment.NewLineNum = newLineNum
|
||||
}
|
||||
opt.Comments = append(opt.Comments, reviewComment)
|
||||
}
|
||||
@@ -724,9 +712,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
reviewID, ok := req.GetArguments()["review_id"].(float64)
|
||||
if !ok {
|
||||
return to.ErrorResult(errors.New("review_id is required"))
|
||||
reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
state, ok := req.GetArguments()["state"].(string)
|
||||
if !ok {
|
||||
@@ -745,9 +733,9 @@ func SubmitPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
|
||||
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
||||
}
|
||||
|
||||
review, _, err := client.SubmitPullReview(owner, repo, index, int64(reviewID), opt)
|
||||
review, _, err := client.SubmitPullReview(owner, repo, index, reviewID, opt)
|
||||
if err != nil {
|
||||
return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
|
||||
return to.ErrorResult(fmt.Errorf("submit review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
|
||||
}
|
||||
|
||||
return to.TextResult(review)
|
||||
@@ -767,9 +755,9 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
reviewID, ok := req.GetArguments()["review_id"].(float64)
|
||||
if !ok {
|
||||
return to.ErrorResult(errors.New("review_id is required"))
|
||||
reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
|
||||
client, err := gitea.ClientFromContext(ctx)
|
||||
@@ -777,14 +765,14 @@ func DeletePullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*m
|
||||
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
||||
}
|
||||
|
||||
_, err = client.DeletePullReview(owner, repo, index, int64(reviewID))
|
||||
_, err = client.DeletePullReview(owner, repo, index, reviewID)
|
||||
if err != nil {
|
||||
return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
|
||||
return to.ErrorResult(fmt.Errorf("delete review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
|
||||
}
|
||||
|
||||
successMsg := map[string]any{
|
||||
"message": "Successfully deleted review",
|
||||
"review_id": int64(reviewID),
|
||||
"review_id": reviewID,
|
||||
"pr_index": index,
|
||||
"repository": fmt.Sprintf("%s/%s", owner, repo),
|
||||
}
|
||||
@@ -806,9 +794,9 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
reviewID, ok := req.GetArguments()["review_id"].(float64)
|
||||
if !ok {
|
||||
return to.ErrorResult(errors.New("review_id is required"))
|
||||
reviewID, err := params.GetIndex(req.GetArguments(), "review_id")
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
|
||||
opt := gitea_sdk.DismissPullReviewOptions{}
|
||||
@@ -821,14 +809,14 @@ func DismissPullRequestReviewFn(ctx context.Context, req mcp.CallToolRequest) (*
|
||||
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
||||
}
|
||||
|
||||
_, err = client.DismissPullReview(owner, repo, index, int64(reviewID), opt)
|
||||
_, err = client.DismissPullReview(owner, repo, index, reviewID, opt)
|
||||
if err != nil {
|
||||
return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", int64(reviewID), owner, repo, index, err))
|
||||
return to.ErrorResult(fmt.Errorf("dismiss review %v for %v/%v/pr/%v err: %v", reviewID, owner, repo, index, err))
|
||||
}
|
||||
|
||||
successMsg := map[string]any{
|
||||
"message": "Successfully dismissed review",
|
||||
"review_id": int64(reviewID),
|
||||
"review_id": reviewID,
|
||||
"pr_index": index,
|
||||
"repository": fmt.Sprintf("%s/%s", owner, repo),
|
||||
}
|
||||
@@ -917,9 +905,9 @@ func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
|
||||
if !ok {
|
||||
return to.ErrorResult(errors.New("repo is required"))
|
||||
}
|
||||
index, ok := req.GetArguments()["index"].(float64)
|
||||
if !ok {
|
||||
return to.ErrorResult(errors.New("index is required"))
|
||||
index, err := params.GetIndex(req.GetArguments(), "index")
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
|
||||
opt := gitea_sdk.EditPullRequestOption{}
|
||||
@@ -947,8 +935,10 @@ func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
|
||||
opt.Assignees = assignees
|
||||
}
|
||||
}
|
||||
if milestone, ok := req.GetArguments()["milestone"].(float64); ok {
|
||||
opt.Milestone = int64(milestone)
|
||||
if val, exists := req.GetArguments()["milestone"]; exists {
|
||||
if milestone, ok := params.ToInt64(val); ok {
|
||||
opt.Milestone = milestone
|
||||
}
|
||||
}
|
||||
if state, ok := req.GetArguments()["state"].(string); ok {
|
||||
opt.State = new(gitea_sdk.StateType(state))
|
||||
@@ -962,9 +952,9 @@ func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
|
||||
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
||||
}
|
||||
|
||||
pr, _, err := client.EditPullRequest(owner, repo, int64(index), opt)
|
||||
pr, _, err := client.EditPullRequest(owner, repo, index, opt)
|
||||
if err != nil {
|
||||
return to.ErrorResult(fmt.Errorf("edit %v/%v/pr/%v err: %v", owner, repo, int64(index), err))
|
||||
return to.ErrorResult(fmt.Errorf("edit %v/%v/pr/%v err: %v", owner, repo, index, err))
|
||||
}
|
||||
|
||||
return to.TextResult(pr)
|
||||
|
||||
@@ -20,100 +20,112 @@ func TestEditPullRequestFn(t *testing.T) {
|
||||
index = 7
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
gotMethod string
|
||||
gotPath string
|
||||
gotBody map[string]any
|
||||
)
|
||||
indexInputs := []struct {
|
||||
name string
|
||||
val any
|
||||
}{
|
||||
{"float64", float64(index)},
|
||||
{"string", "7"},
|
||||
}
|
||||
|
||||
for _, ii := range indexInputs {
|
||||
t.Run(ii.name, func(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
gotMethod string
|
||||
gotPath string
|
||||
gotBody map[string]any
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/version":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"private":false}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index):
|
||||
mu.Lock()
|
||||
gotMethod = r.Method
|
||||
gotPath = r.URL.Path
|
||||
var body map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&body)
|
||||
gotBody = body
|
||||
mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"%s","state":"open"}`, index, body["title"]))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
||||
origHost := flag.Host
|
||||
origToken := flag.Token
|
||||
origVersion := flag.Version
|
||||
flag.Host = server.URL
|
||||
flag.Token = ""
|
||||
flag.Version = "test"
|
||||
defer func() {
|
||||
flag.Host = origHost
|
||||
flag.Token = origToken
|
||||
flag.Version = origVersion
|
||||
}()
|
||||
|
||||
req := mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Arguments: map[string]any{
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"index": ii.val,
|
||||
"title": "WIP: my feature",
|
||||
"state": "open",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EditPullRequestFn(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("EditPullRequestFn() error = %v", err)
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/version":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"private":false}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index):
|
||||
mu.Lock()
|
||||
gotMethod = r.Method
|
||||
gotPath = r.URL.Path
|
||||
var body map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&body)
|
||||
gotBody = body
|
||||
mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"%s","state":"open"}`, index, body["title"]))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
defer mu.Unlock()
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
if gotMethod != http.MethodPatch {
|
||||
t.Fatalf("expected PATCH request, got %s", gotMethod)
|
||||
}
|
||||
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index) {
|
||||
t.Fatalf("unexpected path: %s", gotPath)
|
||||
}
|
||||
if gotBody["title"] != "WIP: my feature" {
|
||||
t.Fatalf("expected title 'WIP: my feature', got %v", gotBody["title"])
|
||||
}
|
||||
if gotBody["state"] != "open" {
|
||||
t.Fatalf("expected state 'open', got %v", gotBody["state"])
|
||||
}
|
||||
|
||||
origHost := flag.Host
|
||||
origToken := flag.Token
|
||||
origVersion := flag.Version
|
||||
flag.Host = server.URL
|
||||
flag.Token = ""
|
||||
flag.Version = "test"
|
||||
defer func() {
|
||||
flag.Host = origHost
|
||||
flag.Token = origToken
|
||||
flag.Version = origVersion
|
||||
}()
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatalf("expected content in result")
|
||||
}
|
||||
textContent, ok := mcp.AsTextContent(result.Content[0])
|
||||
if !ok {
|
||||
t.Fatalf("expected text content, got %T", result.Content[0])
|
||||
}
|
||||
|
||||
req := mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Arguments: map[string]any{
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"index": float64(index),
|
||||
"title": "WIP: my feature",
|
||||
"state": "open",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EditPullRequestFn(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("EditPullRequestFn() error = %v", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if gotMethod != http.MethodPatch {
|
||||
t.Fatalf("expected PATCH request, got %s", gotMethod)
|
||||
}
|
||||
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index) {
|
||||
t.Fatalf("unexpected path: %s", gotPath)
|
||||
}
|
||||
if gotBody["title"] != "WIP: my feature" {
|
||||
t.Fatalf("expected title 'WIP: my feature', got %v", gotBody["title"])
|
||||
}
|
||||
if gotBody["state"] != "open" {
|
||||
t.Fatalf("expected state 'open', got %v", gotBody["state"])
|
||||
}
|
||||
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatalf("expected content in result")
|
||||
}
|
||||
textContent, ok := mcp.AsTextContent(result.Content[0])
|
||||
if !ok {
|
||||
t.Fatalf("expected text content, got %T", result.Content[0])
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
Result map[string]any `json:"Result"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result text: %v", err)
|
||||
}
|
||||
if got := parsed.Result["title"].(string); got != "WIP: my feature" {
|
||||
t.Fatalf("result title = %q, want %q", got, "WIP: my feature")
|
||||
var parsed struct {
|
||||
Result map[string]any `json:"Result"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result text: %v", err)
|
||||
}
|
||||
if got := parsed.Result["title"].(string); got != "WIP: my feature" {
|
||||
t.Fatalf("result title = %q, want %q", got, "WIP: my feature")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,113 +136,125 @@ func TestMergePullRequestFn(t *testing.T) {
|
||||
index = 5
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
gotMethod string
|
||||
gotPath string
|
||||
gotBody map[string]any
|
||||
)
|
||||
indexInputs := []struct {
|
||||
name string
|
||||
val any
|
||||
}{
|
||||
{"float64", float64(index)},
|
||||
{"string", "5"},
|
||||
}
|
||||
|
||||
for _, ii := range indexInputs {
|
||||
t.Run(ii.name, func(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
gotMethod string
|
||||
gotPath string
|
||||
gotBody map[string]any
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/version":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"private":false}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index):
|
||||
mu.Lock()
|
||||
gotMethod = r.Method
|
||||
gotPath = r.URL.Path
|
||||
var body map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&body)
|
||||
gotBody = body
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
||||
origHost := flag.Host
|
||||
origToken := flag.Token
|
||||
origVersion := flag.Version
|
||||
flag.Host = server.URL
|
||||
flag.Token = ""
|
||||
flag.Version = "test"
|
||||
defer func() {
|
||||
flag.Host = origHost
|
||||
flag.Token = origToken
|
||||
flag.Version = origVersion
|
||||
}()
|
||||
|
||||
req := mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Arguments: map[string]any{
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"index": ii.val,
|
||||
"merge_style": "squash",
|
||||
"title": "feat: my squashed commit",
|
||||
"message": "Squash merge of PR #5",
|
||||
"delete_branch": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := MergePullRequestFn(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("MergePullRequestFn() error = %v", err)
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/version":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"private":false}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index):
|
||||
mu.Lock()
|
||||
gotMethod = r.Method
|
||||
gotPath = r.URL.Path
|
||||
var body map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&body)
|
||||
gotBody = body
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
defer mu.Unlock()
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
if gotMethod != http.MethodPost {
|
||||
t.Fatalf("expected POST request, got %s", gotMethod)
|
||||
}
|
||||
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index) {
|
||||
t.Fatalf("unexpected path: %s", gotPath)
|
||||
}
|
||||
if gotBody["Do"] != "squash" {
|
||||
t.Fatalf("expected Do 'squash', got %v", gotBody["Do"])
|
||||
}
|
||||
if gotBody["MergeTitleField"] != "feat: my squashed commit" {
|
||||
t.Fatalf("expected MergeTitleField 'feat: my squashed commit', got %v", gotBody["MergeTitleField"])
|
||||
}
|
||||
if gotBody["MergeMessageField"] != "Squash merge of PR #5" {
|
||||
t.Fatalf("expected MergeMessageField 'Squash merge of PR #5', got %v", gotBody["MergeMessageField"])
|
||||
}
|
||||
if gotBody["delete_branch_after_merge"] != true {
|
||||
t.Fatalf("expected delete_branch_after_merge true, got %v", gotBody["delete_branch_after_merge"])
|
||||
}
|
||||
|
||||
origHost := flag.Host
|
||||
origToken := flag.Token
|
||||
origVersion := flag.Version
|
||||
flag.Host = server.URL
|
||||
flag.Token = ""
|
||||
flag.Version = "test"
|
||||
defer func() {
|
||||
flag.Host = origHost
|
||||
flag.Token = origToken
|
||||
flag.Version = origVersion
|
||||
}()
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatalf("expected content in result")
|
||||
}
|
||||
textContent, ok := mcp.AsTextContent(result.Content[0])
|
||||
if !ok {
|
||||
t.Fatalf("expected text content, got %T", result.Content[0])
|
||||
}
|
||||
|
||||
req := mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Arguments: map[string]any{
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"index": float64(index),
|
||||
"merge_style": "squash",
|
||||
"title": "feat: my squashed commit",
|
||||
"message": "Squash merge of PR #5",
|
||||
"delete_branch": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := MergePullRequestFn(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("MergePullRequestFn() error = %v", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if gotMethod != http.MethodPost {
|
||||
t.Fatalf("expected POST request, got %s", gotMethod)
|
||||
}
|
||||
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index) {
|
||||
t.Fatalf("unexpected path: %s", gotPath)
|
||||
}
|
||||
if gotBody["Do"] != "squash" {
|
||||
t.Fatalf("expected Do 'squash', got %v", gotBody["Do"])
|
||||
}
|
||||
if gotBody["MergeTitleField"] != "feat: my squashed commit" {
|
||||
t.Fatalf("expected MergeTitleField 'feat: my squashed commit', got %v", gotBody["MergeTitleField"])
|
||||
}
|
||||
if gotBody["MergeMessageField"] != "Squash merge of PR #5" {
|
||||
t.Fatalf("expected MergeMessageField 'Squash merge of PR #5', got %v", gotBody["MergeMessageField"])
|
||||
}
|
||||
if gotBody["delete_branch_after_merge"] != true {
|
||||
t.Fatalf("expected delete_branch_after_merge true, got %v", gotBody["delete_branch_after_merge"])
|
||||
}
|
||||
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatalf("expected content in result")
|
||||
}
|
||||
textContent, ok := mcp.AsTextContent(result.Content[0])
|
||||
if !ok {
|
||||
t.Fatalf("expected text content, got %T", result.Content[0])
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
Result map[string]any `json:"Result"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result text: %v", err)
|
||||
}
|
||||
if parsed.Result["merged"] != true {
|
||||
t.Fatalf("expected merged=true, got %v", parsed.Result["merged"])
|
||||
}
|
||||
if parsed.Result["merge_style"] != "squash" {
|
||||
t.Fatalf("expected merge_style 'squash', got %v", parsed.Result["merge_style"])
|
||||
}
|
||||
if parsed.Result["branch_deleted"] != true {
|
||||
t.Fatalf("expected branch_deleted=true, got %v", parsed.Result["branch_deleted"])
|
||||
var parsed struct {
|
||||
Result map[string]any `json:"Result"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result text: %v", err)
|
||||
}
|
||||
if parsed.Result["merged"] != true {
|
||||
t.Fatalf("expected merged=true, got %v", parsed.Result["merged"])
|
||||
}
|
||||
if parsed.Result["merge_style"] != "squash" {
|
||||
t.Fatalf("expected merge_style 'squash', got %v", parsed.Result["merge_style"])
|
||||
}
|
||||
if parsed.Result["branch_deleted"] != true {
|
||||
t.Fatalf("expected branch_deleted=true, got %v", parsed.Result["branch_deleted"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,120 +266,132 @@ func TestGetPullRequestDiffFn(t *testing.T) {
|
||||
diffRaw = "diff --git a/file.txt b/file.txt\n+line\n"
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
diffRequested bool
|
||||
binaryValue string
|
||||
)
|
||||
errCh := make(chan error, 1)
|
||||
indexInputs := []struct {
|
||||
name string
|
||||
val any
|
||||
}{
|
||||
{"float64", float64(index)},
|
||||
{"string", "12"},
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/version":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"private":false}`))
|
||||
case fmt.Sprintf("/%s/%s/pulls/%d.diff", owner, repo, index):
|
||||
if r.Method != http.MethodGet {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("unexpected method: %s", r.Method):
|
||||
for _, ii := range indexInputs {
|
||||
t.Run(ii.name, func(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
diffRequested bool
|
||||
binaryValue string
|
||||
)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/version":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
||||
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"private":false}`))
|
||||
case fmt.Sprintf("/%s/%s/pulls/%d.diff", owner, repo, index):
|
||||
if r.Method != http.MethodGet {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("unexpected method: %s", r.Method):
|
||||
default:
|
||||
}
|
||||
}
|
||||
mu.Lock()
|
||||
diffRequested = true
|
||||
binaryValue = r.URL.Query().Get("binary")
|
||||
mu.Unlock()
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = w.Write([]byte(diffRaw))
|
||||
default:
|
||||
select {
|
||||
case errCh <- fmt.Errorf("unexpected request path: %s", r.URL.Path):
|
||||
default:
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
||||
origHost := flag.Host
|
||||
origToken := flag.Token
|
||||
origVersion := flag.Version
|
||||
flag.Host = server.URL
|
||||
flag.Token = ""
|
||||
flag.Version = "test"
|
||||
defer func() {
|
||||
flag.Host = origHost
|
||||
flag.Token = origToken
|
||||
flag.Version = origVersion
|
||||
}()
|
||||
|
||||
req := mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Arguments: map[string]any{
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"index": ii.val,
|
||||
"binary": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
mu.Lock()
|
||||
diffRequested = true
|
||||
binaryValue = r.URL.Query().Get("binary")
|
||||
mu.Unlock()
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = w.Write([]byte(diffRaw))
|
||||
default:
|
||||
|
||||
result, err := GetPullRequestDiffFn(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPullRequestDiffFn() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case errCh <- fmt.Errorf("unexpected request path: %s", r.URL.Path):
|
||||
case reqErr := <-errCh:
|
||||
t.Fatalf("handler error: %v", reqErr)
|
||||
default:
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
mu.Lock()
|
||||
requested := diffRequested
|
||||
gotBinary := binaryValue
|
||||
mu.Unlock()
|
||||
|
||||
origHost := flag.Host
|
||||
origToken := flag.Token
|
||||
origVersion := flag.Version
|
||||
flag.Host = server.URL
|
||||
flag.Token = ""
|
||||
flag.Version = "test"
|
||||
defer func() {
|
||||
flag.Host = origHost
|
||||
flag.Token = origToken
|
||||
flag.Version = origVersion
|
||||
}()
|
||||
if !requested {
|
||||
t.Fatalf("expected diff request to be made")
|
||||
}
|
||||
if gotBinary != "true" {
|
||||
t.Fatalf("expected binary=true query param, got %q", gotBinary)
|
||||
}
|
||||
|
||||
req := mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Arguments: map[string]any{
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"index": float64(index),
|
||||
"binary": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatalf("expected content in result")
|
||||
}
|
||||
|
||||
result, err := GetPullRequestDiffFn(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPullRequestDiffFn() error = %v", err)
|
||||
}
|
||||
textContent, ok := mcp.AsTextContent(result.Content[0])
|
||||
if !ok {
|
||||
t.Fatalf("expected text content, got %T", result.Content[0])
|
||||
}
|
||||
|
||||
select {
|
||||
case reqErr := <-errCh:
|
||||
t.Fatalf("handler error: %v", reqErr)
|
||||
default:
|
||||
}
|
||||
var parsed struct {
|
||||
Result map[string]any `json:"Result"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result text: %v", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
requested := diffRequested
|
||||
gotBinary := binaryValue
|
||||
mu.Unlock()
|
||||
|
||||
if !requested {
|
||||
t.Fatalf("expected diff request to be made")
|
||||
}
|
||||
if gotBinary != "true" {
|
||||
t.Fatalf("expected binary=true query param, got %q", gotBinary)
|
||||
}
|
||||
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatalf("expected content in result")
|
||||
}
|
||||
|
||||
textContent, ok := mcp.AsTextContent(result.Content[0])
|
||||
if !ok {
|
||||
t.Fatalf("expected text content, got %T", result.Content[0])
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
Result map[string]any `json:"Result"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
|
||||
t.Fatalf("unmarshal result text: %v", err)
|
||||
}
|
||||
|
||||
if got, ok := parsed.Result["diff"].(string); !ok || got != diffRaw {
|
||||
t.Fatalf("diff = %q, want %q", got, diffRaw)
|
||||
}
|
||||
if got, ok := parsed.Result["binary"].(bool); !ok || got != true {
|
||||
t.Fatalf("binary = %v, want true", got)
|
||||
}
|
||||
if got, ok := parsed.Result["index"].(float64); !ok || int64(got) != int64(index) {
|
||||
t.Fatalf("index = %v, want %d", got, index)
|
||||
}
|
||||
if got, ok := parsed.Result["owner"].(string); !ok || got != owner {
|
||||
t.Fatalf("owner = %q, want %q", got, owner)
|
||||
}
|
||||
if got, ok := parsed.Result["repo"].(string); !ok || got != repo {
|
||||
t.Fatalf("repo = %q, want %q", got, repo)
|
||||
if got, ok := parsed.Result["diff"].(string); !ok || got != diffRaw {
|
||||
t.Fatalf("diff = %q, want %q", got, diffRaw)
|
||||
}
|
||||
if got, ok := parsed.Result["binary"].(bool); !ok || got != true {
|
||||
t.Fatalf("binary = %v, want true", got)
|
||||
}
|
||||
if got, ok := parsed.Result["index"].(float64); !ok || int64(got) != int64(index) {
|
||||
t.Fatalf("index = %v, want %d", got, index)
|
||||
}
|
||||
if got, ok := parsed.Result["owner"].(string); !ok || got != owner {
|
||||
t.Fatalf("owner = %q, want %q", got, owner)
|
||||
}
|
||||
if got, ok := parsed.Result["repo"].(string); !ok || got != repo {
|
||||
t.Fatalf("repo = %q, want %q", got, repo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user