File size: 3,034 Bytes
18faf97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# ----------------------------------------------------------------------
# GPU PROCESSING FUNCTIONS
# ----------------------------------------------------------------------
import logging
import torch
import gc
from typing import List, Dict, Any

from src.utils import (
    ProcessingContext,
    cleanup_memory,
    LOG_LEVEL_MAP,
    EMOJI_MAP
)

# ----------------------------------------------------------------------
# GPU MEMORY MANAGEMENT
# ----------------------------------------------------------------------
def clear_gpu_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()
        
        logging.info(
            f"{EMOJI_MAP['INFO']} GPU memory cleared - "
            f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f}MB, "
            f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f}MB"
        )

def get_gpu_memory_info() -> Dict[str, float]:
    if not torch.cuda.is_available():
        return {"available": False}
    
    return {
        "available": True,
        "allocated_mb": torch.cuda.memory_allocated() / 1024**2,
        "reserved_mb": torch.cuda.memory_reserved() / 1024**2,
        "max_allocated_mb": torch.cuda.max_memory_allocated() / 1024**2
    }

# ----------------------------------------------------------------------
# BATCH PROCESSING
# ----------------------------------------------------------------------
def process_batch_with_gpu(
    contexts: List[ProcessingContext],
    batch_size: int = 4
) -> List[ProcessingContext]:
    
    total_contexts = len(contexts)
    processed = 0
    
    for i in range(0, total_contexts, batch_size):
        batch = contexts[i:i + batch_size]
        batch_nums = f"{i+1}-{min(i+batch_size, total_contexts)}/{total_contexts}"
        
        logging.info(
            f"{EMOJI_MAP['PROCESSING']} Processing batch {batch_nums}"
        )
        
        for ctx in batch:
            if ctx.skip_processing:
                continue
                
            try:
                yield ctx
                processed += 1
            except Exception as e:
                logging.error(
                    f"{EMOJI_MAP['ERROR']} Error processing {ctx.url}: {str(e)}"
                )
                ctx.error = str(e)
                ctx.skip_processing = True
        
        if torch.cuda.is_available() and i + batch_size < total_contexts:
            clear_gpu_memory()
    
    logging.info(
        f"{EMOJI_MAP['SUCCESS']} Processed {processed}/{total_contexts} images"
    )

# ----------------------------------------------------------------------
# GPU DECORATORS
# ----------------------------------------------------------------------
def with_gpu_memory_management(func):
    def wrapper(*args, **kwargs):
        try:
            clear_gpu_memory()
            result = func(*args, **kwargs)
            return result
        finally:
            clear_gpu_memory()
    
    wrapper.__name__ = func.__name__
    return wrapper