Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,034 Bytes
18faf97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# ----------------------------------------------------------------------
# GPU PROCESSING FUNCTIONS
# ----------------------------------------------------------------------
import logging
import torch
import gc
from typing import List, Dict, Any
from src.utils import (
ProcessingContext,
cleanup_memory,
LOG_LEVEL_MAP,
EMOJI_MAP
)
# ----------------------------------------------------------------------
# GPU MEMORY MANAGEMENT
# ----------------------------------------------------------------------
def clear_gpu_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
logging.info(
f"{EMOJI_MAP['INFO']} GPU memory cleared - "
f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f}MB, "
f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f}MB"
)
def get_gpu_memory_info() -> Dict[str, float]:
if not torch.cuda.is_available():
return {"available": False}
return {
"available": True,
"allocated_mb": torch.cuda.memory_allocated() / 1024**2,
"reserved_mb": torch.cuda.memory_reserved() / 1024**2,
"max_allocated_mb": torch.cuda.max_memory_allocated() / 1024**2
}
# ----------------------------------------------------------------------
# BATCH PROCESSING
# ----------------------------------------------------------------------
def process_batch_with_gpu(
contexts: List[ProcessingContext],
batch_size: int = 4
) -> List[ProcessingContext]:
total_contexts = len(contexts)
processed = 0
for i in range(0, total_contexts, batch_size):
batch = contexts[i:i + batch_size]
batch_nums = f"{i+1}-{min(i+batch_size, total_contexts)}/{total_contexts}"
logging.info(
f"{EMOJI_MAP['PROCESSING']} Processing batch {batch_nums}"
)
for ctx in batch:
if ctx.skip_processing:
continue
try:
yield ctx
processed += 1
except Exception as e:
logging.error(
f"{EMOJI_MAP['ERROR']} Error processing {ctx.url}: {str(e)}"
)
ctx.error = str(e)
ctx.skip_processing = True
if torch.cuda.is_available() and i + batch_size < total_contexts:
clear_gpu_memory()
logging.info(
f"{EMOJI_MAP['SUCCESS']} Processed {processed}/{total_contexts} images"
)
# ----------------------------------------------------------------------
# GPU DECORATORS
# ----------------------------------------------------------------------
def with_gpu_memory_management(func):
def wrapper(*args, **kwargs):
try:
clear_gpu_memory()
result = func(*args, **kwargs)
return result
finally:
clear_gpu_memory()
wrapper.__name__ = func.__name__
return wrapper
|