Files
s01e02/internal/usecase/optimized_agent_processor.go
2026-03-12 02:10:57 +01:00

462 lines
13 KiB
Go

package usecase
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"github.com/paramah/ai_devs4/s01e02/internal/domain"
)
// OptimizedAgentProcessorUseCase handles optimized processing with location reports
type OptimizedAgentProcessorUseCase struct {
personRepo domain.PersonRepository
locationRepo domain.LocationRepository
apiClient domain.APIClient
llmProvider domain.LLMProvider
apiKey string
outputDir string
}
// NewOptimizedAgentProcessorUseCase creates a new optimized use case instance
func NewOptimizedAgentProcessorUseCase(
personRepo domain.PersonRepository,
locationRepo domain.LocationRepository,
apiClient domain.APIClient,
llmProvider domain.LLMProvider,
apiKey string,
outputDir string,
) *OptimizedAgentProcessorUseCase {
return &OptimizedAgentProcessorUseCase{
personRepo: personRepo,
locationRepo: locationRepo,
apiClient: apiClient,
llmProvider: llmProvider,
apiKey: apiKey,
outputDir: outputDir,
}
}
// Execute processes all persons and generates location reports
func (uc *OptimizedAgentProcessorUseCase) Execute(ctx context.Context, inputFile string) error {
// Load persons from file
log.Printf("Loading persons from: %s", inputFile)
persons, err := uc.personRepo.LoadPersons(ctx, inputFile)
if err != nil {
return fmt.Errorf("loading persons: %w", err)
}
log.Printf("Loaded %d persons", len(persons))
// Load power plant locations
log.Printf("Loading power plant locations...")
powerPlants, err := uc.locationRepo.LoadLocations(ctx, uc.apiKey)
if err != nil {
return fmt.Errorf("loading locations: %w", err)
}
log.Printf("Loaded %d power plant locations", len(powerPlants))
// Also load raw power plants data for codes
plantsData, err := uc.loadPowerPlantsData(ctx)
if err != nil {
return fmt.Errorf("loading power plants data: %w", err)
}
// Save power plants data to output
plantsDataJSON, err := json.MarshalIndent(plantsData, "", " ")
if err != nil {
log.Printf("Warning: Failed to marshal power plants data: %v", err)
} else {
plantsFilePath := filepath.Join(uc.outputDir, "findhim_locations.json")
if err := os.WriteFile(plantsFilePath, plantsDataJSON, 0644); err != nil {
log.Printf("Warning: Failed to save power plants data: %v", err)
} else {
log.Printf("✓ Power plants data saved to: %s", plantsFilePath)
}
}
// Phase 1: Gather all data for all persons
log.Printf("\n========== Phase 1: Gathering person data ==========")
personDataMap := make(map[string]*domain.PersonData)
for i, person := range persons {
log.Printf("[%d/%d] Gathering data for: %s %s", i+1, len(persons), person.Name, person.Surname)
personData, err := uc.gatherPersonData(ctx, person)
if err != nil {
log.Printf("Error gathering data for %s %s: %v", person.Name, person.Surname, err)
continue
}
key := fmt.Sprintf("%s_%s", person.Name, person.Surname)
personDataMap[key] = personData
}
// Phase 2: Calculate all distances and generate location reports
log.Printf("\n========== Phase 2: Generating location reports ==========")
locationReports := uc.generateLocationReports(personDataMap, powerPlants)
// Phase 3: Save reports
log.Printf("\n========== Phase 3: Saving reports ==========")
for _, report := range locationReports {
uc.saveLocationReport(report)
}
// Phase 4: Find closest power plant (locally, no LLM needed)
log.Printf("\n========== Phase 4: Finding closest power plant ==========")
closestPerson, closestPlantName, distance, plantCode := uc.findClosestPowerPlantLocally(personDataMap, powerPlants, plantsData)
if closestPerson != nil {
log.Printf(" ✓ Closest power plant: %s (%s)", closestPlantName, plantCode)
log.Printf(" ✓ Closest person: %s %s", closestPerson.Person.Name, closestPerson.Person.Surname)
log.Printf(" ✓ Distance: %.2f km", distance)
log.Printf(" ✓ Access level: %d", closestPerson.AccessLevel)
// Save final answer
finalAnswer := domain.FinalAnswer{
APIKey: uc.apiKey,
Task: "findhim",
Answer: domain.AnswerDetail{
Name: closestPerson.Person.Name,
Surname: closestPerson.Person.Surname,
AccessLevel: closestPerson.AccessLevel,
PowerPlant: plantCode,
},
}
answerJSON, err := json.MarshalIndent(finalAnswer, "", " ")
if err != nil {
return fmt.Errorf("marshaling final answer: %w", err)
}
answerPath := filepath.Join(uc.outputDir, "final_answer.json")
if err := os.WriteFile(answerPath, answerJSON, 0644); err != nil {
return fmt.Errorf("saving final answer: %w", err)
}
log.Printf(" ✓ Final answer saved to: %s", answerPath)
}
log.Printf("\nProcessing completed!")
return nil
}
// gatherPersonData uses LLM agent to gather all data for a person (optimized)
func (uc *OptimizedAgentProcessorUseCase) gatherPersonData(ctx context.Context, person domain.Person) (*domain.PersonData, error) {
tools := domain.GetToolDefinitions()
// Minimal, optimized system prompt
systemPrompt := fmt.Sprintf(`Gather data for person: %s %s (born: %d).
Tasks:
1. Call get_location
2. Call get_access_level (use birth_year: %d)
Then respond "DONE".`, person.Name, person.Surname, person.Born, person.Born)
messages := []domain.LLMMessage{
{
Role: "system",
Content: systemPrompt,
},
{
Role: "user",
Content: "Start gathering data.",
},
}
maxIterations := 5
var personLocations []domain.PersonLocation
var accessLevel int
for iteration := 0; iteration < maxIterations; iteration++ {
resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{
Messages: messages,
Tools: tools,
ToolChoice: "auto",
Temperature: 0.0,
})
if err != nil {
return nil, fmt.Errorf("LLM chat error: %w", err)
}
messages = append(messages, resp.Message)
if len(resp.Message.ToolCalls) > 0 {
for _, toolCall := range resp.Message.ToolCalls {
result, locations, level, err := uc.executeToolCall(ctx, person, toolCall)
if err != nil {
return nil, fmt.Errorf("executing tool call: %w", err)
}
personLocations = append(personLocations, locations...)
if level > 0 {
accessLevel = level
}
messages = append(messages, domain.LLMMessage{
Role: "tool",
Content: result,
ToolCallID: toolCall.ID,
})
}
} else if resp.FinishReason == "stop" {
log.Printf(" ✓ Data gathered (locations: %d, access level: %d)", len(personLocations), accessLevel)
break
}
}
return &domain.PersonData{
Person: person,
Locations: personLocations,
AccessLevel: accessLevel,
}, nil
}
// executeToolCall executes a tool call and returns locations and access level
func (uc *OptimizedAgentProcessorUseCase) executeToolCall(
ctx context.Context,
person domain.Person,
toolCall domain.ToolCall,
) (string, []domain.PersonLocation, int, error) {
var args map[string]interface{}
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
return "", nil, 0, fmt.Errorf("parsing arguments: %w", err)
}
switch toolCall.Function.Name {
case "get_location":
name, _ := args["name"].(string)
surname, _ := args["surname"].(string)
req := domain.LocationRequest{
APIKey: uc.apiKey,
Name: name,
Surname: surname,
}
response, err := uc.apiClient.GetLocation(ctx, req)
if err != nil {
return fmt.Sprintf("Error: %v", err), nil, 0, nil
}
// Save response
fileName := fmt.Sprintf("%s_%s.json", name, surname)
filePath := filepath.Join(uc.outputDir, "locations", fileName)
os.WriteFile(filePath, response, 0644)
// Parse all locations from array
var locationData []map[string]interface{}
var locations []domain.PersonLocation
if err := json.Unmarshal(response, &locationData); err == nil {
for _, loc := range locationData {
if lat, ok := loc["latitude"].(float64); ok {
if lon, ok := loc["longitude"].(float64); ok {
locations = append(locations, domain.PersonLocation{
Name: name,
Surname: surname,
Latitude: lat,
Longitude: lon,
})
}
}
}
}
log.Printf(" → get_location: %d locations found", len(locations))
return string(response), locations, 0, nil
case "get_access_level":
name, _ := args["name"].(string)
surname, _ := args["surname"].(string)
birthYear, _ := args["birth_year"].(float64)
req := domain.AccessLevelRequest{
APIKey: uc.apiKey,
Name: name,
Surname: surname,
BirthYear: int(birthYear),
}
response, err := uc.apiClient.GetAccessLevel(ctx, req)
if err != nil {
return fmt.Sprintf("Error: %v", err), nil, 0, nil
}
// Save response
fileName := fmt.Sprintf("%s_%s.json", name, surname)
filePath := filepath.Join(uc.outputDir, "accesslevel", fileName)
os.WriteFile(filePath, response, 0644)
// Parse access level
var accessData struct {
AccessLevel int `json:"accessLevel"`
}
var level int
if err := json.Unmarshal(response, &accessData); err == nil {
level = accessData.AccessLevel
}
log.Printf(" → get_access_level: level %d", level)
return string(response), nil, level, nil
default:
return fmt.Sprintf("Unknown function: %s", toolCall.Function.Name), nil, 0, nil
}
}
// generateLocationReports creates reports for each power plant location
func (uc *OptimizedAgentProcessorUseCase) generateLocationReports(
personDataMap map[string]*domain.PersonData,
powerPlants []domain.Location,
) []domain.LocationReport {
var reports []domain.LocationReport
for _, powerPlant := range powerPlants {
report := domain.LocationReport{
LocationName: powerPlant.Name,
Persons: []domain.PersonWithDistance{},
}
// Calculate distances for all persons to this power plant
for _, personData := range personDataMap {
if len(personData.Locations) == 0 {
continue
}
// Use first location (primary location)
personLoc := personData.Locations[0]
distance := domain.Haversine(
personLoc.Latitude,
personLoc.Longitude,
powerPlant.Latitude,
powerPlant.Longitude,
)
report.Persons = append(report.Persons, domain.PersonWithDistance{
Name: personData.Person.Name,
Surname: personData.Person.Surname,
LocationName: powerPlant.Name,
DistanceKm: distance,
AccessLevel: personData.AccessLevel,
})
}
// Sort by distance
report.SortPersonsByDistance()
reports = append(reports, report)
log.Printf(" ✓ %s: %d persons", powerPlant.Name, len(report.Persons))
}
return reports
}
// saveLocationReport saves a location report to file
func (uc *OptimizedAgentProcessorUseCase) saveLocationReport(report domain.LocationReport) {
fileName := fmt.Sprintf("%s_report.json", report.LocationName)
filePath := filepath.Join(uc.outputDir, "reports", fileName)
data, err := json.MarshalIndent(report, "", " ")
if err != nil {
log.Printf("Error marshaling report for %s: %v", report.LocationName, err)
return
}
if err := os.WriteFile(filePath, data, 0644); err != nil {
log.Printf("Error saving report for %s: %v", report.LocationName, err)
return
}
log.Printf(" ✓ Saved: %s", filePath)
}
// loadPowerPlantsData loads raw power plants data with codes
func (uc *OptimizedAgentProcessorUseCase) loadPowerPlantsData(ctx context.Context) (*domain.PowerPlantsData, error) {
// Use the location repository to get the raw data
// We need to fetch from the API directly
url := fmt.Sprintf("https://hub.ag3nts.org/data/%s/findhim_locations.json", uc.apiKey)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("fetching data: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
var plantsData domain.PowerPlantsData
if err := json.Unmarshal(body, &plantsData); err != nil {
return nil, fmt.Errorf("parsing JSON: %w", err)
}
return &plantsData, nil
}
// findClosestPowerPlantLocally finds the closest power plant without using LLM
func (uc *OptimizedAgentProcessorUseCase) findClosestPowerPlantLocally(
personDataMap map[string]*domain.PersonData,
powerPlants []domain.Location,
plantsData *domain.PowerPlantsData,
) (*domain.PersonData, string, float64, string) {
var closestPerson *domain.PersonData
var closestPlantName string
var plantCode string
minDistance := 1e10
// Check all person-plant combinations
for _, personData := range personDataMap {
if len(personData.Locations) == 0 {
continue
}
// Use first location (primary)
personLoc := personData.Locations[0]
for _, plant := range powerPlants {
distance := domain.Haversine(
personLoc.Latitude,
personLoc.Longitude,
plant.Latitude,
plant.Longitude,
)
// If distance is smaller, or same distance but higher access level
if distance < minDistance {
minDistance = distance
closestPerson = personData
closestPlantName = plant.Name
// Get power plant code
if info, ok := plantsData.PowerPlants[plant.Name]; ok {
plantCode = info.Code
}
} else if distance == minDistance && closestPerson != nil && personData.AccessLevel > closestPerson.AccessLevel {
// Same distance, but higher access level
closestPerson = personData
closestPlantName = plant.Name
// Get power plant code
if info, ok := plantsData.PowerPlants[plant.Name]; ok {
plantCode = info.Code
}
}
}
}
return closestPerson, closestPlantName, minDistance, plantCode
}