468 lines
14 KiB
Go
468 lines
14 KiB
Go
package usecase
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/paramah/ai_devs4/s01e01/internal/domain"
|
|
)
|
|
|
|
// CategorizeUseCase handles the categorization of people
|
|
type CategorizeUseCase struct {
|
|
personRepo domain.PersonRepository
|
|
llmProvider domain.LLMProvider
|
|
outputDir string
|
|
batchSize int
|
|
}
|
|
|
|
// NewCategorizeUseCase creates a new categorize use case
|
|
func NewCategorizeUseCase(repo domain.PersonRepository, llm domain.LLMProvider, outputDir string, batchSize int) *CategorizeUseCase {
|
|
return &CategorizeUseCase{
|
|
personRepo: repo,
|
|
llmProvider: llm,
|
|
outputDir: outputDir,
|
|
batchSize: batchSize,
|
|
}
|
|
}
|
|
|
|
// Execute fetches people and categorizes them using LLM
|
|
func (uc *CategorizeUseCase) Execute(ctx context.Context, dataURL string) ([]domain.Person, error) {
|
|
// Create output directory if it doesn't exist
|
|
if err := os.MkdirAll(uc.outputDir, 0755); err != nil {
|
|
return nil, fmt.Errorf("creating output directory: %w", err)
|
|
}
|
|
|
|
// Fetch people from data source
|
|
allPeople, err := uc.personRepo.FetchPeople(ctx, dataURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fetching people: %w", err)
|
|
}
|
|
|
|
originalCount := len(allPeople)
|
|
fmt.Printf("\n[%s] ========== DATA FILTERING ==========\n", time.Now().Format("2006-01-02 15:04:05"))
|
|
fmt.Printf("[%s] Original CSV entries: %d\n", time.Now().Format("2006-01-02 15:04:05"), originalCount)
|
|
|
|
if originalCount == 0 {
|
|
return []domain.Person{}, nil
|
|
}
|
|
|
|
fmt.Printf("[%s] Applying filters:\n", time.Now().Format("2006-01-02 15:04:05"))
|
|
fmt.Printf(" - Gender: M (male)\n")
|
|
fmt.Printf(" - Age in 2026: 20-40 years (born 1986-2006)\n")
|
|
fmt.Printf(" - City: Grudziądz\n")
|
|
fmt.Printf(" - Industry: ALL (will be categorized by LLM)\n\n")
|
|
|
|
// Filter people - keep only those matching criteria
|
|
people := uc.filterCompletePeople(allPeople)
|
|
filteredCount := len(people)
|
|
|
|
fmt.Printf("[%s] Filtered entries (matching criteria): %d\n", time.Now().Format("2006-01-02 15:04:05"), filteredCount)
|
|
fmt.Printf("[%s] Removed entries: %d\n", time.Now().Format("2006-01-02 15:04:05"), originalCount-filteredCount)
|
|
fmt.Printf("[%s] =====================================\n\n", time.Now().Format("2006-01-02 15:04:05"))
|
|
|
|
// Save filtered data to file
|
|
if err := uc.saveFilteredData(people); err != nil {
|
|
return nil, fmt.Errorf("saving filtered data: %w", err)
|
|
}
|
|
|
|
if filteredCount == 0 {
|
|
fmt.Printf("[%s] No complete entries to process\n", time.Now().Format("2006-01-02 15:04:05"))
|
|
return []domain.Person{}, nil
|
|
}
|
|
|
|
// Process in batches
|
|
totalPeople := len(people)
|
|
totalBatches := (totalPeople + uc.batchSize - 1) / uc.batchSize
|
|
startTime := time.Now()
|
|
processedBatches := 0
|
|
skippedBatches := 0
|
|
|
|
for i := 0; i < totalPeople; i += uc.batchSize {
|
|
batchNum := i/uc.batchSize + 1
|
|
batchStart := time.Now()
|
|
|
|
end := i + uc.batchSize
|
|
if end > totalPeople {
|
|
end = totalPeople
|
|
}
|
|
|
|
batch := people[i:end]
|
|
|
|
// Generate filename for this batch
|
|
batchFilename := fmt.Sprintf("batch_%d_%d.json", i, end-1)
|
|
batchFilepath := filepath.Join(uc.outputDir, batchFilename)
|
|
|
|
// Check if batch already processed
|
|
if _, err := os.Stat(batchFilepath); err == nil {
|
|
skippedBatches++
|
|
fmt.Printf("[%s] Skipping batch %d/%d (entries %d-%d, already processed)\n",
|
|
time.Now().Format("2006-01-02 15:04:05"),
|
|
batchNum,
|
|
totalBatches,
|
|
i,
|
|
end-1)
|
|
continue
|
|
}
|
|
|
|
// Calculate ETA
|
|
var etaStr string
|
|
if processedBatches > 0 {
|
|
elapsed := time.Since(startTime)
|
|
avgTimePerBatch := elapsed / time.Duration(processedBatches)
|
|
remainingBatches := totalBatches - batchNum
|
|
eta := avgTimePerBatch * time.Duration(remainingBatches)
|
|
etaStr = fmt.Sprintf(" (ETA: %s)", eta.Round(time.Second))
|
|
}
|
|
|
|
fmt.Printf("[%s] Processing batch %d/%d (entries %d-%d, %d people)...%s\n",
|
|
time.Now().Format("2006-01-02 15:04:05"),
|
|
batchNum,
|
|
totalBatches,
|
|
i,
|
|
end-1,
|
|
len(batch),
|
|
etaStr)
|
|
|
|
// Prepare prompt for LLM
|
|
prompt, schema := uc.buildPrompt(batch)
|
|
|
|
// Send to LLM for categorization
|
|
response, err := uc.llmProvider.Complete(ctx, domain.LLMRequest{
|
|
Prompt: prompt,
|
|
Schema: schema,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("LLM completion (batch %d): %w", batchNum, err)
|
|
}
|
|
|
|
// Log the raw model response for debugging
|
|
fmt.Printf("[%s] Raw response:\n%s\n",
|
|
time.Now().Format("2006-01-02 15:04:05"),
|
|
response.Content)
|
|
|
|
// Parse the response
|
|
categorizedBatch, err := uc.parseBatchResponse(response.Content)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing LLM response (batch %d): %w\nRaw response: %s", batchNum, err, response.Content)
|
|
}
|
|
|
|
// Save batch to file
|
|
if err := uc.saveBatchToFile(categorizedBatch, batchFilepath); err != nil {
|
|
return nil, fmt.Errorf("saving batch to file: %w", err)
|
|
}
|
|
|
|
processedBatches++
|
|
batchDuration := time.Since(batchStart)
|
|
fmt.Printf("[%s] Batch %d completed in %s (%d people)\n",
|
|
time.Now().Format("2006-01-02 15:04:05"),
|
|
batchNum,
|
|
batchDuration.Round(time.Second),
|
|
len(categorizedBatch))
|
|
}
|
|
|
|
fmt.Printf("\n[%s] Summary: Processed %d batches, Skipped %d (already done), Total %d batches\n",
|
|
time.Now().Format("2006-01-02 15:04:05"),
|
|
processedBatches,
|
|
skippedBatches,
|
|
totalBatches)
|
|
|
|
// Collect all results from files
|
|
return uc.collectResults()
|
|
}
|
|
|
|
// saveBatchToFile saves a batch of people to a JSON file
|
|
func (uc *CategorizeUseCase) saveBatchToFile(people []domain.Person, filepath string) error {
|
|
data, err := json.MarshalIndent(people, "", " ")
|
|
if err != nil {
|
|
return fmt.Errorf("marshaling batch: %w", err)
|
|
}
|
|
|
|
if err := os.WriteFile(filepath, data, 0644); err != nil {
|
|
return fmt.Errorf("writing file: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// filterCompletePeople filters people based on specific criteria
|
|
func (uc *CategorizeUseCase) filterCompletePeople(people []domain.Person) []domain.Person {
|
|
const currentYear = 2026
|
|
const minAge = 20
|
|
const maxAge = 40
|
|
|
|
// Calculate birth year range
|
|
minBornYear := currentYear - maxAge // 1986
|
|
maxBornYear := currentYear - minAge // 2006
|
|
|
|
var filtered []domain.Person
|
|
for _, person := range people {
|
|
// Basic data completeness check
|
|
if person.Gender == "" || person.Born == 0 || person.City == "" || person.Job == "" {
|
|
continue
|
|
}
|
|
|
|
// Apply specific filters:
|
|
// 1. Gender: Male (M)
|
|
if person.Gender != "M" {
|
|
continue
|
|
}
|
|
|
|
// 2. Age: between 20-40 years in 2026 (born between 1986-2006)
|
|
if person.Born < minBornYear || person.Born > maxBornYear {
|
|
continue
|
|
}
|
|
|
|
// 3. City: Grudziądz
|
|
if person.City != "Grudziądz" {
|
|
continue
|
|
}
|
|
|
|
// 4. Industry: transport-related (DISABLED - LLM will categorize)
|
|
// if !uc.isTransportJob(person.Job) {
|
|
// continue
|
|
// }
|
|
|
|
filtered = append(filtered, person)
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
// isTransportJob checks if job description is related to transport industry
|
|
func (uc *CategorizeUseCase) isTransportJob(jobDescription string) bool {
|
|
jobLower := strings.ToLower(jobDescription)
|
|
|
|
transportKeywords := []string{
|
|
"transport",
|
|
"pojazd",
|
|
"samochód",
|
|
"kierow",
|
|
"prowadz", // prowadzenie pojazdu
|
|
"dostaw",
|
|
"przewóz",
|
|
"logistyk",
|
|
"auto",
|
|
"ciężar", // ciężarówka
|
|
"bus",
|
|
"tir",
|
|
"wagon",
|
|
"pojazd",
|
|
"ruch",
|
|
"droga",
|
|
"trasa",
|
|
"przesył",
|
|
}
|
|
|
|
for _, keyword := range transportKeywords {
|
|
if strings.Contains(jobLower, keyword) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// saveFilteredData saves filtered people data to a JSON file
|
|
func (uc *CategorizeUseCase) saveFilteredData(people []domain.Person) error {
|
|
// Create a version without Job field for saving
|
|
type PersonForSave struct {
|
|
Name string `json:"name"`
|
|
Surname string `json:"surname"`
|
|
Gender string `json:"gender"`
|
|
Born int `json:"born"`
|
|
City string `json:"city"`
|
|
Job string `json:"job"`
|
|
}
|
|
|
|
peopleForSave := make([]PersonForSave, len(people))
|
|
for i, p := range people {
|
|
peopleForSave[i] = PersonForSave{
|
|
Name: p.Name,
|
|
Surname: p.Surname,
|
|
Gender: p.Gender,
|
|
Born: p.Born,
|
|
City: p.City,
|
|
Job: p.Job,
|
|
}
|
|
}
|
|
|
|
data, err := json.MarshalIndent(peopleForSave, "", " ")
|
|
if err != nil {
|
|
return fmt.Errorf("marshaling filtered data: %w", err)
|
|
}
|
|
|
|
filePath := filepath.Join(uc.outputDir, "filtered_people.json")
|
|
if err := os.WriteFile(filePath, data, 0644); err != nil {
|
|
return fmt.Errorf("writing filtered data file: %w", err)
|
|
}
|
|
|
|
fmt.Printf("[%s] Filtered data saved to: %s\n\n", time.Now().Format("2006-01-02 15:04:05"), filePath)
|
|
return nil
|
|
}
|
|
|
|
// collectResults reads all batch files and returns them as a slice
|
|
func (uc *CategorizeUseCase) collectResults() ([]domain.Person, error) {
|
|
files, err := os.ReadDir(uc.outputDir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading output directory: %w", err)
|
|
}
|
|
|
|
var allPeople []domain.Person
|
|
for _, file := range files {
|
|
// Skip filtered_people.json and only process batch files
|
|
if file.IsDir() || filepath.Ext(file.Name()) != ".json" || file.Name() == "filtered_people.json" {
|
|
continue
|
|
}
|
|
|
|
data, err := os.ReadFile(filepath.Join(uc.outputDir, file.Name()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file %s: %w", file.Name(), err)
|
|
}
|
|
|
|
var batch []domain.Person
|
|
if err := json.Unmarshal(data, &batch); err != nil {
|
|
return nil, fmt.Errorf("unmarshaling file %s: %w", file.Name(), err)
|
|
}
|
|
|
|
allPeople = append(allPeople, batch...)
|
|
}
|
|
|
|
return allPeople, nil
|
|
}
|
|
|
|
func (uc *CategorizeUseCase) buildPrompt(batch []domain.Person) (string, interface{}) {
|
|
// Create a version with job descriptions for the prompt
|
|
type PersonWithJob struct {
|
|
Name string `json:"name"`
|
|
Surname string `json:"surname"`
|
|
Gender string `json:"gender"`
|
|
Born int `json:"born"`
|
|
City string `json:"city"`
|
|
Job string `json:"job"`
|
|
}
|
|
|
|
batchWithJobs := make([]PersonWithJob, len(batch))
|
|
for i, p := range batch {
|
|
batchWithJobs[i] = PersonWithJob{
|
|
Name: p.Name,
|
|
Surname: p.Surname,
|
|
Gender: p.Gender,
|
|
Born: p.Born,
|
|
City: p.City,
|
|
Job: p.Job,
|
|
}
|
|
}
|
|
|
|
batchJSON, _ := json.Marshal(batchWithJobs)
|
|
|
|
availableTags := domain.AvailableTags()
|
|
tagsJSON, _ := json.Marshal(availableTags)
|
|
|
|
prompt := fmt.Sprintf(`Categorize the following people based on their job descriptions.
|
|
Each person should be assigned one or more appropriate tags from the available list based on their job field.
|
|
|
|
Available tags: %s
|
|
|
|
People to categorize (each person has name, surname, gender, born year, city, and job description):
|
|
%s
|
|
|
|
CRITICAL INSTRUCTIONS:
|
|
1. YOU MUST PROCESS ALL %d PEOPLE IN THE INPUT - NOT JUST A FEW!
|
|
2. Read EVERY person's job description carefully
|
|
3. Assign 1-3 relevant tags from the available list based on what the job description says
|
|
4. Tag mapping guidelines:
|
|
- "IT" - for programming, software development, algorithms, data structures, technology
|
|
- "transport" - for driving, vehicle operation, logistics
|
|
- "edukacja" - for teaching, training, education, development of skills
|
|
- "medycyna" - for healthcare, doctors, nurses, medical diagnosis, treatment
|
|
- "praca z ludźmi" - for jobs involving direct work with people (teaching, healthcare, consulting, etc.)
|
|
- "praca z pojazdami" - for mechanics, vehicle repair, automotive work
|
|
- "praca fizyczna" - for manual labor, construction, carpentry, physical work
|
|
5. Many jobs can have multiple tags (e.g., a teacher = "edukacja" + "praca z ludźmi")
|
|
6. Return a complete JSON array with ALL %d people
|
|
7. Each person object must have: name, surname, gender, born, city, tags
|
|
8. Do NOT include the job description in the output
|
|
9. No explanations, no markdown formatting, just the JSON array`, string(tagsJSON), string(batchJSON), len(batch), len(batch))
|
|
|
|
// JSON Schema for structured output
|
|
schema := map[string]interface{}{
|
|
"name": "categorize_people",
|
|
"schema": map[string]interface{}{
|
|
"type": "array",
|
|
"items": map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"name": map[string]interface{}{
|
|
"type": "string",
|
|
},
|
|
"surname": map[string]interface{}{
|
|
"type": "string",
|
|
},
|
|
"gender": map[string]interface{}{
|
|
"type": "string",
|
|
},
|
|
"born": map[string]interface{}{
|
|
"type": "integer",
|
|
},
|
|
"city": map[string]interface{}{
|
|
"type": "string",
|
|
},
|
|
"tags": map[string]interface{}{
|
|
"type": "array",
|
|
"items": map[string]interface{}{
|
|
"type": "string",
|
|
},
|
|
},
|
|
},
|
|
"required": []string{"name", "surname", "gender", "born", "city", "tags"},
|
|
},
|
|
},
|
|
}
|
|
|
|
return prompt, schema
|
|
}
|
|
|
|
func (uc *CategorizeUseCase) parseBatchResponse(content string) ([]domain.Person, error) {
|
|
// Clean up the response - extract JSON
|
|
cleanContent := uc.cleanJSONResponse(content)
|
|
|
|
var people []domain.Person
|
|
if err := json.Unmarshal([]byte(cleanContent), &people); err != nil {
|
|
return nil, fmt.Errorf("unmarshaling response: %w", err)
|
|
}
|
|
|
|
return people, nil
|
|
}
|
|
|
|
// cleanJSONResponse extracts JSON from response that may contain extra text
|
|
func (uc *CategorizeUseCase) cleanJSONResponse(content string) string {
|
|
// Trim whitespace
|
|
content = strings.TrimSpace(content)
|
|
|
|
// Try to find JSON array first [...]
|
|
arrayStart := strings.Index(content, "[")
|
|
arrayEnd := strings.LastIndex(content, "]")
|
|
|
|
// Try to find JSON object {...}
|
|
objectStart := strings.Index(content, "{")
|
|
objectEnd := strings.LastIndex(content, "}")
|
|
|
|
// Use whichever comes first
|
|
if arrayStart != -1 && (objectStart == -1 || arrayStart < objectStart) {
|
|
if arrayEnd != -1 && arrayEnd > arrayStart {
|
|
return strings.TrimSpace(content[arrayStart : arrayEnd+1])
|
|
}
|
|
}
|
|
|
|
if objectStart != -1 && objectEnd != -1 && objectStart < objectEnd {
|
|
return strings.TrimSpace(content[objectStart : objectEnd+1])
|
|
}
|
|
|
|
// No valid JSON found, return as is
|
|
return content
|
|
}
|