131 lines
3.3 KiB
Go
131 lines
3.3 KiB
Go
package llm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
"github.com/paramah/ai_devs4/s01e01/internal/domain"
|
|
)
|
|
|
|
// LMStudioProvider implements domain.LLMProvider for local LM Studio
|
|
type LMStudioProvider struct {
|
|
baseURL string
|
|
model string
|
|
client *http.Client
|
|
}
|
|
|
|
// NewLMStudioProvider creates a new LM Studio provider
|
|
func NewLMStudioProvider(baseURL, model string) *LMStudioProvider {
|
|
return &LMStudioProvider{
|
|
baseURL: baseURL,
|
|
model: model,
|
|
client: &http.Client{},
|
|
}
|
|
}
|
|
|
|
type lmStudioRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []map[string]interface{} `json:"messages"`
|
|
ResponseFormat *lmResponseFormat `json:"response_format,omitempty"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
}
|
|
|
|
type lmResponseFormat struct {
|
|
Type string `json:"type"`
|
|
JSONSchema interface{} `json:"json_schema,omitempty"`
|
|
}
|
|
|
|
type lmStudioResponse struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
Error json.RawMessage `json:"error,omitempty"`
|
|
}
|
|
|
|
// Complete sends a request to LM Studio local server
|
|
func (p *LMStudioProvider) Complete(ctx context.Context, request domain.LLMRequest) (*domain.LLMResponse, error) {
|
|
reqBody := lmStudioRequest{
|
|
Model: p.model,
|
|
Messages: []map[string]interface{}{
|
|
{
|
|
"role": "user",
|
|
"content": request.Prompt,
|
|
},
|
|
},
|
|
Temperature: 0.7,
|
|
}
|
|
|
|
if request.Schema != nil {
|
|
reqBody.ResponseFormat = &lmResponseFormat{
|
|
Type: "json_schema",
|
|
JSONSchema: request.Schema,
|
|
}
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshaling request: %w", err)
|
|
}
|
|
|
|
url := p.baseURL + "/v1/chat/completions"
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := p.client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sending request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading response: %w", err)
|
|
}
|
|
|
|
// Check HTTP status
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var apiResp lmStudioResponse
|
|
if err := json.Unmarshal(body, &apiResp); err != nil {
|
|
return nil, fmt.Errorf("unmarshaling response: %w\nResponse body: %s", err, string(body))
|
|
}
|
|
|
|
// Check for error in response
|
|
if len(apiResp.Error) > 0 {
|
|
// Try to parse as string
|
|
var errStr string
|
|
if err := json.Unmarshal(apiResp.Error, &errStr); err == nil {
|
|
return nil, fmt.Errorf("API error: %s", errStr)
|
|
}
|
|
// Try to parse as object with message field
|
|
var errObj struct {
|
|
Message string `json:"message"`
|
|
}
|
|
if err := json.Unmarshal(apiResp.Error, &errObj); err == nil {
|
|
return nil, fmt.Errorf("API error: %s", errObj.Message)
|
|
}
|
|
// Fallback to raw error
|
|
return nil, fmt.Errorf("API error: %s", string(apiResp.Error))
|
|
}
|
|
|
|
if len(apiResp.Choices) == 0 {
|
|
return nil, fmt.Errorf("no choices in response. Response body: %s", string(body))
|
|
}
|
|
|
|
return &domain.LLMResponse{
|
|
Content: apiResp.Choices[0].Message.Content,
|
|
}, nil
|
|
}
|