165 lines
4.5 KiB
Go
165 lines
4.5 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/paramah/ai_devs4/s01e01/internal/config"
|
|
"github.com/paramah/ai_devs4/s01e01/internal/domain"
|
|
"github.com/paramah/ai_devs4/s01e01/internal/infrastructure/csv"
|
|
"github.com/paramah/ai_devs4/s01e01/internal/infrastructure/llm"
|
|
"github.com/paramah/ai_devs4/s01e01/internal/usecase"
|
|
)
|
|
|
|
func main() {
|
|
configPath := flag.String("config", "config.json", "Path to configuration file")
|
|
flag.Parse()
|
|
|
|
// Load configuration
|
|
cfg, err := config.Load(*configPath)
|
|
if err != nil {
|
|
log.Fatalf("Failed to load configuration: %v", err)
|
|
}
|
|
|
|
if err := cfg.Validate(); err != nil {
|
|
log.Fatalf("Invalid configuration: %v", err)
|
|
}
|
|
|
|
// Create repositories and providers
|
|
personRepo := csv.NewRepository()
|
|
|
|
var llmProvider domain.LLMProvider
|
|
switch cfg.LLM.Provider {
|
|
case "openrouter":
|
|
llmProvider = llm.NewOpenRouterProvider(cfg.LLM.APIKey, cfg.LLM.Model)
|
|
log.Printf("Using OpenRouter with model: %s", cfg.LLM.Model)
|
|
case "lmstudio":
|
|
llmProvider = llm.NewLMStudioProvider(cfg.LLM.BaseURL, cfg.LLM.Model)
|
|
log.Printf("Using LM Studio at %s with model: %s", cfg.LLM.BaseURL, cfg.LLM.Model)
|
|
default:
|
|
log.Fatalf("Unknown provider: %s", cfg.LLM.Provider)
|
|
}
|
|
|
|
// Create use case
|
|
categorizeUC := usecase.NewCategorizeUseCase(personRepo, llmProvider, cfg.OutputDir, cfg.BatchSize)
|
|
|
|
// Execute categorization
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
|
defer cancel()
|
|
|
|
log.Printf("Output directory: %s", cfg.OutputDir)
|
|
log.Printf("Batch size: %d", cfg.BatchSize)
|
|
log.Printf("Fetching people from: %s", cfg.DataSource.URL)
|
|
people, err := categorizeUC.Execute(ctx, cfg.DataSource.URL)
|
|
if err != nil {
|
|
log.Fatalf("Failed to categorize people: %v", err)
|
|
}
|
|
|
|
// Prepare full results
|
|
output := map[string]interface{}{
|
|
"answer": people,
|
|
}
|
|
|
|
result, err := json.MarshalIndent(output, "", " ")
|
|
if err != nil {
|
|
log.Fatalf("Failed to marshal results: %v", err)
|
|
}
|
|
|
|
fmt.Println(string(result))
|
|
|
|
// Save all results to file
|
|
if err := os.WriteFile("output.json", result, 0644); err != nil {
|
|
log.Printf("Warning: Failed to write output file: %v", err)
|
|
} else {
|
|
log.Printf("Results saved to output.json")
|
|
}
|
|
|
|
// Filter only people with "transport" tag for verification
|
|
transportPeople := filterTransportPeople(people)
|
|
log.Printf("\n[INFO] Total categorized people: %d", len(people))
|
|
log.Printf("[INFO] People with 'transport' tag: %d", len(transportPeople))
|
|
|
|
// Send to verification endpoint (only transport people)
|
|
log.Printf("\n========== SENDING TO VERIFICATION ENDPOINT ==========")
|
|
responseData, err := sendVerification(cfg.APIKey, transportPeople)
|
|
if err != nil {
|
|
log.Fatalf("Failed to send verification: %v", err)
|
|
}
|
|
|
|
// Save response to file
|
|
if err := os.WriteFile("response.json", responseData, 0644); err != nil {
|
|
log.Printf("Warning: Failed to write response file: %v", err)
|
|
} else {
|
|
log.Printf("Response saved to response.json")
|
|
}
|
|
log.Printf("======================================================\n")
|
|
}
|
|
|
|
func filterTransportPeople(people []domain.Person) []domain.Person {
|
|
var filtered []domain.Person
|
|
for _, person := range people {
|
|
// Check if person has "transport" tag
|
|
hasTransport := false
|
|
for _, tag := range person.Tags {
|
|
if tag == "transport" {
|
|
hasTransport = true
|
|
break
|
|
}
|
|
}
|
|
if hasTransport {
|
|
filtered = append(filtered, person)
|
|
}
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
func sendVerification(apiKey string, people []domain.Person) ([]byte, error) {
|
|
verifyURL := "https://hub.ag3nts.org/verify"
|
|
|
|
// Prepare request body
|
|
requestBody := map[string]interface{}{
|
|
"apikey": apiKey,
|
|
"task": "people",
|
|
"answer": people,
|
|
}
|
|
|
|
requestJSON, err := json.MarshalIndent(requestBody, "", " ")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshaling request: %w", err)
|
|
}
|
|
|
|
// Log the request
|
|
log.Printf("POST %s", verifyURL)
|
|
log.Printf("Request body:\n%s", string(requestJSON))
|
|
|
|
// Send request
|
|
resp, err := http.Post(verifyURL, "application/json", bytes.NewBuffer(requestJSON))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sending request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Read response
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading response: %w", err)
|
|
}
|
|
|
|
log.Printf("\nResponse status: %d", resp.StatusCode)
|
|
log.Printf("Response body:\n%s", string(responseBody))
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return responseBody, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(responseBody))
|
|
}
|
|
|
|
return responseBody, nil
|
|
}
|