final
This commit is contained in:
239
internal/usecase/agent_processor.go
Normal file
239
internal/usecase/agent_processor.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package usecase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/paramah/ai_devs4/s01e02/internal/domain"
|
||||
)
|
||||
|
||||
// AgentProcessorUseCase handles the processing of persons using LLM agent
|
||||
type AgentProcessorUseCase struct {
|
||||
personRepo domain.PersonRepository
|
||||
locationRepo domain.LocationRepository
|
||||
apiClient domain.APIClient
|
||||
llmProvider domain.LLMProvider
|
||||
apiKey string
|
||||
outputDir string
|
||||
}
|
||||
|
||||
// NewAgentProcessorUseCase creates a new use case instance
|
||||
func NewAgentProcessorUseCase(
|
||||
personRepo domain.PersonRepository,
|
||||
locationRepo domain.LocationRepository,
|
||||
apiClient domain.APIClient,
|
||||
llmProvider domain.LLMProvider,
|
||||
apiKey string,
|
||||
outputDir string,
|
||||
) *AgentProcessorUseCase {
|
||||
return &AgentProcessorUseCase{
|
||||
personRepo: personRepo,
|
||||
locationRepo: locationRepo,
|
||||
apiClient: apiClient,
|
||||
llmProvider: llmProvider,
|
||||
apiKey: apiKey,
|
||||
outputDir: outputDir,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute processes all persons using LLM agent
|
||||
func (uc *AgentProcessorUseCase) Execute(ctx context.Context, inputFile string) error {
|
||||
// Load persons from file
|
||||
log.Printf("Loading persons from: %s", inputFile)
|
||||
persons, err := uc.personRepo.LoadPersons(ctx, inputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading persons: %w", err)
|
||||
}
|
||||
log.Printf("Loaded %d persons", len(persons))
|
||||
|
||||
// Load power plant locations
|
||||
log.Printf("Loading power plant locations...")
|
||||
locations, err := uc.locationRepo.LoadLocations(ctx, uc.apiKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading locations: %w", err)
|
||||
}
|
||||
log.Printf("Loaded %d power plant locations", len(locations))
|
||||
|
||||
// Process each person with agent
|
||||
for i, person := range persons {
|
||||
log.Printf("\n[%d/%d] Processing: %s %s", i+1, len(persons), person.Name, person.Surname)
|
||||
|
||||
if err := uc.processPerson(ctx, person, locations); err != nil {
|
||||
log.Printf("Error processing %s %s: %v", person.Name, person.Surname, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("\nProcessing completed!")
|
||||
return nil
|
||||
}
|
||||
|
||||
// processPerson uses LLM agent to gather data for a person
|
||||
func (uc *AgentProcessorUseCase) processPerson(ctx context.Context, person domain.Person, powerPlants []domain.Location) error {
|
||||
tools := domain.GetToolDefinitions()
|
||||
|
||||
// Initial system message
|
||||
systemPrompt := fmt.Sprintf(`You are an agent that gathers information about people.
|
||||
For the person %s %s (born: %d), you need to:
|
||||
1. Call get_location to get their current location coordinates
|
||||
2. Call get_access_level to get their access level (remember: birth_year parameter must be only the year as integer, e.g., %d)
|
||||
|
||||
After gathering the data, respond with "DONE".`, person.Name, person.Surname, person.Born, person.Born)
|
||||
|
||||
messages := []domain.LLMMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("Please gather information for %s %s.", person.Name, person.Surname),
|
||||
},
|
||||
}
|
||||
|
||||
maxIterations := 10
|
||||
var personLocation *domain.PersonLocation
|
||||
|
||||
for iteration := 0; iteration < maxIterations; iteration++ {
|
||||
log.Printf(" [Iteration %d] Calling LLM...", iteration+1)
|
||||
|
||||
resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
ToolChoice: "auto",
|
||||
Temperature: 0.0,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("LLM chat error: %w", err)
|
||||
}
|
||||
|
||||
messages = append(messages, resp.Message)
|
||||
|
||||
// Check if LLM wants to call functions
|
||||
if len(resp.Message.ToolCalls) > 0 {
|
||||
log.Printf(" → LLM requested %d tool call(s)", len(resp.Message.ToolCalls))
|
||||
|
||||
for _, toolCall := range resp.Message.ToolCalls {
|
||||
result, loc, err := uc.executeToolCall(ctx, person, toolCall)
|
||||
if err != nil {
|
||||
return fmt.Errorf("executing tool call: %w", err)
|
||||
}
|
||||
|
||||
// Store person location if we got it from get_location
|
||||
if loc != nil {
|
||||
personLocation = loc
|
||||
}
|
||||
|
||||
messages = append(messages, domain.LLMMessage{
|
||||
Role: "tool",
|
||||
Content: result,
|
||||
ToolCallID: toolCall.ID,
|
||||
})
|
||||
}
|
||||
} else if resp.FinishReason == "stop" {
|
||||
log.Printf(" ✓ Agent completed gathering data")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate distances if we have location
|
||||
if personLocation != nil && len(powerPlants) > 0 {
|
||||
log.Printf(" → Calculating distances to power plants...")
|
||||
closest := domain.FindClosestLocation(*personLocation, powerPlants)
|
||||
if closest != nil {
|
||||
log.Printf(" ✓ Closest power plant: %s (%.2f km)", closest.Location, closest.DistanceKm)
|
||||
|
||||
// Save distance result
|
||||
distanceFile := filepath.Join(uc.outputDir, "distances", fmt.Sprintf("%s_%s.json", person.Name, person.Surname))
|
||||
distanceData, _ := json.MarshalIndent(closest, "", " ")
|
||||
os.WriteFile(distanceFile, distanceData, 0644)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeToolCall executes a tool call from the LLM
|
||||
func (uc *AgentProcessorUseCase) executeToolCall(ctx context.Context, person domain.Person, toolCall domain.ToolCall) (string, *domain.PersonLocation, error) {
|
||||
log.Printf(" → Executing: %s", toolCall.Function.Name)
|
||||
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
return "", nil, fmt.Errorf("parsing arguments: %w", err)
|
||||
}
|
||||
|
||||
switch toolCall.Function.Name {
|
||||
case "get_location":
|
||||
name, _ := args["name"].(string)
|
||||
surname, _ := args["surname"].(string)
|
||||
|
||||
req := domain.LocationRequest{
|
||||
APIKey: uc.apiKey,
|
||||
Name: name,
|
||||
Surname: surname,
|
||||
}
|
||||
|
||||
response, err := uc.apiClient.GetLocation(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err), nil, nil
|
||||
}
|
||||
|
||||
// Save response
|
||||
fileName := fmt.Sprintf("%s_%s.json", name, surname)
|
||||
filePath := filepath.Join(uc.outputDir, "locations", fileName)
|
||||
os.WriteFile(filePath, response, 0644)
|
||||
log.Printf(" ✓ Saved to: %s", filePath)
|
||||
|
||||
// Parse location to get coordinates (API returns array of locations)
|
||||
var locationData []map[string]interface{}
|
||||
if err := json.Unmarshal(response, &locationData); err == nil && len(locationData) > 0 {
|
||||
// Take first location from the array
|
||||
firstLoc := locationData[0]
|
||||
if lat, ok := firstLoc["latitude"].(float64); ok {
|
||||
if lon, ok := firstLoc["longitude"].(float64); ok {
|
||||
personLoc := &domain.PersonLocation{
|
||||
Name: name,
|
||||
Surname: surname,
|
||||
Latitude: lat,
|
||||
Longitude: lon,
|
||||
}
|
||||
return string(response), personLoc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return string(response), nil, nil
|
||||
|
||||
case "get_access_level":
|
||||
name, _ := args["name"].(string)
|
||||
surname, _ := args["surname"].(string)
|
||||
birthYear, _ := args["birth_year"].(float64)
|
||||
|
||||
req := domain.AccessLevelRequest{
|
||||
APIKey: uc.apiKey,
|
||||
Name: name,
|
||||
Surname: surname,
|
||||
BirthYear: int(birthYear),
|
||||
}
|
||||
|
||||
response, err := uc.apiClient.GetAccessLevel(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err), nil, nil
|
||||
}
|
||||
|
||||
// Save response
|
||||
fileName := fmt.Sprintf("%s_%s.json", name, surname)
|
||||
filePath := filepath.Join(uc.outputDir, "accesslevel", fileName)
|
||||
os.WriteFile(filePath, response, 0644)
|
||||
log.Printf(" ✓ Saved to: %s", filePath)
|
||||
|
||||
return string(response), nil, nil
|
||||
|
||||
default:
|
||||
return fmt.Sprintf("Unknown function: %s", toolCall.Function.Name), nil, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user