Spaces:
Running
on
Zero
Running
on
Zero
# ---------------------------------------------------------------------- | |
# GPU PROCESSING FUNCTIONS | |
# ---------------------------------------------------------------------- | |
import logging | |
import torch | |
import gc | |
from typing import List, Dict, Any | |
from src.utils import ( | |
ProcessingContext, | |
cleanup_memory, | |
LOG_LEVEL_MAP, | |
EMOJI_MAP | |
) | |
# ---------------------------------------------------------------------- | |
# GPU MEMORY MANAGEMENT | |
# ---------------------------------------------------------------------- | |
def clear_gpu_memory(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
logging.info( | |
f"{EMOJI_MAP['INFO']} GPU memory cleared - " | |
f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f}MB, " | |
f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f}MB" | |
) | |
def get_gpu_memory_info() -> Dict[str, float]: | |
if not torch.cuda.is_available(): | |
return {"available": False} | |
return { | |
"available": True, | |
"allocated_mb": torch.cuda.memory_allocated() / 1024**2, | |
"reserved_mb": torch.cuda.memory_reserved() / 1024**2, | |
"max_allocated_mb": torch.cuda.max_memory_allocated() / 1024**2 | |
} | |
# ---------------------------------------------------------------------- | |
# BATCH PROCESSING | |
# ---------------------------------------------------------------------- | |
def process_batch_with_gpu( | |
contexts: List[ProcessingContext], | |
batch_size: int = 4 | |
) -> List[ProcessingContext]: | |
total_contexts = len(contexts) | |
processed = 0 | |
for i in range(0, total_contexts, batch_size): | |
batch = contexts[i:i + batch_size] | |
batch_nums = f"{i+1}-{min(i+batch_size, total_contexts)}/{total_contexts}" | |
logging.info( | |
f"{EMOJI_MAP['PROCESSING']} Processing batch {batch_nums}" | |
) | |
for ctx in batch: | |
if ctx.skip_processing: | |
continue | |
try: | |
yield ctx | |
processed += 1 | |
except Exception as e: | |
logging.error( | |
f"{EMOJI_MAP['ERROR']} Error processing {ctx.url}: {str(e)}" | |
) | |
ctx.error = str(e) | |
ctx.skip_processing = True | |
if torch.cuda.is_available() and i + batch_size < total_contexts: | |
clear_gpu_memory() | |
logging.info( | |
f"{EMOJI_MAP['SUCCESS']} Processed {processed}/{total_contexts} images" | |
) | |
# ---------------------------------------------------------------------- | |
# GPU DECORATORS | |
# ---------------------------------------------------------------------- | |
def with_gpu_memory_management(func): | |
def wrapper(*args, **kwargs): | |
try: | |
clear_gpu_memory() | |
result = func(*args, **kwargs) | |
return result | |
finally: | |
clear_gpu_memory() | |
wrapper.__name__ = func.__name__ | |
return wrapper | |