diff --git a/operation/operation.go b/operation/operation.go index 65daa86..da0c811 100644 --- a/operation/operation.go +++ b/operation/operation.go @@ -52,18 +52,34 @@ func RegisterTool(s *server.MCPServer) { s.DeleteTools("") } +// parseBearerToken extracts the Bearer token from an Authorization header. +// Returns the token and true if valid, empty string and false otherwise. +func parseBearerToken(authHeader string) (string, bool) { + const bearerPrefix = "Bearer " + if len(authHeader) < len(bearerPrefix) || !strings.HasPrefix(authHeader, bearerPrefix) { + return "", false + } + + token := strings.TrimSpace(authHeader[len(bearerPrefix):]) + if token == "" { + return "", false + } + + return token, true +} + func getContextWithToken(ctx context.Context, r *http.Request) context.Context { authHeader := r.Header.Get("Authorization") if authHeader == "" { return ctx } - parts := strings.Split(authHeader, " ") - if len(parts) != 2 || parts[0] != "Bearer" { + token, ok := parseBearerToken(authHeader) + if !ok { return ctx } - return context.WithValue(ctx, mcpContext.TokenContextKey, parts[1]) + return context.WithValue(ctx, mcpContext.TokenContextKey, token) } func Run() error { diff --git a/operation/operation_test.go b/operation/operation_test.go new file mode 100644 index 0000000..758ae04 --- /dev/null +++ b/operation/operation_test.go @@ -0,0 +1,81 @@ +package operation + +import ( + "testing" +) + +func TestParseBearerToken(t *testing.T) { + tests := []struct { + name string + header string + wantToken string + wantOK bool + }{ + { + name: "valid token", + header: "Bearer validtoken", + wantToken: "validtoken", + wantOK: true, + }, + { + name: "token with spaces trimmed", + header: "Bearer spacedToken ", + wantToken: "spacedToken", + wantOK: true, + }, + { + name: "lowercase bearer should fail", + header: "bearer lowercase", + wantToken: "", + wantOK: false, + }, + { + name: "bearer with no token", + header: "Bearer ", + wantToken: "", + wantOK: false, + }, + { + name: "bearer with only spaces", + header: "Bearer ", + wantToken: "", + wantOK: false, + }, + { + name: "missing space after Bearer", + header: "Bearertoken", + wantToken: "", + wantOK: false, + }, + { + name: "different auth type", + header: "Basic dXNlcjpwYXNz", + wantToken: "", + wantOK: false, + }, + { + name: "empty header", + header: "", + wantToken: "", + wantOK: false, + }, + { + name: "token with internal spaces", + header: "Bearer token with spaces", + wantToken: "token with spaces", + wantOK: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotToken, gotOK := parseBearerToken(tt.header) + if gotToken != tt.wantToken { + t.Errorf("parseBearerToken() token = %q, want %q", gotToken, tt.wantToken) + } + if gotOK != tt.wantOK { + t.Errorf("parseBearerToken() ok = %v, want %v", gotOK, tt.wantOK) + } + }) + } +}