462 lines
13 KiB
Go
462 lines
13 KiB
Go
package usecase
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/paramah/ai_devs4/s01e02/internal/domain"
|
|
)
|
|
|
|
// OptimizedAgentProcessorUseCase handles optimized processing with location reports
|
|
type OptimizedAgentProcessorUseCase struct {
|
|
personRepo domain.PersonRepository
|
|
locationRepo domain.LocationRepository
|
|
apiClient domain.APIClient
|
|
llmProvider domain.LLMProvider
|
|
apiKey string
|
|
outputDir string
|
|
}
|
|
|
|
// NewOptimizedAgentProcessorUseCase creates a new optimized use case instance
|
|
func NewOptimizedAgentProcessorUseCase(
|
|
personRepo domain.PersonRepository,
|
|
locationRepo domain.LocationRepository,
|
|
apiClient domain.APIClient,
|
|
llmProvider domain.LLMProvider,
|
|
apiKey string,
|
|
outputDir string,
|
|
) *OptimizedAgentProcessorUseCase {
|
|
return &OptimizedAgentProcessorUseCase{
|
|
personRepo: personRepo,
|
|
locationRepo: locationRepo,
|
|
apiClient: apiClient,
|
|
llmProvider: llmProvider,
|
|
apiKey: apiKey,
|
|
outputDir: outputDir,
|
|
}
|
|
}
|
|
|
|
// Execute processes all persons and generates location reports
|
|
func (uc *OptimizedAgentProcessorUseCase) Execute(ctx context.Context, inputFile string) error {
|
|
// Load persons from file
|
|
log.Printf("Loading persons from: %s", inputFile)
|
|
persons, err := uc.personRepo.LoadPersons(ctx, inputFile)
|
|
if err != nil {
|
|
return fmt.Errorf("loading persons: %w", err)
|
|
}
|
|
log.Printf("Loaded %d persons", len(persons))
|
|
|
|
// Load power plant locations
|
|
log.Printf("Loading power plant locations...")
|
|
powerPlants, err := uc.locationRepo.LoadLocations(ctx, uc.apiKey)
|
|
if err != nil {
|
|
return fmt.Errorf("loading locations: %w", err)
|
|
}
|
|
log.Printf("Loaded %d power plant locations", len(powerPlants))
|
|
|
|
// Also load raw power plants data for codes
|
|
plantsData, err := uc.loadPowerPlantsData(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("loading power plants data: %w", err)
|
|
}
|
|
|
|
// Save power plants data to output
|
|
plantsDataJSON, err := json.MarshalIndent(plantsData, "", " ")
|
|
if err != nil {
|
|
log.Printf("Warning: Failed to marshal power plants data: %v", err)
|
|
} else {
|
|
plantsFilePath := filepath.Join(uc.outputDir, "findhim_locations.json")
|
|
if err := os.WriteFile(plantsFilePath, plantsDataJSON, 0644); err != nil {
|
|
log.Printf("Warning: Failed to save power plants data: %v", err)
|
|
} else {
|
|
log.Printf("✓ Power plants data saved to: %s", plantsFilePath)
|
|
}
|
|
}
|
|
|
|
// Phase 1: Gather all data for all persons
|
|
log.Printf("\n========== Phase 1: Gathering person data ==========")
|
|
personDataMap := make(map[string]*domain.PersonData)
|
|
|
|
for i, person := range persons {
|
|
log.Printf("[%d/%d] Gathering data for: %s %s", i+1, len(persons), person.Name, person.Surname)
|
|
|
|
personData, err := uc.gatherPersonData(ctx, person)
|
|
if err != nil {
|
|
log.Printf("Error gathering data for %s %s: %v", person.Name, person.Surname, err)
|
|
continue
|
|
}
|
|
|
|
key := fmt.Sprintf("%s_%s", person.Name, person.Surname)
|
|
personDataMap[key] = personData
|
|
}
|
|
|
|
// Phase 2: Calculate all distances and generate location reports
|
|
log.Printf("\n========== Phase 2: Generating location reports ==========")
|
|
locationReports := uc.generateLocationReports(personDataMap, powerPlants)
|
|
|
|
// Phase 3: Save reports
|
|
log.Printf("\n========== Phase 3: Saving reports ==========")
|
|
for _, report := range locationReports {
|
|
uc.saveLocationReport(report)
|
|
}
|
|
|
|
// Phase 4: Find closest power plant (locally, no LLM needed)
|
|
log.Printf("\n========== Phase 4: Finding closest power plant ==========")
|
|
|
|
closestPerson, closestPlantName, distance, plantCode := uc.findClosestPowerPlantLocally(personDataMap, powerPlants, plantsData)
|
|
|
|
if closestPerson != nil {
|
|
log.Printf(" ✓ Closest power plant: %s (%s)", closestPlantName, plantCode)
|
|
log.Printf(" ✓ Closest person: %s %s", closestPerson.Person.Name, closestPerson.Person.Surname)
|
|
log.Printf(" ✓ Distance: %.2f km", distance)
|
|
log.Printf(" ✓ Access level: %d", closestPerson.AccessLevel)
|
|
|
|
// Save final answer
|
|
finalAnswer := domain.FinalAnswer{
|
|
APIKey: uc.apiKey,
|
|
Task: "findhim",
|
|
Answer: domain.AnswerDetail{
|
|
Name: closestPerson.Person.Name,
|
|
Surname: closestPerson.Person.Surname,
|
|
AccessLevel: closestPerson.AccessLevel,
|
|
PowerPlant: plantCode,
|
|
},
|
|
}
|
|
|
|
answerJSON, err := json.MarshalIndent(finalAnswer, "", " ")
|
|
if err != nil {
|
|
return fmt.Errorf("marshaling final answer: %w", err)
|
|
}
|
|
|
|
answerPath := filepath.Join(uc.outputDir, "final_answer.json")
|
|
if err := os.WriteFile(answerPath, answerJSON, 0644); err != nil {
|
|
return fmt.Errorf("saving final answer: %w", err)
|
|
}
|
|
|
|
log.Printf(" ✓ Final answer saved to: %s", answerPath)
|
|
}
|
|
|
|
log.Printf("\nProcessing completed!")
|
|
return nil
|
|
}
|
|
|
|
// gatherPersonData uses LLM agent to gather all data for a person (optimized)
|
|
func (uc *OptimizedAgentProcessorUseCase) gatherPersonData(ctx context.Context, person domain.Person) (*domain.PersonData, error) {
|
|
tools := domain.GetToolDefinitions()
|
|
|
|
// Minimal, optimized system prompt
|
|
systemPrompt := fmt.Sprintf(`Gather data for person: %s %s (born: %d).
|
|
Tasks:
|
|
1. Call get_location
|
|
2. Call get_access_level (use birth_year: %d)
|
|
Then respond "DONE".`, person.Name, person.Surname, person.Born, person.Born)
|
|
|
|
messages := []domain.LLMMessage{
|
|
{
|
|
Role: "system",
|
|
Content: systemPrompt,
|
|
},
|
|
{
|
|
Role: "user",
|
|
Content: "Start gathering data.",
|
|
},
|
|
}
|
|
|
|
maxIterations := 5
|
|
var personLocations []domain.PersonLocation
|
|
var accessLevel int
|
|
|
|
for iteration := 0; iteration < maxIterations; iteration++ {
|
|
resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{
|
|
Messages: messages,
|
|
Tools: tools,
|
|
ToolChoice: "auto",
|
|
Temperature: 0.0,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("LLM chat error: %w", err)
|
|
}
|
|
|
|
messages = append(messages, resp.Message)
|
|
|
|
if len(resp.Message.ToolCalls) > 0 {
|
|
for _, toolCall := range resp.Message.ToolCalls {
|
|
result, locations, level, err := uc.executeToolCall(ctx, person, toolCall)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("executing tool call: %w", err)
|
|
}
|
|
|
|
personLocations = append(personLocations, locations...)
|
|
if level > 0 {
|
|
accessLevel = level
|
|
}
|
|
|
|
messages = append(messages, domain.LLMMessage{
|
|
Role: "tool",
|
|
Content: result,
|
|
ToolCallID: toolCall.ID,
|
|
})
|
|
}
|
|
} else if resp.FinishReason == "stop" {
|
|
log.Printf(" ✓ Data gathered (locations: %d, access level: %d)", len(personLocations), accessLevel)
|
|
break
|
|
}
|
|
}
|
|
|
|
return &domain.PersonData{
|
|
Person: person,
|
|
Locations: personLocations,
|
|
AccessLevel: accessLevel,
|
|
}, nil
|
|
}
|
|
|
|
// executeToolCall executes a tool call and returns locations and access level
|
|
func (uc *OptimizedAgentProcessorUseCase) executeToolCall(
|
|
ctx context.Context,
|
|
person domain.Person,
|
|
toolCall domain.ToolCall,
|
|
) (string, []domain.PersonLocation, int, error) {
|
|
var args map[string]interface{}
|
|
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
|
return "", nil, 0, fmt.Errorf("parsing arguments: %w", err)
|
|
}
|
|
|
|
switch toolCall.Function.Name {
|
|
case "get_location":
|
|
name, _ := args["name"].(string)
|
|
surname, _ := args["surname"].(string)
|
|
|
|
req := domain.LocationRequest{
|
|
APIKey: uc.apiKey,
|
|
Name: name,
|
|
Surname: surname,
|
|
}
|
|
|
|
response, err := uc.apiClient.GetLocation(ctx, req)
|
|
if err != nil {
|
|
return fmt.Sprintf("Error: %v", err), nil, 0, nil
|
|
}
|
|
|
|
// Save response
|
|
fileName := fmt.Sprintf("%s_%s.json", name, surname)
|
|
filePath := filepath.Join(uc.outputDir, "locations", fileName)
|
|
os.WriteFile(filePath, response, 0644)
|
|
|
|
// Parse all locations from array
|
|
var locationData []map[string]interface{}
|
|
var locations []domain.PersonLocation
|
|
|
|
if err := json.Unmarshal(response, &locationData); err == nil {
|
|
for _, loc := range locationData {
|
|
if lat, ok := loc["latitude"].(float64); ok {
|
|
if lon, ok := loc["longitude"].(float64); ok {
|
|
locations = append(locations, domain.PersonLocation{
|
|
Name: name,
|
|
Surname: surname,
|
|
Latitude: lat,
|
|
Longitude: lon,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
log.Printf(" → get_location: %d locations found", len(locations))
|
|
return string(response), locations, 0, nil
|
|
|
|
case "get_access_level":
|
|
name, _ := args["name"].(string)
|
|
surname, _ := args["surname"].(string)
|
|
birthYear, _ := args["birth_year"].(float64)
|
|
|
|
req := domain.AccessLevelRequest{
|
|
APIKey: uc.apiKey,
|
|
Name: name,
|
|
Surname: surname,
|
|
BirthYear: int(birthYear),
|
|
}
|
|
|
|
response, err := uc.apiClient.GetAccessLevel(ctx, req)
|
|
if err != nil {
|
|
return fmt.Sprintf("Error: %v", err), nil, 0, nil
|
|
}
|
|
|
|
// Save response
|
|
fileName := fmt.Sprintf("%s_%s.json", name, surname)
|
|
filePath := filepath.Join(uc.outputDir, "accesslevel", fileName)
|
|
os.WriteFile(filePath, response, 0644)
|
|
|
|
// Parse access level
|
|
var accessData struct {
|
|
AccessLevel int `json:"accessLevel"`
|
|
}
|
|
var level int
|
|
if err := json.Unmarshal(response, &accessData); err == nil {
|
|
level = accessData.AccessLevel
|
|
}
|
|
|
|
log.Printf(" → get_access_level: level %d", level)
|
|
return string(response), nil, level, nil
|
|
|
|
default:
|
|
return fmt.Sprintf("Unknown function: %s", toolCall.Function.Name), nil, 0, nil
|
|
}
|
|
}
|
|
|
|
// generateLocationReports creates reports for each power plant location
|
|
func (uc *OptimizedAgentProcessorUseCase) generateLocationReports(
|
|
personDataMap map[string]*domain.PersonData,
|
|
powerPlants []domain.Location,
|
|
) []domain.LocationReport {
|
|
var reports []domain.LocationReport
|
|
|
|
for _, powerPlant := range powerPlants {
|
|
report := domain.LocationReport{
|
|
LocationName: powerPlant.Name,
|
|
Persons: []domain.PersonWithDistance{},
|
|
}
|
|
|
|
// Calculate distances for all persons to this power plant
|
|
for _, personData := range personDataMap {
|
|
if len(personData.Locations) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Use first location (primary location)
|
|
personLoc := personData.Locations[0]
|
|
distance := domain.Haversine(
|
|
personLoc.Latitude,
|
|
personLoc.Longitude,
|
|
powerPlant.Latitude,
|
|
powerPlant.Longitude,
|
|
)
|
|
|
|
report.Persons = append(report.Persons, domain.PersonWithDistance{
|
|
Name: personData.Person.Name,
|
|
Surname: personData.Person.Surname,
|
|
LocationName: powerPlant.Name,
|
|
DistanceKm: distance,
|
|
AccessLevel: personData.AccessLevel,
|
|
})
|
|
}
|
|
|
|
// Sort by distance
|
|
report.SortPersonsByDistance()
|
|
reports = append(reports, report)
|
|
|
|
log.Printf(" ✓ %s: %d persons", powerPlant.Name, len(report.Persons))
|
|
}
|
|
|
|
return reports
|
|
}
|
|
|
|
// saveLocationReport saves a location report to file
|
|
func (uc *OptimizedAgentProcessorUseCase) saveLocationReport(report domain.LocationReport) {
|
|
fileName := fmt.Sprintf("%s_report.json", report.LocationName)
|
|
filePath := filepath.Join(uc.outputDir, "reports", fileName)
|
|
|
|
data, err := json.MarshalIndent(report, "", " ")
|
|
if err != nil {
|
|
log.Printf("Error marshaling report for %s: %v", report.LocationName, err)
|
|
return
|
|
}
|
|
|
|
if err := os.WriteFile(filePath, data, 0644); err != nil {
|
|
log.Printf("Error saving report for %s: %v", report.LocationName, err)
|
|
return
|
|
}
|
|
|
|
log.Printf(" ✓ Saved: %s", filePath)
|
|
}
|
|
|
|
// loadPowerPlantsData loads raw power plants data with codes
|
|
func (uc *OptimizedAgentProcessorUseCase) loadPowerPlantsData(ctx context.Context) (*domain.PowerPlantsData, error) {
|
|
// Use the location repository to get the raw data
|
|
// We need to fetch from the API directly
|
|
url := fmt.Sprintf("https://hub.ag3nts.org/data/%s/findhim_locations.json", uc.apiKey)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating request: %w", err)
|
|
}
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fetching data: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading response: %w", err)
|
|
}
|
|
|
|
var plantsData domain.PowerPlantsData
|
|
if err := json.Unmarshal(body, &plantsData); err != nil {
|
|
return nil, fmt.Errorf("parsing JSON: %w", err)
|
|
}
|
|
|
|
return &plantsData, nil
|
|
}
|
|
|
|
// findClosestPowerPlantLocally finds the closest power plant without using LLM
|
|
func (uc *OptimizedAgentProcessorUseCase) findClosestPowerPlantLocally(
|
|
personDataMap map[string]*domain.PersonData,
|
|
powerPlants []domain.Location,
|
|
plantsData *domain.PowerPlantsData,
|
|
) (*domain.PersonData, string, float64, string) {
|
|
var closestPerson *domain.PersonData
|
|
var closestPlantName string
|
|
var plantCode string
|
|
minDistance := 1e10
|
|
|
|
// Check all person-plant combinations
|
|
for _, personData := range personDataMap {
|
|
if len(personData.Locations) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Use first location (primary)
|
|
personLoc := personData.Locations[0]
|
|
|
|
for _, plant := range powerPlants {
|
|
distance := domain.Haversine(
|
|
personLoc.Latitude,
|
|
personLoc.Longitude,
|
|
plant.Latitude,
|
|
plant.Longitude,
|
|
)
|
|
|
|
// If distance is smaller, or same distance but higher access level
|
|
if distance < minDistance {
|
|
minDistance = distance
|
|
closestPerson = personData
|
|
closestPlantName = plant.Name
|
|
|
|
// Get power plant code
|
|
if info, ok := plantsData.PowerPlants[plant.Name]; ok {
|
|
plantCode = info.Code
|
|
}
|
|
} else if distance == minDistance && closestPerson != nil && personData.AccessLevel > closestPerson.AccessLevel {
|
|
// Same distance, but higher access level
|
|
closestPerson = personData
|
|
closestPlantName = plant.Name
|
|
|
|
// Get power plant code
|
|
if info, ok := plantsData.PowerPlants[plant.Name]; ok {
|
|
plantCode = info.Code
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return closestPerson, closestPlantName, minDistance, plantCode
|
|
}
|