final
This commit is contained in:
185
internal/usecase/find_closest_agent.go
Normal file
185
internal/usecase/find_closest_agent.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package usecase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/paramah/ai_devs4/s01e02/internal/domain"
|
||||
)
|
||||
|
||||
// FindClosestPowerPlantUseCase uses LLM agent to find the closest power plant
|
||||
type FindClosestPowerPlantUseCase struct {
|
||||
llmProvider domain.LLMProvider
|
||||
}
|
||||
|
||||
// NewFindClosestPowerPlantUseCase creates a new use case
|
||||
func NewFindClosestPowerPlantUseCase(llmProvider domain.LLMProvider) *FindClosestPowerPlantUseCase {
|
||||
return &FindClosestPowerPlantUseCase{
|
||||
llmProvider: llmProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute uses agent to find the closest power plant to any person
|
||||
func (uc *FindClosestPowerPlantUseCase) Execute(
|
||||
ctx context.Context,
|
||||
personDataMap map[string]*domain.PersonData,
|
||||
plantsData *domain.PowerPlantsData,
|
||||
) (*domain.PersonData, string, float64, string, error) {
|
||||
|
||||
// Prepare data summary for agent
|
||||
personsInfo := ""
|
||||
for _, pd := range personDataMap {
|
||||
if len(pd.Locations) > 0 {
|
||||
loc := pd.Locations[0]
|
||||
personsInfo += fmt.Sprintf("- %s %s: lat=%.4f, lon=%.4f, access_level=%d\n",
|
||||
pd.Person.Name, pd.Person.Surname, loc.Latitude, loc.Longitude, pd.AccessLevel)
|
||||
}
|
||||
}
|
||||
|
||||
plantsInfo := ""
|
||||
for cityName := range plantsData.PowerPlants {
|
||||
plantsInfo += fmt.Sprintf("- %s\n", cityName)
|
||||
}
|
||||
|
||||
systemPrompt := fmt.Sprintf(`You are an agent that finds the closest power plant to any person.
|
||||
|
||||
Available persons and their locations:
|
||||
%s
|
||||
|
||||
Available power plants (cities):
|
||||
%s
|
||||
|
||||
Your task:
|
||||
1. For each power plant, call get_power_plant_location to get its coordinates
|
||||
2. For each person-plant pair, call calculate_distance to find the distance
|
||||
3. Track the minimum distance and corresponding person/plant
|
||||
4. After checking all combinations, respond with JSON containing the result:
|
||||
{
|
||||
"person_name": "Name",
|
||||
"person_surname": "Surname",
|
||||
"plant_city": "City",
|
||||
"min_distance": 123.45
|
||||
}`, personsInfo, plantsInfo)
|
||||
|
||||
tools := domain.GetToolDefinitions()
|
||||
messages := []domain.LLMMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Find the power plant that is closest to any person. Check all combinations.",
|
||||
},
|
||||
}
|
||||
|
||||
maxIterations := 100 // Need many iterations for all combinations
|
||||
var resultJSON string
|
||||
|
||||
for iteration := 0; iteration < maxIterations; iteration++ {
|
||||
resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
ToolChoice: "auto",
|
||||
Temperature: 0.0,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", 0, "", fmt.Errorf("LLM chat error: %w", err)
|
||||
}
|
||||
|
||||
messages = append(messages, resp.Message)
|
||||
|
||||
if len(resp.Message.ToolCalls) > 0 {
|
||||
log.Printf(" [Iteration %d] Agent requested %d tool call(s)", iteration+1, len(resp.Message.ToolCalls))
|
||||
|
||||
for _, toolCall := range resp.Message.ToolCalls {
|
||||
result := uc.executeToolCall(toolCall, plantsData)
|
||||
messages = append(messages, domain.LLMMessage{
|
||||
Role: "tool",
|
||||
Content: result,
|
||||
ToolCallID: toolCall.ID,
|
||||
})
|
||||
}
|
||||
} else if resp.FinishReason == "stop" {
|
||||
resultJSON = resp.Message.Content
|
||||
log.Printf(" ✓ Agent completed analysis")
|
||||
log.Printf(" → Raw response: %s", resultJSON)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Extract JSON from response (agent might wrap it in text)
|
||||
start := -1
|
||||
end := -1
|
||||
for i, ch := range resultJSON {
|
||||
if ch == '{' && start == -1 {
|
||||
start = i
|
||||
}
|
||||
if ch == '}' {
|
||||
end = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
if start == -1 || end == -1 {
|
||||
return nil, "", 0, "", fmt.Errorf("no JSON found in response: %s", resultJSON)
|
||||
}
|
||||
|
||||
jsonStr := resultJSON[start:end]
|
||||
|
||||
// Parse result
|
||||
var result struct {
|
||||
PersonName string `json:"person_name"`
|
||||
PersonSurname string `json:"person_surname"`
|
||||
PlantCity string `json:"plant_city"`
|
||||
MinDistance float64 `json:"min_distance"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
|
||||
return nil, "", 0, "", fmt.Errorf("parsing result: %w (json: %s)", err, jsonStr)
|
||||
}
|
||||
|
||||
// Find the person data
|
||||
key := fmt.Sprintf("%s_%s", result.PersonName, result.PersonSurname)
|
||||
personData, ok := personDataMap[key]
|
||||
if !ok {
|
||||
return nil, "", 0, "", fmt.Errorf("person not found: %s", key)
|
||||
}
|
||||
|
||||
// Get plant code
|
||||
plantInfo, ok := plantsData.PowerPlants[result.PlantCity]
|
||||
if !ok {
|
||||
return nil, "", 0, "", fmt.Errorf("plant not found: %s", result.PlantCity)
|
||||
}
|
||||
|
||||
return personData, result.PlantCity, result.MinDistance, plantInfo.Code, nil
|
||||
}
|
||||
|
||||
// executeToolCall handles tool execution
|
||||
func (uc *FindClosestPowerPlantUseCase) executeToolCall(toolCall domain.ToolCall, plantsData *domain.PowerPlantsData) string {
|
||||
var args map[string]interface{}
|
||||
json.Unmarshal([]byte(toolCall.Function.Arguments), &args)
|
||||
|
||||
switch toolCall.Function.Name {
|
||||
case "get_power_plant_location":
|
||||
cityName, _ := args["city_name"].(string)
|
||||
coords, ok := domain.CityCoordinates[cityName]
|
||||
if !ok {
|
||||
return fmt.Sprintf(`{"error": "City not found: %s"}`, cityName)
|
||||
}
|
||||
return fmt.Sprintf(`{"city": "%s", "lat": %.6f, "lon": %.6f}`, cityName, coords.Lat, coords.Lon)
|
||||
|
||||
case "calculate_distance":
|
||||
lat1, _ := args["lat1"].(float64)
|
||||
lon1, _ := args["lon1"].(float64)
|
||||
lat2, _ := args["lat2"].(float64)
|
||||
lon2, _ := args["lon2"].(float64)
|
||||
|
||||
distance := domain.Haversine(lat1, lon1, lat2, lon2)
|
||||
return fmt.Sprintf(`{"distance_km": %.2f}`, distance)
|
||||
|
||||
default:
|
||||
return fmt.Sprintf(`{"error": "Unknown function: %s"}`, toolCall.Function.Name)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user