114 lines
2.8 KiB
Go
114 lines
2.8 KiB
Go
package llm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
"github.com/paramah/ai_devs4/s01e02/internal/domain"
|
|
)
|
|
|
|
// OpenRouterProvider implements domain.LLMProvider for OpenRouter API
|
|
type OpenRouterProvider struct {
|
|
apiKey string
|
|
model string
|
|
baseURL string
|
|
client *http.Client
|
|
}
|
|
|
|
// NewOpenRouterProvider creates a new OpenRouter provider
|
|
func NewOpenRouterProvider(apiKey, model string) *OpenRouterProvider {
|
|
return &OpenRouterProvider{
|
|
apiKey: apiKey,
|
|
model: model,
|
|
baseURL: "https://openrouter.ai/api/v1/chat/completions",
|
|
client: &http.Client{},
|
|
}
|
|
}
|
|
|
|
type openRouterRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []domain.LLMMessage `json:"messages"`
|
|
Tools []domain.Tool `json:"tools,omitempty"`
|
|
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
}
|
|
|
|
type openRouterResponse struct {
|
|
Choices []struct {
|
|
Message domain.LLMMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
// Chat sends a chat request with function calling support
|
|
func (p *OpenRouterProvider) Chat(ctx context.Context, request domain.LLMRequest) (*domain.LLMResponse, error) {
|
|
reqBody := openRouterRequest{
|
|
Model: p.model,
|
|
Messages: request.Messages,
|
|
Tools: request.Tools,
|
|
Temperature: request.Temperature,
|
|
}
|
|
|
|
if request.ToolChoice != "" {
|
|
if request.ToolChoice == "auto" {
|
|
reqBody.ToolChoice = "auto"
|
|
} else {
|
|
reqBody.ToolChoice = map[string]interface{}{
|
|
"type": "function",
|
|
"function": map[string]string{
|
|
"name": request.ToolChoice,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshaling request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
|
|
|
resp, err := p.client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sending request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading response: %w", err)
|
|
}
|
|
|
|
var apiResp openRouterResponse
|
|
if err := json.Unmarshal(body, &apiResp); err != nil {
|
|
return nil, fmt.Errorf("unmarshaling response: %w", err)
|
|
}
|
|
|
|
if apiResp.Error != nil {
|
|
return nil, fmt.Errorf("API error: %s", apiResp.Error.Message)
|
|
}
|
|
|
|
if len(apiResp.Choices) == 0 {
|
|
return nil, fmt.Errorf("no choices in response")
|
|
}
|
|
|
|
return &domain.LLMResponse{
|
|
Message: apiResp.Choices[0].Message,
|
|
FinishReason: apiResp.Choices[0].FinishReason,
|
|
}, nil
|
|
}
|