Files
s01e02/internal/usecase/find_closest_agent.go
2026-03-12 02:10:57 +01:00

186 lines
5.2 KiB
Go

package usecase
import (
"context"
"encoding/json"
"fmt"
"log"
"github.com/paramah/ai_devs4/s01e02/internal/domain"
)
// FindClosestPowerPlantUseCase uses LLM agent to find the closest power plant
type FindClosestPowerPlantUseCase struct {
llmProvider domain.LLMProvider
}
// NewFindClosestPowerPlantUseCase creates a new use case
func NewFindClosestPowerPlantUseCase(llmProvider domain.LLMProvider) *FindClosestPowerPlantUseCase {
return &FindClosestPowerPlantUseCase{
llmProvider: llmProvider,
}
}
// Execute uses agent to find the closest power plant to any person
func (uc *FindClosestPowerPlantUseCase) Execute(
ctx context.Context,
personDataMap map[string]*domain.PersonData,
plantsData *domain.PowerPlantsData,
) (*domain.PersonData, string, float64, string, error) {
// Prepare data summary for agent
personsInfo := ""
for _, pd := range personDataMap {
if len(pd.Locations) > 0 {
loc := pd.Locations[0]
personsInfo += fmt.Sprintf("- %s %s: lat=%.4f, lon=%.4f, access_level=%d\n",
pd.Person.Name, pd.Person.Surname, loc.Latitude, loc.Longitude, pd.AccessLevel)
}
}
plantsInfo := ""
for cityName := range plantsData.PowerPlants {
plantsInfo += fmt.Sprintf("- %s\n", cityName)
}
systemPrompt := fmt.Sprintf(`You are an agent that finds the closest power plant to any person.
Available persons and their locations:
%s
Available power plants (cities):
%s
Your task:
1. For each power plant, call get_power_plant_location to get its coordinates
2. For each person-plant pair, call calculate_distance to find the distance
3. Track the minimum distance and corresponding person/plant
4. After checking all combinations, respond with JSON containing the result:
{
"person_name": "Name",
"person_surname": "Surname",
"plant_city": "City",
"min_distance": 123.45
}`, personsInfo, plantsInfo)
tools := domain.GetToolDefinitions()
messages := []domain.LLMMessage{
{
Role: "system",
Content: systemPrompt,
},
{
Role: "user",
Content: "Find the power plant that is closest to any person. Check all combinations.",
},
}
maxIterations := 100 // Need many iterations for all combinations
var resultJSON string
for iteration := 0; iteration < maxIterations; iteration++ {
resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{
Messages: messages,
Tools: tools,
ToolChoice: "auto",
Temperature: 0.0,
})
if err != nil {
return nil, "", 0, "", fmt.Errorf("LLM chat error: %w", err)
}
messages = append(messages, resp.Message)
if len(resp.Message.ToolCalls) > 0 {
log.Printf(" [Iteration %d] Agent requested %d tool call(s)", iteration+1, len(resp.Message.ToolCalls))
for _, toolCall := range resp.Message.ToolCalls {
result := uc.executeToolCall(toolCall, plantsData)
messages = append(messages, domain.LLMMessage{
Role: "tool",
Content: result,
ToolCallID: toolCall.ID,
})
}
} else if resp.FinishReason == "stop" {
resultJSON = resp.Message.Content
log.Printf(" ✓ Agent completed analysis")
log.Printf(" → Raw response: %s", resultJSON)
break
}
}
// Extract JSON from response (agent might wrap it in text)
start := -1
end := -1
for i, ch := range resultJSON {
if ch == '{' && start == -1 {
start = i
}
if ch == '}' {
end = i + 1
}
}
if start == -1 || end == -1 {
return nil, "", 0, "", fmt.Errorf("no JSON found in response: %s", resultJSON)
}
jsonStr := resultJSON[start:end]
// Parse result
var result struct {
PersonName string `json:"person_name"`
PersonSurname string `json:"person_surname"`
PlantCity string `json:"plant_city"`
MinDistance float64 `json:"min_distance"`
}
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
return nil, "", 0, "", fmt.Errorf("parsing result: %w (json: %s)", err, jsonStr)
}
// Find the person data
key := fmt.Sprintf("%s_%s", result.PersonName, result.PersonSurname)
personData, ok := personDataMap[key]
if !ok {
return nil, "", 0, "", fmt.Errorf("person not found: %s", key)
}
// Get plant code
plantInfo, ok := plantsData.PowerPlants[result.PlantCity]
if !ok {
return nil, "", 0, "", fmt.Errorf("plant not found: %s", result.PlantCity)
}
return personData, result.PlantCity, result.MinDistance, plantInfo.Code, nil
}
// executeToolCall handles tool execution
func (uc *FindClosestPowerPlantUseCase) executeToolCall(toolCall domain.ToolCall, plantsData *domain.PowerPlantsData) string {
var args map[string]interface{}
json.Unmarshal([]byte(toolCall.Function.Arguments), &args)
switch toolCall.Function.Name {
case "get_power_plant_location":
cityName, _ := args["city_name"].(string)
coords, ok := domain.CityCoordinates[cityName]
if !ok {
return fmt.Sprintf(`{"error": "City not found: %s"}`, cityName)
}
return fmt.Sprintf(`{"city": "%s", "lat": %.6f, "lon": %.6f}`, cityName, coords.Lat, coords.Lon)
case "calculate_distance":
lat1, _ := args["lat1"].(float64)
lon1, _ := args["lon1"].(float64)
lat2, _ := args["lat2"].(float64)
lon2, _ := args["lon2"].(float64)
distance := domain.Haversine(lat1, lon1, lat2, lon2)
return fmt.Sprintf(`{"distance_km": %.2f}`, distance)
default:
return fmt.Sprintf(`{"error": "Unknown function: %s"}`, toolCall.Function.Name)
}
}