Files
s01e02/internal/usecase/agent_processor.go
2026-03-12 02:10:57 +01:00

240 lines
7.0 KiB
Go

package usecase
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"github.com/paramah/ai_devs4/s01e02/internal/domain"
)
// AgentProcessorUseCase handles the processing of persons using LLM agent
type AgentProcessorUseCase struct {
personRepo domain.PersonRepository
locationRepo domain.LocationRepository
apiClient domain.APIClient
llmProvider domain.LLMProvider
apiKey string
outputDir string
}
// NewAgentProcessorUseCase creates a new use case instance
func NewAgentProcessorUseCase(
personRepo domain.PersonRepository,
locationRepo domain.LocationRepository,
apiClient domain.APIClient,
llmProvider domain.LLMProvider,
apiKey string,
outputDir string,
) *AgentProcessorUseCase {
return &AgentProcessorUseCase{
personRepo: personRepo,
locationRepo: locationRepo,
apiClient: apiClient,
llmProvider: llmProvider,
apiKey: apiKey,
outputDir: outputDir,
}
}
// Execute processes all persons using LLM agent
func (uc *AgentProcessorUseCase) Execute(ctx context.Context, inputFile string) error {
// Load persons from file
log.Printf("Loading persons from: %s", inputFile)
persons, err := uc.personRepo.LoadPersons(ctx, inputFile)
if err != nil {
return fmt.Errorf("loading persons: %w", err)
}
log.Printf("Loaded %d persons", len(persons))
// Load power plant locations
log.Printf("Loading power plant locations...")
locations, err := uc.locationRepo.LoadLocations(ctx, uc.apiKey)
if err != nil {
return fmt.Errorf("loading locations: %w", err)
}
log.Printf("Loaded %d power plant locations", len(locations))
// Process each person with agent
for i, person := range persons {
log.Printf("\n[%d/%d] Processing: %s %s", i+1, len(persons), person.Name, person.Surname)
if err := uc.processPerson(ctx, person, locations); err != nil {
log.Printf("Error processing %s %s: %v", person.Name, person.Surname, err)
continue
}
}
log.Printf("\nProcessing completed!")
return nil
}
// processPerson uses LLM agent to gather data for a person
func (uc *AgentProcessorUseCase) processPerson(ctx context.Context, person domain.Person, powerPlants []domain.Location) error {
tools := domain.GetToolDefinitions()
// Initial system message
systemPrompt := fmt.Sprintf(`You are an agent that gathers information about people.
For the person %s %s (born: %d), you need to:
1. Call get_location to get their current location coordinates
2. Call get_access_level to get their access level (remember: birth_year parameter must be only the year as integer, e.g., %d)
After gathering the data, respond with "DONE".`, person.Name, person.Surname, person.Born, person.Born)
messages := []domain.LLMMessage{
{
Role: "system",
Content: systemPrompt,
},
{
Role: "user",
Content: fmt.Sprintf("Please gather information for %s %s.", person.Name, person.Surname),
},
}
maxIterations := 10
var personLocation *domain.PersonLocation
for iteration := 0; iteration < maxIterations; iteration++ {
log.Printf(" [Iteration %d] Calling LLM...", iteration+1)
resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{
Messages: messages,
Tools: tools,
ToolChoice: "auto",
Temperature: 0.0,
})
if err != nil {
return fmt.Errorf("LLM chat error: %w", err)
}
messages = append(messages, resp.Message)
// Check if LLM wants to call functions
if len(resp.Message.ToolCalls) > 0 {
log.Printf(" → LLM requested %d tool call(s)", len(resp.Message.ToolCalls))
for _, toolCall := range resp.Message.ToolCalls {
result, loc, err := uc.executeToolCall(ctx, person, toolCall)
if err != nil {
return fmt.Errorf("executing tool call: %w", err)
}
// Store person location if we got it from get_location
if loc != nil {
personLocation = loc
}
messages = append(messages, domain.LLMMessage{
Role: "tool",
Content: result,
ToolCallID: toolCall.ID,
})
}
} else if resp.FinishReason == "stop" {
log.Printf(" ✓ Agent completed gathering data")
break
}
}
// Calculate distances if we have location
if personLocation != nil && len(powerPlants) > 0 {
log.Printf(" → Calculating distances to power plants...")
closest := domain.FindClosestLocation(*personLocation, powerPlants)
if closest != nil {
log.Printf(" ✓ Closest power plant: %s (%.2f km)", closest.Location, closest.DistanceKm)
// Save distance result
distanceFile := filepath.Join(uc.outputDir, "distances", fmt.Sprintf("%s_%s.json", person.Name, person.Surname))
distanceData, _ := json.MarshalIndent(closest, "", " ")
os.WriteFile(distanceFile, distanceData, 0644)
}
}
return nil
}
// executeToolCall executes a tool call from the LLM
func (uc *AgentProcessorUseCase) executeToolCall(ctx context.Context, person domain.Person, toolCall domain.ToolCall) (string, *domain.PersonLocation, error) {
log.Printf(" → Executing: %s", toolCall.Function.Name)
var args map[string]interface{}
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
return "", nil, fmt.Errorf("parsing arguments: %w", err)
}
switch toolCall.Function.Name {
case "get_location":
name, _ := args["name"].(string)
surname, _ := args["surname"].(string)
req := domain.LocationRequest{
APIKey: uc.apiKey,
Name: name,
Surname: surname,
}
response, err := uc.apiClient.GetLocation(ctx, req)
if err != nil {
return fmt.Sprintf("Error: %v", err), nil, nil
}
// Save response
fileName := fmt.Sprintf("%s_%s.json", name, surname)
filePath := filepath.Join(uc.outputDir, "locations", fileName)
os.WriteFile(filePath, response, 0644)
log.Printf(" ✓ Saved to: %s", filePath)
// Parse location to get coordinates (API returns array of locations)
var locationData []map[string]interface{}
if err := json.Unmarshal(response, &locationData); err == nil && len(locationData) > 0 {
// Take first location from the array
firstLoc := locationData[0]
if lat, ok := firstLoc["latitude"].(float64); ok {
if lon, ok := firstLoc["longitude"].(float64); ok {
personLoc := &domain.PersonLocation{
Name: name,
Surname: surname,
Latitude: lat,
Longitude: lon,
}
return string(response), personLoc, nil
}
}
}
return string(response), nil, nil
case "get_access_level":
name, _ := args["name"].(string)
surname, _ := args["surname"].(string)
birthYear, _ := args["birth_year"].(float64)
req := domain.AccessLevelRequest{
APIKey: uc.apiKey,
Name: name,
Surname: surname,
BirthYear: int(birthYear),
}
response, err := uc.apiClient.GetAccessLevel(ctx, req)
if err != nil {
return fmt.Sprintf("Error: %v", err), nil, nil
}
// Save response
fileName := fmt.Sprintf("%s_%s.json", name, surname)
filePath := filepath.Join(uc.outputDir, "accesslevel", fileName)
os.WriteFile(filePath, response, 0644)
log.Printf(" ✓ Saved to: %s", filePath)
return string(response), nil, nil
default:
return fmt.Sprintf("Unknown function: %s", toolCall.Function.Name), nil, nil
}
}