From 21e4e1b42bd694356b381bd0211ee391df0b8112 Mon Sep 17 00:00:00 2001 From: silverwind Date: Fri, 13 Feb 2026 13:26:21 +0000 Subject: [PATCH] feat: add edit_pull_request tool (#125) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add `edit_pull_request` MCP tool to modify pull request properties - Supports editing title, body, base branch, assignees, milestone, state, and maintainer edit permission - Enables toggling WIP/draft status by modifying the title prefix Fixes https://gitea.com/gitea/gitea-mcp/issues/124 ## Test plan - [x] `go test ./...` passes - [x] Verified against gitea.com: toggled WIP on/off via title edit, changed PR state 🤖 Generated with [Claude Code](https://claude.ai/claude-code) Reviewed-on: https://gitea.com/gitea/gitea-mcp/pulls/125 Reviewed-by: Bo-Yi Wu (吳柏毅) Co-authored-by: silverwind Co-committed-by: silverwind --- operation/pull/pull.go | 85 +++++++++++++++++++++++++++++ operation/pull/pull_test.go | 104 ++++++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+) diff --git a/operation/pull/pull.go b/operation/pull/pull.go index cbf839b..f0ce34c 100644 --- a/operation/pull/pull.go +++ b/operation/pull/pull.go @@ -6,6 +6,7 @@ import ( "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" + "gitea.com/gitea/gitea-mcp/pkg/ptr" "gitea.com/gitea/gitea-mcp/pkg/to" "gitea.com/gitea/gitea-mcp/pkg/tool" @@ -31,6 +32,7 @@ const ( DeletePullRequestReviewToolName = "delete_pull_request_review" DismissPullRequestReviewToolName = "dismiss_pull_request_review" MergePullRequestToolName = "merge_pull_request" + EditPullRequestToolName = "edit_pull_request" ) var ( @@ -182,6 +184,22 @@ var ( mcp.WithString("message", mcp.Description("custom merge commit message (optional)")), mcp.WithBoolean("delete_branch", mcp.Description("delete the branch after merge"), mcp.DefaultBool(false)), ) + + EditPullRequestTool = mcp.NewTool( + EditPullRequestToolName, + mcp.WithDescription("edit a pull request"), + mcp.WithString("owner", mcp.Required(), mcp.Description("repository owner")), + mcp.WithString("repo", mcp.Required(), mcp.Description("repository name")), + mcp.WithNumber("index", mcp.Required(), mcp.Description("pull request index")), + mcp.WithString("title", mcp.Description("pull request title")), + mcp.WithString("body", mcp.Description("pull request body content")), + mcp.WithString("base", mcp.Description("pull request base branch")), + mcp.WithString("assignee", mcp.Description("username to assign")), + mcp.WithArray("assignees", mcp.Description("usernames to assign"), mcp.Items(map[string]interface{}{"type": "string"})), + mcp.WithNumber("milestone", mcp.Description("milestone number")), + mcp.WithString("state", mcp.Description("pull request state"), mcp.Enum("open", "closed")), + mcp.WithBoolean("allow_maintainer_edit", mcp.Description("allow maintainer to edit the pull request")), + ) ) func init() { @@ -241,6 +259,10 @@ func init() { Tool: MergePullRequestTool, Handler: MergePullRequestFn, }) + Tool.RegisterWrite(server.ServerTool{ + Tool: EditPullRequestTool, + Handler: EditPullRequestFn, + }) } func GetPullRequestByIndexFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -876,3 +898,66 @@ func MergePullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call return to.TextResult(successMsg) } + +func EditPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + log.Debugf("Called EditPullRequestFn") + owner, ok := req.GetArguments()["owner"].(string) + if !ok { + return to.ErrorResult(fmt.Errorf("owner is required")) + } + repo, ok := req.GetArguments()["repo"].(string) + if !ok { + return to.ErrorResult(fmt.Errorf("repo is required")) + } + index, ok := req.GetArguments()["index"].(float64) + if !ok { + return to.ErrorResult(fmt.Errorf("index is required")) + } + + opt := gitea_sdk.EditPullRequestOption{} + + if title, ok := req.GetArguments()["title"].(string); ok { + opt.Title = title + } + if body, ok := req.GetArguments()["body"].(string); ok { + opt.Body = ptr.To(body) + } + if base, ok := req.GetArguments()["base"].(string); ok { + opt.Base = base + } + if assignee, ok := req.GetArguments()["assignee"].(string); ok { + opt.Assignee = assignee + } + if assigneesArg, exists := req.GetArguments()["assignees"]; exists { + if assigneesSlice, ok := assigneesArg.([]interface{}); ok { + var assignees []string + for _, a := range assigneesSlice { + if s, ok := a.(string); ok { + assignees = append(assignees, s) + } + } + opt.Assignees = assignees + } + } + if milestone, ok := req.GetArguments()["milestone"].(float64); ok { + opt.Milestone = int64(milestone) + } + if state, ok := req.GetArguments()["state"].(string); ok { + opt.State = ptr.To(gitea_sdk.StateType(state)) + } + if allowMaintainerEdit, ok := req.GetArguments()["allow_maintainer_edit"].(bool); ok { + opt.AllowMaintainerEdit = ptr.To(allowMaintainerEdit) + } + + client, err := gitea.ClientFromContext(ctx) + if err != nil { + return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) + } + + pr, _, err := client.EditPullRequest(owner, repo, int64(index), opt) + if err != nil { + return to.ErrorResult(fmt.Errorf("edit %v/%v/pr/%v err: %v", owner, repo, int64(index), err)) + } + + return to.TextResult(pr) +} diff --git a/operation/pull/pull_test.go b/operation/pull/pull_test.go index 8da2831..aaca693 100644 --- a/operation/pull/pull_test.go +++ b/operation/pull/pull_test.go @@ -13,6 +13,110 @@ import ( "github.com/mark3labs/mcp-go/mcp" ) +func TestEditPullRequestFn(t *testing.T) { + const ( + owner = "octo" + repo = "demo" + index = 7 + ) + + var ( + mu sync.Mutex + gotMethod string + gotPath string + gotBody map[string]any + ) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/version": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version":"1.12.0"}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"private":false}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index): + mu.Lock() + gotMethod = r.Method + gotPath = r.URL.Path + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + gotBody = body + mu.Unlock() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(fmt.Sprintf(`{"number":%d,"title":"%s","state":"open"}`, index, body["title"]))) + default: + http.NotFound(w, r) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + origHost := flag.Host + origToken := flag.Token + origVersion := flag.Version + flag.Host = server.URL + flag.Token = "" + flag.Version = "test" + defer func() { + flag.Host = origHost + flag.Token = origToken + flag.Version = origVersion + }() + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "owner": owner, + "repo": repo, + "index": float64(index), + "title": "WIP: my feature", + "state": "open", + }, + }, + } + + result, err := EditPullRequestFn(context.Background(), req) + if err != nil { + t.Fatalf("EditPullRequestFn() error = %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if gotMethod != http.MethodPatch { + t.Fatalf("expected PATCH request, got %s", gotMethod) + } + if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index) { + t.Fatalf("unexpected path: %s", gotPath) + } + if gotBody["title"] != "WIP: my feature" { + t.Fatalf("expected title 'WIP: my feature', got %v", gotBody["title"]) + } + if gotBody["state"] != "open" { + t.Fatalf("expected state 'open', got %v", gotBody["state"]) + } + + if len(result.Content) == 0 { + t.Fatalf("expected content in result") + } + textContent, ok := mcp.AsTextContent(result.Content[0]) + if !ok { + t.Fatalf("expected text content, got %T", result.Content[0]) + } + + var parsed struct { + Result map[string]any `json:"Result"` + } + if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil { + t.Fatalf("unmarshal result text: %v", err) + } + if got := parsed.Result["title"].(string); got != "WIP: my feature" { + t.Fatalf("result title = %q, want %q", got, "WIP: my feature") + } +} + func TestGetPullRequestDiffFn(t *testing.T) { const ( owner = "octo"