186 lines
5.2 KiB
Go
186 lines
5.2 KiB
Go
package usecase
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
|
|
"github.com/paramah/ai_devs4/s01e02/internal/domain"
|
|
)
|
|
|
|
// FindClosestPowerPlantUseCase uses LLM agent to find the closest power plant
|
|
type FindClosestPowerPlantUseCase struct {
|
|
llmProvider domain.LLMProvider
|
|
}
|
|
|
|
// NewFindClosestPowerPlantUseCase creates a new use case
|
|
func NewFindClosestPowerPlantUseCase(llmProvider domain.LLMProvider) *FindClosestPowerPlantUseCase {
|
|
return &FindClosestPowerPlantUseCase{
|
|
llmProvider: llmProvider,
|
|
}
|
|
}
|
|
|
|
// Execute uses agent to find the closest power plant to any person
|
|
func (uc *FindClosestPowerPlantUseCase) Execute(
|
|
ctx context.Context,
|
|
personDataMap map[string]*domain.PersonData,
|
|
plantsData *domain.PowerPlantsData,
|
|
) (*domain.PersonData, string, float64, string, error) {
|
|
|
|
// Prepare data summary for agent
|
|
personsInfo := ""
|
|
for _, pd := range personDataMap {
|
|
if len(pd.Locations) > 0 {
|
|
loc := pd.Locations[0]
|
|
personsInfo += fmt.Sprintf("- %s %s: lat=%.4f, lon=%.4f, access_level=%d\n",
|
|
pd.Person.Name, pd.Person.Surname, loc.Latitude, loc.Longitude, pd.AccessLevel)
|
|
}
|
|
}
|
|
|
|
plantsInfo := ""
|
|
for cityName := range plantsData.PowerPlants {
|
|
plantsInfo += fmt.Sprintf("- %s\n", cityName)
|
|
}
|
|
|
|
systemPrompt := fmt.Sprintf(`You are an agent that finds the closest power plant to any person.
|
|
|
|
Available persons and their locations:
|
|
%s
|
|
|
|
Available power plants (cities):
|
|
%s
|
|
|
|
Your task:
|
|
1. For each power plant, call get_power_plant_location to get its coordinates
|
|
2. For each person-plant pair, call calculate_distance to find the distance
|
|
3. Track the minimum distance and corresponding person/plant
|
|
4. After checking all combinations, respond with JSON containing the result:
|
|
{
|
|
"person_name": "Name",
|
|
"person_surname": "Surname",
|
|
"plant_city": "City",
|
|
"min_distance": 123.45
|
|
}`, personsInfo, plantsInfo)
|
|
|
|
tools := domain.GetToolDefinitions()
|
|
messages := []domain.LLMMessage{
|
|
{
|
|
Role: "system",
|
|
Content: systemPrompt,
|
|
},
|
|
{
|
|
Role: "user",
|
|
Content: "Find the power plant that is closest to any person. Check all combinations.",
|
|
},
|
|
}
|
|
|
|
maxIterations := 100 // Need many iterations for all combinations
|
|
var resultJSON string
|
|
|
|
for iteration := 0; iteration < maxIterations; iteration++ {
|
|
resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{
|
|
Messages: messages,
|
|
Tools: tools,
|
|
ToolChoice: "auto",
|
|
Temperature: 0.0,
|
|
})
|
|
if err != nil {
|
|
return nil, "", 0, "", fmt.Errorf("LLM chat error: %w", err)
|
|
}
|
|
|
|
messages = append(messages, resp.Message)
|
|
|
|
if len(resp.Message.ToolCalls) > 0 {
|
|
log.Printf(" [Iteration %d] Agent requested %d tool call(s)", iteration+1, len(resp.Message.ToolCalls))
|
|
|
|
for _, toolCall := range resp.Message.ToolCalls {
|
|
result := uc.executeToolCall(toolCall, plantsData)
|
|
messages = append(messages, domain.LLMMessage{
|
|
Role: "tool",
|
|
Content: result,
|
|
ToolCallID: toolCall.ID,
|
|
})
|
|
}
|
|
} else if resp.FinishReason == "stop" {
|
|
resultJSON = resp.Message.Content
|
|
log.Printf(" ✓ Agent completed analysis")
|
|
log.Printf(" → Raw response: %s", resultJSON)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Extract JSON from response (agent might wrap it in text)
|
|
start := -1
|
|
end := -1
|
|
for i, ch := range resultJSON {
|
|
if ch == '{' && start == -1 {
|
|
start = i
|
|
}
|
|
if ch == '}' {
|
|
end = i + 1
|
|
}
|
|
}
|
|
|
|
if start == -1 || end == -1 {
|
|
return nil, "", 0, "", fmt.Errorf("no JSON found in response: %s", resultJSON)
|
|
}
|
|
|
|
jsonStr := resultJSON[start:end]
|
|
|
|
// Parse result
|
|
var result struct {
|
|
PersonName string `json:"person_name"`
|
|
PersonSurname string `json:"person_surname"`
|
|
PlantCity string `json:"plant_city"`
|
|
MinDistance float64 `json:"min_distance"`
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
|
|
return nil, "", 0, "", fmt.Errorf("parsing result: %w (json: %s)", err, jsonStr)
|
|
}
|
|
|
|
// Find the person data
|
|
key := fmt.Sprintf("%s_%s", result.PersonName, result.PersonSurname)
|
|
personData, ok := personDataMap[key]
|
|
if !ok {
|
|
return nil, "", 0, "", fmt.Errorf("person not found: %s", key)
|
|
}
|
|
|
|
// Get plant code
|
|
plantInfo, ok := plantsData.PowerPlants[result.PlantCity]
|
|
if !ok {
|
|
return nil, "", 0, "", fmt.Errorf("plant not found: %s", result.PlantCity)
|
|
}
|
|
|
|
return personData, result.PlantCity, result.MinDistance, plantInfo.Code, nil
|
|
}
|
|
|
|
// executeToolCall handles tool execution
|
|
func (uc *FindClosestPowerPlantUseCase) executeToolCall(toolCall domain.ToolCall, plantsData *domain.PowerPlantsData) string {
|
|
var args map[string]interface{}
|
|
json.Unmarshal([]byte(toolCall.Function.Arguments), &args)
|
|
|
|
switch toolCall.Function.Name {
|
|
case "get_power_plant_location":
|
|
cityName, _ := args["city_name"].(string)
|
|
coords, ok := domain.CityCoordinates[cityName]
|
|
if !ok {
|
|
return fmt.Sprintf(`{"error": "City not found: %s"}`, cityName)
|
|
}
|
|
return fmt.Sprintf(`{"city": "%s", "lat": %.6f, "lon": %.6f}`, cityName, coords.Lat, coords.Lon)
|
|
|
|
case "calculate_distance":
|
|
lat1, _ := args["lat1"].(float64)
|
|
lon1, _ := args["lon1"].(float64)
|
|
lat2, _ := args["lat2"].(float64)
|
|
lon2, _ := args["lon2"].(float64)
|
|
|
|
distance := domain.Haversine(lat1, lon1, lat2, lon2)
|
|
return fmt.Sprintf(`{"distance_km": %.2f}`, distance)
|
|
|
|
default:
|
|
return fmt.Sprintf(`{"error": "Unknown function: %s"}`, toolCall.Function.Name)
|
|
}
|
|
}
|