240 lines
7.0 KiB
Go
240 lines
7.0 KiB
Go
package usecase
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/paramah/ai_devs4/s01e02/internal/domain"
|
|
)
|
|
|
|
// AgentProcessorUseCase handles the processing of persons using LLM agent
|
|
type AgentProcessorUseCase struct {
|
|
personRepo domain.PersonRepository
|
|
locationRepo domain.LocationRepository
|
|
apiClient domain.APIClient
|
|
llmProvider domain.LLMProvider
|
|
apiKey string
|
|
outputDir string
|
|
}
|
|
|
|
// NewAgentProcessorUseCase creates a new use case instance
|
|
func NewAgentProcessorUseCase(
|
|
personRepo domain.PersonRepository,
|
|
locationRepo domain.LocationRepository,
|
|
apiClient domain.APIClient,
|
|
llmProvider domain.LLMProvider,
|
|
apiKey string,
|
|
outputDir string,
|
|
) *AgentProcessorUseCase {
|
|
return &AgentProcessorUseCase{
|
|
personRepo: personRepo,
|
|
locationRepo: locationRepo,
|
|
apiClient: apiClient,
|
|
llmProvider: llmProvider,
|
|
apiKey: apiKey,
|
|
outputDir: outputDir,
|
|
}
|
|
}
|
|
|
|
// Execute processes all persons using LLM agent
|
|
func (uc *AgentProcessorUseCase) Execute(ctx context.Context, inputFile string) error {
|
|
// Load persons from file
|
|
log.Printf("Loading persons from: %s", inputFile)
|
|
persons, err := uc.personRepo.LoadPersons(ctx, inputFile)
|
|
if err != nil {
|
|
return fmt.Errorf("loading persons: %w", err)
|
|
}
|
|
log.Printf("Loaded %d persons", len(persons))
|
|
|
|
// Load power plant locations
|
|
log.Printf("Loading power plant locations...")
|
|
locations, err := uc.locationRepo.LoadLocations(ctx, uc.apiKey)
|
|
if err != nil {
|
|
return fmt.Errorf("loading locations: %w", err)
|
|
}
|
|
log.Printf("Loaded %d power plant locations", len(locations))
|
|
|
|
// Process each person with agent
|
|
for i, person := range persons {
|
|
log.Printf("\n[%d/%d] Processing: %s %s", i+1, len(persons), person.Name, person.Surname)
|
|
|
|
if err := uc.processPerson(ctx, person, locations); err != nil {
|
|
log.Printf("Error processing %s %s: %v", person.Name, person.Surname, err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
log.Printf("\nProcessing completed!")
|
|
return nil
|
|
}
|
|
|
|
// processPerson uses LLM agent to gather data for a person
|
|
func (uc *AgentProcessorUseCase) processPerson(ctx context.Context, person domain.Person, powerPlants []domain.Location) error {
|
|
tools := domain.GetToolDefinitions()
|
|
|
|
// Initial system message
|
|
systemPrompt := fmt.Sprintf(`You are an agent that gathers information about people.
|
|
For the person %s %s (born: %d), you need to:
|
|
1. Call get_location to get their current location coordinates
|
|
2. Call get_access_level to get their access level (remember: birth_year parameter must be only the year as integer, e.g., %d)
|
|
|
|
After gathering the data, respond with "DONE".`, person.Name, person.Surname, person.Born, person.Born)
|
|
|
|
messages := []domain.LLMMessage{
|
|
{
|
|
Role: "system",
|
|
Content: systemPrompt,
|
|
},
|
|
{
|
|
Role: "user",
|
|
Content: fmt.Sprintf("Please gather information for %s %s.", person.Name, person.Surname),
|
|
},
|
|
}
|
|
|
|
maxIterations := 10
|
|
var personLocation *domain.PersonLocation
|
|
|
|
for iteration := 0; iteration < maxIterations; iteration++ {
|
|
log.Printf(" [Iteration %d] Calling LLM...", iteration+1)
|
|
|
|
resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{
|
|
Messages: messages,
|
|
Tools: tools,
|
|
ToolChoice: "auto",
|
|
Temperature: 0.0,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("LLM chat error: %w", err)
|
|
}
|
|
|
|
messages = append(messages, resp.Message)
|
|
|
|
// Check if LLM wants to call functions
|
|
if len(resp.Message.ToolCalls) > 0 {
|
|
log.Printf(" → LLM requested %d tool call(s)", len(resp.Message.ToolCalls))
|
|
|
|
for _, toolCall := range resp.Message.ToolCalls {
|
|
result, loc, err := uc.executeToolCall(ctx, person, toolCall)
|
|
if err != nil {
|
|
return fmt.Errorf("executing tool call: %w", err)
|
|
}
|
|
|
|
// Store person location if we got it from get_location
|
|
if loc != nil {
|
|
personLocation = loc
|
|
}
|
|
|
|
messages = append(messages, domain.LLMMessage{
|
|
Role: "tool",
|
|
Content: result,
|
|
ToolCallID: toolCall.ID,
|
|
})
|
|
}
|
|
} else if resp.FinishReason == "stop" {
|
|
log.Printf(" ✓ Agent completed gathering data")
|
|
break
|
|
}
|
|
}
|
|
|
|
// Calculate distances if we have location
|
|
if personLocation != nil && len(powerPlants) > 0 {
|
|
log.Printf(" → Calculating distances to power plants...")
|
|
closest := domain.FindClosestLocation(*personLocation, powerPlants)
|
|
if closest != nil {
|
|
log.Printf(" ✓ Closest power plant: %s (%.2f km)", closest.Location, closest.DistanceKm)
|
|
|
|
// Save distance result
|
|
distanceFile := filepath.Join(uc.outputDir, "distances", fmt.Sprintf("%s_%s.json", person.Name, person.Surname))
|
|
distanceData, _ := json.MarshalIndent(closest, "", " ")
|
|
os.WriteFile(distanceFile, distanceData, 0644)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// executeToolCall executes a tool call from the LLM
|
|
func (uc *AgentProcessorUseCase) executeToolCall(ctx context.Context, person domain.Person, toolCall domain.ToolCall) (string, *domain.PersonLocation, error) {
|
|
log.Printf(" → Executing: %s", toolCall.Function.Name)
|
|
|
|
var args map[string]interface{}
|
|
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
|
return "", nil, fmt.Errorf("parsing arguments: %w", err)
|
|
}
|
|
|
|
switch toolCall.Function.Name {
|
|
case "get_location":
|
|
name, _ := args["name"].(string)
|
|
surname, _ := args["surname"].(string)
|
|
|
|
req := domain.LocationRequest{
|
|
APIKey: uc.apiKey,
|
|
Name: name,
|
|
Surname: surname,
|
|
}
|
|
|
|
response, err := uc.apiClient.GetLocation(ctx, req)
|
|
if err != nil {
|
|
return fmt.Sprintf("Error: %v", err), nil, nil
|
|
}
|
|
|
|
// Save response
|
|
fileName := fmt.Sprintf("%s_%s.json", name, surname)
|
|
filePath := filepath.Join(uc.outputDir, "locations", fileName)
|
|
os.WriteFile(filePath, response, 0644)
|
|
log.Printf(" ✓ Saved to: %s", filePath)
|
|
|
|
// Parse location to get coordinates (API returns array of locations)
|
|
var locationData []map[string]interface{}
|
|
if err := json.Unmarshal(response, &locationData); err == nil && len(locationData) > 0 {
|
|
// Take first location from the array
|
|
firstLoc := locationData[0]
|
|
if lat, ok := firstLoc["latitude"].(float64); ok {
|
|
if lon, ok := firstLoc["longitude"].(float64); ok {
|
|
personLoc := &domain.PersonLocation{
|
|
Name: name,
|
|
Surname: surname,
|
|
Latitude: lat,
|
|
Longitude: lon,
|
|
}
|
|
return string(response), personLoc, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return string(response), nil, nil
|
|
|
|
case "get_access_level":
|
|
name, _ := args["name"].(string)
|
|
surname, _ := args["surname"].(string)
|
|
birthYear, _ := args["birth_year"].(float64)
|
|
|
|
req := domain.AccessLevelRequest{
|
|
APIKey: uc.apiKey,
|
|
Name: name,
|
|
Surname: surname,
|
|
BirthYear: int(birthYear),
|
|
}
|
|
|
|
response, err := uc.apiClient.GetAccessLevel(ctx, req)
|
|
if err != nil {
|
|
return fmt.Sprintf("Error: %v", err), nil, nil
|
|
}
|
|
|
|
// Save response
|
|
fileName := fmt.Sprintf("%s_%s.json", name, surname)
|
|
filePath := filepath.Join(uc.outputDir, "accesslevel", fileName)
|
|
os.WriteFile(filePath, response, 0644)
|
|
log.Printf(" ✓ Saved to: %s", filePath)
|
|
|
|
return string(response), nil, nil
|
|
|
|
default:
|
|
return fmt.Sprintf("Unknown function: %s", toolCall.Function.Name), nil, nil
|
|
}
|
|
}
|