final
This commit is contained in:
461
internal/usecase/optimized_agent_processor.go
Normal file
461
internal/usecase/optimized_agent_processor.go
Normal file
@@ -0,0 +1,461 @@
|
||||
package usecase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/paramah/ai_devs4/s01e02/internal/domain"
|
||||
)
|
||||
|
||||
// OptimizedAgentProcessorUseCase handles optimized processing with location reports
|
||||
type OptimizedAgentProcessorUseCase struct {
|
||||
personRepo domain.PersonRepository
|
||||
locationRepo domain.LocationRepository
|
||||
apiClient domain.APIClient
|
||||
llmProvider domain.LLMProvider
|
||||
apiKey string
|
||||
outputDir string
|
||||
}
|
||||
|
||||
// NewOptimizedAgentProcessorUseCase creates a new optimized use case instance
|
||||
func NewOptimizedAgentProcessorUseCase(
|
||||
personRepo domain.PersonRepository,
|
||||
locationRepo domain.LocationRepository,
|
||||
apiClient domain.APIClient,
|
||||
llmProvider domain.LLMProvider,
|
||||
apiKey string,
|
||||
outputDir string,
|
||||
) *OptimizedAgentProcessorUseCase {
|
||||
return &OptimizedAgentProcessorUseCase{
|
||||
personRepo: personRepo,
|
||||
locationRepo: locationRepo,
|
||||
apiClient: apiClient,
|
||||
llmProvider: llmProvider,
|
||||
apiKey: apiKey,
|
||||
outputDir: outputDir,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute processes all persons and generates location reports
|
||||
func (uc *OptimizedAgentProcessorUseCase) Execute(ctx context.Context, inputFile string) error {
|
||||
// Load persons from file
|
||||
log.Printf("Loading persons from: %s", inputFile)
|
||||
persons, err := uc.personRepo.LoadPersons(ctx, inputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading persons: %w", err)
|
||||
}
|
||||
log.Printf("Loaded %d persons", len(persons))
|
||||
|
||||
// Load power plant locations
|
||||
log.Printf("Loading power plant locations...")
|
||||
powerPlants, err := uc.locationRepo.LoadLocations(ctx, uc.apiKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading locations: %w", err)
|
||||
}
|
||||
log.Printf("Loaded %d power plant locations", len(powerPlants))
|
||||
|
||||
// Also load raw power plants data for codes
|
||||
plantsData, err := uc.loadPowerPlantsData(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading power plants data: %w", err)
|
||||
}
|
||||
|
||||
// Save power plants data to output
|
||||
plantsDataJSON, err := json.MarshalIndent(plantsData, "", " ")
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to marshal power plants data: %v", err)
|
||||
} else {
|
||||
plantsFilePath := filepath.Join(uc.outputDir, "findhim_locations.json")
|
||||
if err := os.WriteFile(plantsFilePath, plantsDataJSON, 0644); err != nil {
|
||||
log.Printf("Warning: Failed to save power plants data: %v", err)
|
||||
} else {
|
||||
log.Printf("✓ Power plants data saved to: %s", plantsFilePath)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1: Gather all data for all persons
|
||||
log.Printf("\n========== Phase 1: Gathering person data ==========")
|
||||
personDataMap := make(map[string]*domain.PersonData)
|
||||
|
||||
for i, person := range persons {
|
||||
log.Printf("[%d/%d] Gathering data for: %s %s", i+1, len(persons), person.Name, person.Surname)
|
||||
|
||||
personData, err := uc.gatherPersonData(ctx, person)
|
||||
if err != nil {
|
||||
log.Printf("Error gathering data for %s %s: %v", person.Name, person.Surname, err)
|
||||
continue
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s_%s", person.Name, person.Surname)
|
||||
personDataMap[key] = personData
|
||||
}
|
||||
|
||||
// Phase 2: Calculate all distances and generate location reports
|
||||
log.Printf("\n========== Phase 2: Generating location reports ==========")
|
||||
locationReports := uc.generateLocationReports(personDataMap, powerPlants)
|
||||
|
||||
// Phase 3: Save reports
|
||||
log.Printf("\n========== Phase 3: Saving reports ==========")
|
||||
for _, report := range locationReports {
|
||||
uc.saveLocationReport(report)
|
||||
}
|
||||
|
||||
// Phase 4: Find closest power plant (locally, no LLM needed)
|
||||
log.Printf("\n========== Phase 4: Finding closest power plant ==========")
|
||||
|
||||
closestPerson, closestPlantName, distance, plantCode := uc.findClosestPowerPlantLocally(personDataMap, powerPlants, plantsData)
|
||||
|
||||
if closestPerson != nil {
|
||||
log.Printf(" ✓ Closest power plant: %s (%s)", closestPlantName, plantCode)
|
||||
log.Printf(" ✓ Closest person: %s %s", closestPerson.Person.Name, closestPerson.Person.Surname)
|
||||
log.Printf(" ✓ Distance: %.2f km", distance)
|
||||
log.Printf(" ✓ Access level: %d", closestPerson.AccessLevel)
|
||||
|
||||
// Save final answer
|
||||
finalAnswer := domain.FinalAnswer{
|
||||
APIKey: uc.apiKey,
|
||||
Task: "findhim",
|
||||
Answer: domain.AnswerDetail{
|
||||
Name: closestPerson.Person.Name,
|
||||
Surname: closestPerson.Person.Surname,
|
||||
AccessLevel: closestPerson.AccessLevel,
|
||||
PowerPlant: plantCode,
|
||||
},
|
||||
}
|
||||
|
||||
answerJSON, err := json.MarshalIndent(finalAnswer, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling final answer: %w", err)
|
||||
}
|
||||
|
||||
answerPath := filepath.Join(uc.outputDir, "final_answer.json")
|
||||
if err := os.WriteFile(answerPath, answerJSON, 0644); err != nil {
|
||||
return fmt.Errorf("saving final answer: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" ✓ Final answer saved to: %s", answerPath)
|
||||
}
|
||||
|
||||
log.Printf("\nProcessing completed!")
|
||||
return nil
|
||||
}
|
||||
|
||||
// gatherPersonData uses LLM agent to gather all data for a person (optimized)
|
||||
func (uc *OptimizedAgentProcessorUseCase) gatherPersonData(ctx context.Context, person domain.Person) (*domain.PersonData, error) {
|
||||
tools := domain.GetToolDefinitions()
|
||||
|
||||
// Minimal, optimized system prompt
|
||||
systemPrompt := fmt.Sprintf(`Gather data for person: %s %s (born: %d).
|
||||
Tasks:
|
||||
1. Call get_location
|
||||
2. Call get_access_level (use birth_year: %d)
|
||||
Then respond "DONE".`, person.Name, person.Surname, person.Born, person.Born)
|
||||
|
||||
messages := []domain.LLMMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Start gathering data.",
|
||||
},
|
||||
}
|
||||
|
||||
maxIterations := 5
|
||||
var personLocations []domain.PersonLocation
|
||||
var accessLevel int
|
||||
|
||||
for iteration := 0; iteration < maxIterations; iteration++ {
|
||||
resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
ToolChoice: "auto",
|
||||
Temperature: 0.0,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LLM chat error: %w", err)
|
||||
}
|
||||
|
||||
messages = append(messages, resp.Message)
|
||||
|
||||
if len(resp.Message.ToolCalls) > 0 {
|
||||
for _, toolCall := range resp.Message.ToolCalls {
|
||||
result, locations, level, err := uc.executeToolCall(ctx, person, toolCall)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("executing tool call: %w", err)
|
||||
}
|
||||
|
||||
personLocations = append(personLocations, locations...)
|
||||
if level > 0 {
|
||||
accessLevel = level
|
||||
}
|
||||
|
||||
messages = append(messages, domain.LLMMessage{
|
||||
Role: "tool",
|
||||
Content: result,
|
||||
ToolCallID: toolCall.ID,
|
||||
})
|
||||
}
|
||||
} else if resp.FinishReason == "stop" {
|
||||
log.Printf(" ✓ Data gathered (locations: %d, access level: %d)", len(personLocations), accessLevel)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &domain.PersonData{
|
||||
Person: person,
|
||||
Locations: personLocations,
|
||||
AccessLevel: accessLevel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// executeToolCall executes a tool call and returns locations and access level
|
||||
func (uc *OptimizedAgentProcessorUseCase) executeToolCall(
|
||||
ctx context.Context,
|
||||
person domain.Person,
|
||||
toolCall domain.ToolCall,
|
||||
) (string, []domain.PersonLocation, int, error) {
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
return "", nil, 0, fmt.Errorf("parsing arguments: %w", err)
|
||||
}
|
||||
|
||||
switch toolCall.Function.Name {
|
||||
case "get_location":
|
||||
name, _ := args["name"].(string)
|
||||
surname, _ := args["surname"].(string)
|
||||
|
||||
req := domain.LocationRequest{
|
||||
APIKey: uc.apiKey,
|
||||
Name: name,
|
||||
Surname: surname,
|
||||
}
|
||||
|
||||
response, err := uc.apiClient.GetLocation(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err), nil, 0, nil
|
||||
}
|
||||
|
||||
// Save response
|
||||
fileName := fmt.Sprintf("%s_%s.json", name, surname)
|
||||
filePath := filepath.Join(uc.outputDir, "locations", fileName)
|
||||
os.WriteFile(filePath, response, 0644)
|
||||
|
||||
// Parse all locations from array
|
||||
var locationData []map[string]interface{}
|
||||
var locations []domain.PersonLocation
|
||||
|
||||
if err := json.Unmarshal(response, &locationData); err == nil {
|
||||
for _, loc := range locationData {
|
||||
if lat, ok := loc["latitude"].(float64); ok {
|
||||
if lon, ok := loc["longitude"].(float64); ok {
|
||||
locations = append(locations, domain.PersonLocation{
|
||||
Name: name,
|
||||
Surname: surname,
|
||||
Latitude: lat,
|
||||
Longitude: lon,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf(" → get_location: %d locations found", len(locations))
|
||||
return string(response), locations, 0, nil
|
||||
|
||||
case "get_access_level":
|
||||
name, _ := args["name"].(string)
|
||||
surname, _ := args["surname"].(string)
|
||||
birthYear, _ := args["birth_year"].(float64)
|
||||
|
||||
req := domain.AccessLevelRequest{
|
||||
APIKey: uc.apiKey,
|
||||
Name: name,
|
||||
Surname: surname,
|
||||
BirthYear: int(birthYear),
|
||||
}
|
||||
|
||||
response, err := uc.apiClient.GetAccessLevel(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err), nil, 0, nil
|
||||
}
|
||||
|
||||
// Save response
|
||||
fileName := fmt.Sprintf("%s_%s.json", name, surname)
|
||||
filePath := filepath.Join(uc.outputDir, "accesslevel", fileName)
|
||||
os.WriteFile(filePath, response, 0644)
|
||||
|
||||
// Parse access level
|
||||
var accessData struct {
|
||||
AccessLevel int `json:"accessLevel"`
|
||||
}
|
||||
var level int
|
||||
if err := json.Unmarshal(response, &accessData); err == nil {
|
||||
level = accessData.AccessLevel
|
||||
}
|
||||
|
||||
log.Printf(" → get_access_level: level %d", level)
|
||||
return string(response), nil, level, nil
|
||||
|
||||
default:
|
||||
return fmt.Sprintf("Unknown function: %s", toolCall.Function.Name), nil, 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
// generateLocationReports creates reports for each power plant location
|
||||
func (uc *OptimizedAgentProcessorUseCase) generateLocationReports(
|
||||
personDataMap map[string]*domain.PersonData,
|
||||
powerPlants []domain.Location,
|
||||
) []domain.LocationReport {
|
||||
var reports []domain.LocationReport
|
||||
|
||||
for _, powerPlant := range powerPlants {
|
||||
report := domain.LocationReport{
|
||||
LocationName: powerPlant.Name,
|
||||
Persons: []domain.PersonWithDistance{},
|
||||
}
|
||||
|
||||
// Calculate distances for all persons to this power plant
|
||||
for _, personData := range personDataMap {
|
||||
if len(personData.Locations) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use first location (primary location)
|
||||
personLoc := personData.Locations[0]
|
||||
distance := domain.Haversine(
|
||||
personLoc.Latitude,
|
||||
personLoc.Longitude,
|
||||
powerPlant.Latitude,
|
||||
powerPlant.Longitude,
|
||||
)
|
||||
|
||||
report.Persons = append(report.Persons, domain.PersonWithDistance{
|
||||
Name: personData.Person.Name,
|
||||
Surname: personData.Person.Surname,
|
||||
LocationName: powerPlant.Name,
|
||||
DistanceKm: distance,
|
||||
AccessLevel: personData.AccessLevel,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by distance
|
||||
report.SortPersonsByDistance()
|
||||
reports = append(reports, report)
|
||||
|
||||
log.Printf(" ✓ %s: %d persons", powerPlant.Name, len(report.Persons))
|
||||
}
|
||||
|
||||
return reports
|
||||
}
|
||||
|
||||
// saveLocationReport saves a location report to file
|
||||
func (uc *OptimizedAgentProcessorUseCase) saveLocationReport(report domain.LocationReport) {
|
||||
fileName := fmt.Sprintf("%s_report.json", report.LocationName)
|
||||
filePath := filepath.Join(uc.outputDir, "reports", fileName)
|
||||
|
||||
data, err := json.MarshalIndent(report, "", " ")
|
||||
if err != nil {
|
||||
log.Printf("Error marshaling report for %s: %v", report.LocationName, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filePath, data, 0644); err != nil {
|
||||
log.Printf("Error saving report for %s: %v", report.LocationName, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf(" ✓ Saved: %s", filePath)
|
||||
}
|
||||
|
||||
// loadPowerPlantsData loads raw power plants data with codes
|
||||
func (uc *OptimizedAgentProcessorUseCase) loadPowerPlantsData(ctx context.Context) (*domain.PowerPlantsData, error) {
|
||||
// Use the location repository to get the raw data
|
||||
// We need to fetch from the API directly
|
||||
url := fmt.Sprintf("https://hub.ag3nts.org/data/%s/findhim_locations.json", uc.apiKey)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching data: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
var plantsData domain.PowerPlantsData
|
||||
if err := json.Unmarshal(body, &plantsData); err != nil {
|
||||
return nil, fmt.Errorf("parsing JSON: %w", err)
|
||||
}
|
||||
|
||||
return &plantsData, nil
|
||||
}
|
||||
|
||||
// findClosestPowerPlantLocally finds the closest power plant without using LLM
|
||||
func (uc *OptimizedAgentProcessorUseCase) findClosestPowerPlantLocally(
|
||||
personDataMap map[string]*domain.PersonData,
|
||||
powerPlants []domain.Location,
|
||||
plantsData *domain.PowerPlantsData,
|
||||
) (*domain.PersonData, string, float64, string) {
|
||||
var closestPerson *domain.PersonData
|
||||
var closestPlantName string
|
||||
var plantCode string
|
||||
minDistance := 1e10
|
||||
|
||||
// Check all person-plant combinations
|
||||
for _, personData := range personDataMap {
|
||||
if len(personData.Locations) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use first location (primary)
|
||||
personLoc := personData.Locations[0]
|
||||
|
||||
for _, plant := range powerPlants {
|
||||
distance := domain.Haversine(
|
||||
personLoc.Latitude,
|
||||
personLoc.Longitude,
|
||||
plant.Latitude,
|
||||
plant.Longitude,
|
||||
)
|
||||
|
||||
// If distance is smaller, or same distance but higher access level
|
||||
if distance < minDistance {
|
||||
minDistance = distance
|
||||
closestPerson = personData
|
||||
closestPlantName = plant.Name
|
||||
|
||||
// Get power plant code
|
||||
if info, ok := plantsData.PowerPlants[plant.Name]; ok {
|
||||
plantCode = info.Code
|
||||
}
|
||||
} else if distance == minDistance && closestPerson != nil && personData.AccessLevel > closestPerson.AccessLevel {
|
||||
// Same distance, but higher access level
|
||||
closestPerson = personData
|
||||
closestPlantName = plant.Name
|
||||
|
||||
// Get power plant code
|
||||
if info, ok := plantsData.PowerPlants[plant.Name]; ok {
|
||||
plantCode = info.Code
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return closestPerson, closestPlantName, minDistance, plantCode
|
||||
}
|
||||
Reference in New Issue
Block a user