Files
s01e01/internal/usecase/categorize.go
2026-03-11 22:51:42 +01:00

468 lines
14 KiB
Go

package usecase
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/paramah/ai_devs4/s01e01/internal/domain"
)
// CategorizeUseCase handles the categorization of people
type CategorizeUseCase struct {
personRepo domain.PersonRepository
llmProvider domain.LLMProvider
outputDir string
batchSize int
}
// NewCategorizeUseCase creates a new categorize use case
func NewCategorizeUseCase(repo domain.PersonRepository, llm domain.LLMProvider, outputDir string, batchSize int) *CategorizeUseCase {
return &CategorizeUseCase{
personRepo: repo,
llmProvider: llm,
outputDir: outputDir,
batchSize: batchSize,
}
}
// Execute fetches people and categorizes them using LLM
func (uc *CategorizeUseCase) Execute(ctx context.Context, dataURL string) ([]domain.Person, error) {
// Create output directory if it doesn't exist
if err := os.MkdirAll(uc.outputDir, 0755); err != nil {
return nil, fmt.Errorf("creating output directory: %w", err)
}
// Fetch people from data source
allPeople, err := uc.personRepo.FetchPeople(ctx, dataURL)
if err != nil {
return nil, fmt.Errorf("fetching people: %w", err)
}
originalCount := len(allPeople)
fmt.Printf("\n[%s] ========== DATA FILTERING ==========\n", time.Now().Format("2006-01-02 15:04:05"))
fmt.Printf("[%s] Original CSV entries: %d\n", time.Now().Format("2006-01-02 15:04:05"), originalCount)
if originalCount == 0 {
return []domain.Person{}, nil
}
fmt.Printf("[%s] Applying filters:\n", time.Now().Format("2006-01-02 15:04:05"))
fmt.Printf(" - Gender: M (male)\n")
fmt.Printf(" - Age in 2026: 20-40 years (born 1986-2006)\n")
fmt.Printf(" - City: Grudziądz\n")
fmt.Printf(" - Industry: ALL (will be categorized by LLM)\n\n")
// Filter people - keep only those matching criteria
people := uc.filterCompletePeople(allPeople)
filteredCount := len(people)
fmt.Printf("[%s] Filtered entries (matching criteria): %d\n", time.Now().Format("2006-01-02 15:04:05"), filteredCount)
fmt.Printf("[%s] Removed entries: %d\n", time.Now().Format("2006-01-02 15:04:05"), originalCount-filteredCount)
fmt.Printf("[%s] =====================================\n\n", time.Now().Format("2006-01-02 15:04:05"))
// Save filtered data to file
if err := uc.saveFilteredData(people); err != nil {
return nil, fmt.Errorf("saving filtered data: %w", err)
}
if filteredCount == 0 {
fmt.Printf("[%s] No complete entries to process\n", time.Now().Format("2006-01-02 15:04:05"))
return []domain.Person{}, nil
}
// Process in batches
totalPeople := len(people)
totalBatches := (totalPeople + uc.batchSize - 1) / uc.batchSize
startTime := time.Now()
processedBatches := 0
skippedBatches := 0
for i := 0; i < totalPeople; i += uc.batchSize {
batchNum := i/uc.batchSize + 1
batchStart := time.Now()
end := i + uc.batchSize
if end > totalPeople {
end = totalPeople
}
batch := people[i:end]
// Generate filename for this batch
batchFilename := fmt.Sprintf("batch_%d_%d.json", i, end-1)
batchFilepath := filepath.Join(uc.outputDir, batchFilename)
// Check if batch already processed
if _, err := os.Stat(batchFilepath); err == nil {
skippedBatches++
fmt.Printf("[%s] Skipping batch %d/%d (entries %d-%d, already processed)\n",
time.Now().Format("2006-01-02 15:04:05"),
batchNum,
totalBatches,
i,
end-1)
continue
}
// Calculate ETA
var etaStr string
if processedBatches > 0 {
elapsed := time.Since(startTime)
avgTimePerBatch := elapsed / time.Duration(processedBatches)
remainingBatches := totalBatches - batchNum
eta := avgTimePerBatch * time.Duration(remainingBatches)
etaStr = fmt.Sprintf(" (ETA: %s)", eta.Round(time.Second))
}
fmt.Printf("[%s] Processing batch %d/%d (entries %d-%d, %d people)...%s\n",
time.Now().Format("2006-01-02 15:04:05"),
batchNum,
totalBatches,
i,
end-1,
len(batch),
etaStr)
// Prepare prompt for LLM
prompt, schema := uc.buildPrompt(batch)
// Send to LLM for categorization
response, err := uc.llmProvider.Complete(ctx, domain.LLMRequest{
Prompt: prompt,
Schema: schema,
})
if err != nil {
return nil, fmt.Errorf("LLM completion (batch %d): %w", batchNum, err)
}
// Log the raw model response for debugging
fmt.Printf("[%s] Raw response:\n%s\n",
time.Now().Format("2006-01-02 15:04:05"),
response.Content)
// Parse the response
categorizedBatch, err := uc.parseBatchResponse(response.Content)
if err != nil {
return nil, fmt.Errorf("parsing LLM response (batch %d): %w\nRaw response: %s", batchNum, err, response.Content)
}
// Save batch to file
if err := uc.saveBatchToFile(categorizedBatch, batchFilepath); err != nil {
return nil, fmt.Errorf("saving batch to file: %w", err)
}
processedBatches++
batchDuration := time.Since(batchStart)
fmt.Printf("[%s] Batch %d completed in %s (%d people)\n",
time.Now().Format("2006-01-02 15:04:05"),
batchNum,
batchDuration.Round(time.Second),
len(categorizedBatch))
}
fmt.Printf("\n[%s] Summary: Processed %d batches, Skipped %d (already done), Total %d batches\n",
time.Now().Format("2006-01-02 15:04:05"),
processedBatches,
skippedBatches,
totalBatches)
// Collect all results from files
return uc.collectResults()
}
// saveBatchToFile saves a batch of people to a JSON file
func (uc *CategorizeUseCase) saveBatchToFile(people []domain.Person, filepath string) error {
data, err := json.MarshalIndent(people, "", " ")
if err != nil {
return fmt.Errorf("marshaling batch: %w", err)
}
if err := os.WriteFile(filepath, data, 0644); err != nil {
return fmt.Errorf("writing file: %w", err)
}
return nil
}
// filterCompletePeople filters people based on specific criteria
func (uc *CategorizeUseCase) filterCompletePeople(people []domain.Person) []domain.Person {
const currentYear = 2026
const minAge = 20
const maxAge = 40
// Calculate birth year range
minBornYear := currentYear - maxAge // 1986
maxBornYear := currentYear - minAge // 2006
var filtered []domain.Person
for _, person := range people {
// Basic data completeness check
if person.Gender == "" || person.Born == 0 || person.City == "" || person.Job == "" {
continue
}
// Apply specific filters:
// 1. Gender: Male (M)
if person.Gender != "M" {
continue
}
// 2. Age: between 20-40 years in 2026 (born between 1986-2006)
if person.Born < minBornYear || person.Born > maxBornYear {
continue
}
// 3. City: Grudziądz
if person.City != "Grudziądz" {
continue
}
// 4. Industry: transport-related (DISABLED - LLM will categorize)
// if !uc.isTransportJob(person.Job) {
// continue
// }
filtered = append(filtered, person)
}
return filtered
}
// isTransportJob checks if job description is related to transport industry
func (uc *CategorizeUseCase) isTransportJob(jobDescription string) bool {
jobLower := strings.ToLower(jobDescription)
transportKeywords := []string{
"transport",
"pojazd",
"samochód",
"kierow",
"prowadz", // prowadzenie pojazdu
"dostaw",
"przewóz",
"logistyk",
"auto",
"ciężar", // ciężarówka
"bus",
"tir",
"wagon",
"pojazd",
"ruch",
"droga",
"trasa",
"przesył",
}
for _, keyword := range transportKeywords {
if strings.Contains(jobLower, keyword) {
return true
}
}
return false
}
// saveFilteredData saves filtered people data to a JSON file
func (uc *CategorizeUseCase) saveFilteredData(people []domain.Person) error {
// Create a version without Job field for saving
type PersonForSave struct {
Name string `json:"name"`
Surname string `json:"surname"`
Gender string `json:"gender"`
Born int `json:"born"`
City string `json:"city"`
Job string `json:"job"`
}
peopleForSave := make([]PersonForSave, len(people))
for i, p := range people {
peopleForSave[i] = PersonForSave{
Name: p.Name,
Surname: p.Surname,
Gender: p.Gender,
Born: p.Born,
City: p.City,
Job: p.Job,
}
}
data, err := json.MarshalIndent(peopleForSave, "", " ")
if err != nil {
return fmt.Errorf("marshaling filtered data: %w", err)
}
filePath := filepath.Join(uc.outputDir, "filtered_people.json")
if err := os.WriteFile(filePath, data, 0644); err != nil {
return fmt.Errorf("writing filtered data file: %w", err)
}
fmt.Printf("[%s] Filtered data saved to: %s\n\n", time.Now().Format("2006-01-02 15:04:05"), filePath)
return nil
}
// collectResults reads all batch files and returns them as a slice
func (uc *CategorizeUseCase) collectResults() ([]domain.Person, error) {
files, err := os.ReadDir(uc.outputDir)
if err != nil {
return nil, fmt.Errorf("reading output directory: %w", err)
}
var allPeople []domain.Person
for _, file := range files {
// Skip filtered_people.json and only process batch files
if file.IsDir() || filepath.Ext(file.Name()) != ".json" || file.Name() == "filtered_people.json" {
continue
}
data, err := os.ReadFile(filepath.Join(uc.outputDir, file.Name()))
if err != nil {
return nil, fmt.Errorf("reading file %s: %w", file.Name(), err)
}
var batch []domain.Person
if err := json.Unmarshal(data, &batch); err != nil {
return nil, fmt.Errorf("unmarshaling file %s: %w", file.Name(), err)
}
allPeople = append(allPeople, batch...)
}
return allPeople, nil
}
func (uc *CategorizeUseCase) buildPrompt(batch []domain.Person) (string, interface{}) {
// Create a version with job descriptions for the prompt
type PersonWithJob struct {
Name string `json:"name"`
Surname string `json:"surname"`
Gender string `json:"gender"`
Born int `json:"born"`
City string `json:"city"`
Job string `json:"job"`
}
batchWithJobs := make([]PersonWithJob, len(batch))
for i, p := range batch {
batchWithJobs[i] = PersonWithJob{
Name: p.Name,
Surname: p.Surname,
Gender: p.Gender,
Born: p.Born,
City: p.City,
Job: p.Job,
}
}
batchJSON, _ := json.Marshal(batchWithJobs)
availableTags := domain.AvailableTags()
tagsJSON, _ := json.Marshal(availableTags)
prompt := fmt.Sprintf(`Categorize the following people based on their job descriptions.
Each person should be assigned one or more appropriate tags from the available list based on their job field.
Available tags: %s
People to categorize (each person has name, surname, gender, born year, city, and job description):
%s
CRITICAL INSTRUCTIONS:
1. YOU MUST PROCESS ALL %d PEOPLE IN THE INPUT - NOT JUST A FEW!
2. Read EVERY person's job description carefully
3. Assign 1-3 relevant tags from the available list based on what the job description says
4. Tag mapping guidelines:
- "IT" - for programming, software development, algorithms, data structures, technology
- "transport" - for driving, vehicle operation, logistics
- "edukacja" - for teaching, training, education, development of skills
- "medycyna" - for healthcare, doctors, nurses, medical diagnosis, treatment
- "praca z ludźmi" - for jobs involving direct work with people (teaching, healthcare, consulting, etc.)
- "praca z pojazdami" - for mechanics, vehicle repair, automotive work
- "praca fizyczna" - for manual labor, construction, carpentry, physical work
5. Many jobs can have multiple tags (e.g., a teacher = "edukacja" + "praca z ludźmi")
6. Return a complete JSON array with ALL %d people
7. Each person object must have: name, surname, gender, born, city, tags
8. Do NOT include the job description in the output
9. No explanations, no markdown formatting, just the JSON array`, string(tagsJSON), string(batchJSON), len(batch), len(batch))
// JSON Schema for structured output
schema := map[string]interface{}{
"name": "categorize_people",
"schema": map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
},
"surname": map[string]interface{}{
"type": "string",
},
"gender": map[string]interface{}{
"type": "string",
},
"born": map[string]interface{}{
"type": "integer",
},
"city": map[string]interface{}{
"type": "string",
},
"tags": map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"type": "string",
},
},
},
"required": []string{"name", "surname", "gender", "born", "city", "tags"},
},
},
}
return prompt, schema
}
func (uc *CategorizeUseCase) parseBatchResponse(content string) ([]domain.Person, error) {
// Clean up the response - extract JSON
cleanContent := uc.cleanJSONResponse(content)
var people []domain.Person
if err := json.Unmarshal([]byte(cleanContent), &people); err != nil {
return nil, fmt.Errorf("unmarshaling response: %w", err)
}
return people, nil
}
// cleanJSONResponse extracts JSON from response that may contain extra text
func (uc *CategorizeUseCase) cleanJSONResponse(content string) string {
// Trim whitespace
content = strings.TrimSpace(content)
// Try to find JSON array first [...]
arrayStart := strings.Index(content, "[")
arrayEnd := strings.LastIndex(content, "]")
// Try to find JSON object {...}
objectStart := strings.Index(content, "{")
objectEnd := strings.LastIndex(content, "}")
// Use whichever comes first
if arrayStart != -1 && (objectStart == -1 || arrayStart < objectStart) {
if arrayEnd != -1 && arrayEnd > arrayStart {
return strings.TrimSpace(content[arrayStart : arrayEnd+1])
}
}
if objectStart != -1 && objectEnd != -1 && objectStart < objectEnd {
return strings.TrimSpace(content[objectStart : objectEnd+1])
}
// No valid JSON found, return as is
return content
}