Spaces:
Running
on
Zero
Running
on
Zero
# ---------------------------------------------------------------------- | |
# IMPORTS | |
# ---------------------------------------------------------------------- | |
import os | |
import sys | |
import logging | |
import threading | |
import torch | |
import warnings | |
import time | |
# Suppress the model loading warnings about non-meta parameters | |
warnings.filterwarnings("ignore", message=".*copying from a non-meta parameter.*") | |
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") | |
# Add parent directory to path for imports during deployment | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
parent_dir = os.path.dirname(os.path.dirname(current_dir)) | |
if parent_dir not in sys.path: | |
sys.path.insert(0, parent_dir) | |
from transformers import ( | |
AutoProcessor, | |
AutoImageProcessor, | |
AutoModelForObjectDetection, | |
DetrImageProcessor, | |
DetrForObjectDetection, | |
AutoModelForImageSegmentation, | |
YolosImageProcessor, | |
YolosForObjectDetection | |
) | |
# ---------------------------------------------------------------------- | |
# HARDWARE CONFIGURATION | |
# ---------------------------------------------------------------------- | |
def setup_device(): | |
if os.getenv("SPACE_ID"): | |
return "cpu" | |
elif torch.cuda.is_available(): | |
device_count = torch.cuda.device_count() | |
if device_count >= 1: | |
return "cuda" | |
return "cpu" | |
def check_cuda_availability(): | |
if os.getenv("SPACE_ID"): | |
logging.info("Running in Hugging Face Spaces (Zero GPU) - GPU will be available in decorated functions") | |
return False | |
if not torch.cuda.is_available(): | |
logging.warning("\n" + "="*60 + "\n" + | |
"WARNING: CUDA NOT AVAILABLE!\n" + | |
"Running on CPU. Performance will be significantly reduced.\n" + | |
"="*60 + "\n") | |
return False | |
device_count = torch.cuda.device_count() | |
if device_count > 0: | |
for i in range(device_count): | |
props = torch.cuda.get_device_properties(i) | |
logging.info(f"GPU {i}: {props.name} (Memory: {props.total_memory / (1024**3):.1f} GB)") | |
else: | |
logging.info("CUDA available but no GPUs detected") | |
return True | |
def check_hardware_environment(): | |
gpu_available = check_cuda_availability() | |
if os.getenv("SPACE_ID"): | |
ensure_zerogpu() | |
else: | |
if gpu_available: | |
logging.info(f"Running on {setup_device().upper()}") | |
else: | |
logging.info("Running on CPU") | |
# ---------------------------------------------------------------------- | |
# ZERO GPU CONFIGURATION | |
# ---------------------------------------------------------------------- | |
def ensure_zerogpu(): | |
space_id = os.getenv("SPACE_ID") | |
hf_token = os.getenv("HF_TOKEN") | |
if not space_id: | |
logging.info("Not running in Hugging Face Spaces") | |
return | |
try: | |
from huggingface_hub import HfApi | |
api = HfApi(token=hf_token) if hf_token else HfApi() | |
space_info = api.get_space_runtime(space_id) | |
current_hardware = getattr(space_info, 'hardware', None) | |
logging.info(f"Current space hardware: {current_hardware}") | |
if current_hardware and "a10g" not in current_hardware.lower(): | |
logging.warning(f"Space is running on {current_hardware}, not zero-a10g") | |
if hf_token: | |
try: | |
api.request_space_hardware(repo_id=space_id, hardware="zero-a10g") | |
logging.info("Requested hardware change to zero-a10g") | |
except Exception as e: | |
logging.error(f"Failed to request hardware change: {e}") | |
else: | |
logging.warning("Cannot request hardware change without HF_TOKEN") | |
else: | |
logging.info("Space is already running on zero-a10g") | |
except ImportError: | |
logging.warning("huggingface_hub not available, cannot verify space hardware") | |
except Exception as e: | |
logging.error(f"Unexpected error in ensure_zerogpu: {str(e)}") | |
DEVICE = setup_device() | |
# ---------------------------------------------------------------------- | |
# MODEL PRECISION SETTINGS | |
# ---------------------------------------------------------------------- | |
RTDETR_FULL_PRECISION = True | |
HEAD_DETECTION_FULL_PRECISION = True | |
RMBG_FULL_PRECISION = True | |
YOLOS_FASHIONPEDIA_FULL_PRECISION = True | |
# ---------------------------------------------------------------------- | |
# OPTIMIZATION SETTINGS | |
# ---------------------------------------------------------------------- | |
USE_TORCH_COMPILE = True | |
TORCH_COMPILE_MODE = "reduce-overhead" | |
TORCH_COMPILE_BACKEND = "inductor" | |
ENABLE_CHANNELS_LAST = True | |
ENABLE_CUDA_GRAPHS = True | |
USE_MIXED_PRECISION = True | |
# ---------------------------------------------------------------------- | |
# MODEL REPOSITORY IDENTIFIERS | |
# ---------------------------------------------------------------------- | |
RTDETR_REPO = "PekingU/rtdetr_r50vd" | |
HEAD_DETECTION_REPO = "sanali209/DT_face_head_char" | |
RMBG_REPO = "briaai/RMBG-2.0" | |
YOLOS_FASHIONPEDIA_REPO = "valentinafeve/yolos-fashionpedia" | |
# ---------------------------------------------------------------------- | |
# BIREFNET CONFIGURATION | |
# ---------------------------------------------------------------------- | |
BIREFNET_CONFIG_PYTHON_TEMPLATE = """from transformers.configuration_utils import PretrainedConfig | |
class BiRefNetConfig(PretrainedConfig): | |
model_type = "SegformerForSemanticSegmentation" | |
num_channels = 3 | |
backbone = "mit_b5" | |
hidden_size = 768 | |
num_hidden_layers = 12 | |
num_attention_heads = 12 | |
bb_pretrained = False | |
""" | |
BIREFNET_CONFIG_JSON = """{ | |
"_name_or_path": "briaai/RMBG-2.0", | |
"architectures": ["BiRefNet"], | |
"auto_map": { | |
"AutoConfig": "BiRefNet_config.BiRefNetConfig", | |
"AutoModelForImageSegmentation": "birefnet.BiRefNet" | |
}, | |
"bb_pretrained": false | |
}""" | |
BIREFNET_CONFIG_FILES = { | |
"BiRefNet_config.py": BIREFNET_CONFIG_PYTHON_TEMPLATE, | |
"config.json": BIREFNET_CONFIG_JSON | |
} | |
BIREFNET_DOWNLOAD_FILES = ["birefnet.py", "preprocessor_config.json"] | |
BIREFNET_WEIGHT_FILES = ["model.safetensors", "pytorch_model.bin"] | |
DEFAULT_LOCAL_RMBG_DIR = "models/rmbg2" | |
# ---------------------------------------------------------------------- | |
# ERROR MESSAGES | |
# ---------------------------------------------------------------------- | |
ERROR_NO_HF_TOKEN = "HF_TOKEN environment variable not set. Please set it in your Space secrets." | |
ERROR_ACCESS_DENIED = "Access denied to RMBG-2.0. Please request access at https://huggingface.co/briaai/RMBG-2.0 and try again." | |
ERROR_AUTH_FAILED = "Authentication failed. Please set HF_TOKEN environment variable." | |
# ---------------------------------------------------------------------- | |
# GLOBAL MODEL INSTANCES | |
# ---------------------------------------------------------------------- | |
RTDETR_PROCESSOR = None | |
RTDETR_MODEL = None | |
HEAD_PROCESSOR = None | |
HEAD_MODEL = None | |
RMBG_MODEL = None | |
YOLOS_PROCESSOR = None | |
YOLOS_MODEL = None | |
# ---------------------------------------------------------------------- | |
# GLOBAL STATE VARIABLES | |
# ---------------------------------------------------------------------- | |
MODELS_LOADED = False | |
LOAD_ERROR = "" | |
LOAD_LOCK = threading.Lock() | |
# ---------------------------------------------------------------------- | |
# MODEL LOADING WORKAROUNDS FOR SPACES ENVIRONMENT | |
# ---------------------------------------------------------------------- | |
def patch_spaces_device_handling(): | |
try: | |
import spaces.zero.torch.patching as spaces_patching | |
original_untyped_storage_new = spaces_patching._untyped_storage_new_register | |
def patched_untyped_storage_new_register(storage_cls): | |
def wrapper(*args, **kwargs): | |
device = kwargs.get('device') | |
if device is not None and isinstance(device, str): | |
kwargs['device'] = torch.device(device) | |
return original_untyped_storage_new(storage_cls)(*args, **kwargs) | |
return wrapper | |
spaces_patching._untyped_storage_new_register = patched_untyped_storage_new_register | |
logging.info("Successfully patched spaces device handling") | |
return True | |
except Exception as e: | |
logging.debug(f"Spaces patching not available or failed: {e}") | |
return False | |
def is_spaces_environment(): | |
return os.getenv("SPACE_ID") is not None or "spaces" in sys.modules | |
# ---------------------------------------------------------------------- | |
# BIREFNET FILE MANAGEMENT | |
# ---------------------------------------------------------------------- | |
def create_config_files(local_dir: str) -> None: | |
os.makedirs(local_dir, exist_ok=True) | |
for filename, content in BIREFNET_CONFIG_FILES.items(): | |
file_path = os.path.join(local_dir, filename) | |
if not os.path.exists(file_path): | |
with open(file_path, "w") as f: | |
f.write(content) | |
logging.info(f"Created {filename} in {local_dir}") | |
def download_birefnet_files(local_dir: str, token: str) -> None: | |
from huggingface_hub import hf_hub_download | |
for file in BIREFNET_DOWNLOAD_FILES: | |
file_path = os.path.join(local_dir, file) | |
if not os.path.exists(file_path): | |
try: | |
hf_hub_download( | |
repo_id=RMBG_REPO, | |
filename=file, | |
token=token, | |
local_dir=local_dir, | |
local_dir_use_symlinks=False | |
) | |
logging.info(f"Downloaded {file} to {local_dir}") | |
except Exception as e: | |
logging.error(f"Failed to download {file}: {e}") | |
raise RuntimeError(f"Failed to download {file} from {RMBG_REPO}") | |
def download_model_weights(local_dir: str, token: str) -> None: | |
from huggingface_hub import hf_hub_download | |
weights_exist = any( | |
os.path.exists(os.path.join(local_dir, weight_file)) | |
for weight_file in BIREFNET_WEIGHT_FILES | |
) | |
if weights_exist: | |
return | |
try: | |
hf_hub_download( | |
repo_id=RMBG_REPO, | |
filename="model.safetensors", | |
token=token, | |
local_dir=local_dir, | |
local_dir_use_symlinks=False | |
) | |
logging.info(f"Downloaded model.safetensors to {local_dir}") | |
return | |
except Exception as e: | |
logging.warning(f"Failed to download model.safetensors: {e}") | |
try: | |
hf_hub_download( | |
repo_id=RMBG_REPO, | |
filename="pytorch_model.bin", | |
token=token, | |
local_dir=local_dir, | |
local_dir_use_symlinks=False | |
) | |
logging.info(f"Downloaded pytorch_model.bin to {local_dir}") | |
except Exception as e: | |
logging.error(f"Failed to download pytorch_model.bin: {e}") | |
raise RuntimeError(f"Failed to download model weights from {RMBG_REPO}") | |
def ensure_birefnet_files(local_dir: str, token: str) -> None: | |
create_config_files(local_dir) | |
download_birefnet_files(local_dir, token) | |
download_model_weights(local_dir, token) | |
def ensure_models_loaded() -> None: | |
global MODELS_LOADED, LOAD_ERROR | |
if not MODELS_LOADED: | |
if is_spaces_environment(): | |
# ---------------------------------------------------------------------- | |
# ZERO GPU MODEL LOADING: 1. Models NOT loaded at startup | |
# ---------------------------------------------------------------------- | |
time.sleep(1) | |
print("="*70) | |
print("ZERO GPU MODEL LOADING: 1. Models NOT loaded at startup") | |
print("="*70) | |
logging.info("ZERO GPU MODEL LOADING: Models NOT loaded at startup") | |
logging.info("ZERO GPU MODEL LOADING: Models will be loaded on-demand in GPU context") | |
return | |
with LOAD_LOCK: | |
if not MODELS_LOADED: | |
if LOAD_ERROR: | |
raise RuntimeError(f"Models failed to load: {LOAD_ERROR}") | |
try: | |
load_models() | |
except Exception as e: | |
LOAD_ERROR = str(e) | |
raise | |
# ---------------------------------------------------------------------- | |
# MODEL LOADING WITH PRECISION | |
# ---------------------------------------------------------------------- | |
def load_model_with_precision(model_class, repo_id: str, full_precision: bool, device_map: bool = True, trust_remote_code: bool = False): | |
global DEVICE | |
try: | |
spaces_env = is_spaces_environment() | |
if spaces_env: | |
torch_device = torch.device("cpu") | |
patch_spaces_device_handling() | |
else: | |
if DEVICE == "cuda": | |
torch.cuda.empty_cache() | |
torch_device = torch.device(DEVICE) | |
load_kwargs = { | |
"torch_dtype": torch.float32 if full_precision else torch.float16, | |
"trust_remote_code": trust_remote_code, | |
"low_cpu_mem_usage": True, | |
"use_safetensors": True | |
} | |
if spaces_env: | |
load_kwargs["device_map"] = None | |
elif DEVICE == "cuda" and device_map and torch.cuda.device_count() > 1: | |
load_kwargs["device_map"] = "auto" | |
try: | |
model = model_class.from_pretrained(repo_id, **load_kwargs) | |
if not spaces_env and not hasattr(model, 'hf_device_map'): | |
model = model.to(torch_device) | |
if not full_precision and DEVICE == "cuda": | |
model = model.half() | |
except (ValueError, RuntimeError, OSError, UnicodeDecodeError) as e: | |
logging.warning(f"Failed to load model with initial configuration: {e}") | |
if "Unable to load weights from pytorch checkpoint" in str(e) or "UnicodeDecodeError" in str(e): | |
logging.info(f"Attempting to clear cache and retry for {repo_id}") | |
try: | |
from huggingface_hub import scan_cache_dir | |
cache_info = scan_cache_dir() | |
for repo in cache_info.repos: | |
if repo_id.replace("/", "--") in repo.repo_id: | |
repo.delete() | |
logging.info(f"Cleared cache for {repo_id}") | |
break | |
except Exception as cache_e: | |
logging.warning(f"Cache clearing failed: {cache_e}") | |
try: | |
load_kwargs_retry = { | |
"torch_dtype": torch.float32, | |
"trust_remote_code": trust_remote_code, | |
"force_download": True, | |
"device_map": None, | |
"low_cpu_mem_usage": True | |
} | |
model = model_class.from_pretrained(repo_id, **load_kwargs_retry) | |
model = model.to(torch_device) | |
except Exception as retry_e: | |
logging.warning(f"Retry with force_download failed: {retry_e}") | |
try: | |
load_kwargs_tf = { | |
"from_tf": True, | |
"torch_dtype": torch.float32, | |
"trust_remote_code": trust_remote_code, | |
"device_map": None, | |
"low_cpu_mem_usage": True | |
} | |
model = model_class.from_pretrained(repo_id, **load_kwargs_tf) | |
model = model.to(torch_device) | |
logging.info(f"Successfully loaded {repo_id} from TensorFlow checkpoint") | |
except Exception as tf_e: | |
logging.warning(f"TensorFlow fallback failed: {tf_e}") | |
try: | |
load_kwargs_basic = { | |
"torch_dtype": torch.float32, | |
"trust_remote_code": trust_remote_code, | |
"device_map": None, | |
"use_safetensors": False, | |
"local_files_only": False | |
} | |
model = model_class.from_pretrained(repo_id, **load_kwargs_basic) | |
model = model.to(torch_device) | |
logging.info(f"Successfully loaded {repo_id} with basic configuration") | |
except Exception as basic_e: | |
logging.error(f"All fallback strategies failed for {repo_id}: {basic_e}") | |
raise RuntimeError(f"Unable to load model {repo_id} after all retry attempts: {basic_e}") | |
else: | |
load_kwargs_fallback = { | |
"torch_dtype": torch.float32, | |
"trust_remote_code": trust_remote_code, | |
"device_map": None | |
} | |
model = model_class.from_pretrained(repo_id, **load_kwargs_fallback) | |
model = model.to(torch_device) | |
model.eval() | |
if not spaces_env: | |
with torch.no_grad(): | |
logging.info(f"Verifying model {repo_id} is on correct device") | |
param = next(model.parameters()) | |
if DEVICE == "cuda" and not param.is_cuda: | |
model = model.to(torch_device) | |
logging.warning(f"Forced model {repo_id} to {DEVICE}") | |
logging.info(f"Model {repo_id} device: {param.device}") | |
else: | |
logging.info(f"Model {repo_id} loaded on CPU (Zero GPU environment)") | |
return model | |
except Exception as e: | |
logging.error(f"Failed to load model from {repo_id} on {DEVICE}: {e}") | |
raise | |
def handle_rmbg_access_error(error_msg: str) -> None: | |
if "403" in error_msg and "gated repo" in error_msg: | |
logging.error("\n" + "="*60 + "\n" | |
"ERROR: Access denied to RMBG-2.0 model!\n" | |
"You need to request access at: https://huggingface.co/briaai/RMBG-2.0\n" + | |
"="*60 + "\n") | |
raise RuntimeError(ERROR_ACCESS_DENIED) | |
elif "401" in error_msg: | |
logging.error("\n" + "="*60 + "\n" | |
"ERROR: Authentication failed!\n" | |
"Please set your HF_TOKEN environment variable.\n" + | |
"="*60 + "\n") | |
raise RuntimeError(ERROR_AUTH_FAILED) | |
else: | |
raise RuntimeError(error_msg) | |
# ---------------------------------------------------------------------- | |
# INDIVIDUAL MODEL LOADING FUNCTIONS | |
# ---------------------------------------------------------------------- | |
def load_rtdetr_model() -> None: | |
global RTDETR_PROCESSOR, RTDETR_MODEL | |
logging.info("Loading RT-DETR model...") | |
RTDETR_PROCESSOR = AutoProcessor.from_pretrained(RTDETR_REPO) | |
RTDETR_MODEL = load_model_with_precision( | |
AutoModelForObjectDetection, | |
RTDETR_REPO, | |
RTDETR_FULL_PRECISION, | |
device_map=False | |
) | |
logging.info("RT-DETR model loaded successfully") | |
def load_head_detection_model() -> None: | |
global HEAD_PROCESSOR, HEAD_MODEL | |
logging.info("Loading Head Detection model...") | |
HEAD_PROCESSOR = AutoImageProcessor.from_pretrained(HEAD_DETECTION_REPO) | |
HEAD_MODEL = load_model_with_precision( | |
DetrForObjectDetection, | |
HEAD_DETECTION_REPO, | |
HEAD_DETECTION_FULL_PRECISION, | |
device_map=False | |
) | |
logging.info("Head Detection model loaded successfully") | |
def load_rmbg_model() -> None: | |
global RMBG_MODEL | |
logging.info("Loading RMBG model...") | |
token = os.getenv("HF_TOKEN", "") | |
if not token: | |
logging.error(ERROR_NO_HF_TOKEN) | |
logging.warning("RMBG model requires HF_TOKEN. Skipping RMBG model loading...") | |
RMBG_MODEL = None | |
return | |
local_dir = DEFAULT_LOCAL_RMBG_DIR | |
try: | |
ensure_birefnet_files(local_dir, token) | |
except RuntimeError as e: | |
handle_rmbg_access_error(str(e)) | |
os.environ["HF_HOME"] = os.path.dirname(local_dir) | |
try: | |
RMBG_MODEL = load_model_with_precision( | |
AutoModelForImageSegmentation, | |
local_dir, | |
RMBG_FULL_PRECISION, | |
trust_remote_code=True, | |
device_map=False | |
) | |
if USE_TORCH_COMPILE and DEVICE == "cuda": | |
try: | |
RMBG_MODEL = torch.compile( | |
RMBG_MODEL, | |
mode=TORCH_COMPILE_MODE, | |
backend=TORCH_COMPILE_BACKEND, | |
fullgraph=False, | |
dynamic=False | |
) | |
logging.info(f"RMBG model compiled with mode={TORCH_COMPILE_MODE}, backend={TORCH_COMPILE_BACKEND}") | |
except Exception as e: | |
logging.warning(f"Failed to compile RMBG model: {e}") | |
logging.info("RMBG-2.0 model loaded successfully from local directory") | |
except Exception as e: | |
error_msg = str(e) | |
handle_rmbg_access_error(error_msg) | |
def load_yolos_fashionpedia_model() -> None: | |
global YOLOS_PROCESSOR, YOLOS_MODEL | |
logging.info("Loading YOLOS FashionPedia model...") | |
try: | |
YOLOS_PROCESSOR = AutoImageProcessor.from_pretrained( | |
YOLOS_FASHIONPEDIA_REPO, | |
size={"height": 512, "width": 512} | |
) | |
except Exception: | |
logging.warning("Failed to set custom size for YOLOS processor, using default") | |
YOLOS_PROCESSOR = AutoImageProcessor.from_pretrained(YOLOS_FASHIONPEDIA_REPO) | |
YOLOS_MODEL = load_model_with_precision( | |
YolosForObjectDetection, | |
YOLOS_FASHIONPEDIA_REPO, | |
YOLOS_FASHIONPEDIA_FULL_PRECISION, | |
device_map=False | |
) | |
logging.info("YOLOS FashionPedia model loaded successfully") | |
# ---------------------------------------------------------------------- | |
# MAIN MODEL LOADING FUNCTION | |
# ---------------------------------------------------------------------- | |
def load_models() -> None: | |
global MODELS_LOADED, LOAD_ERROR | |
with LOAD_LOCK: | |
if MODELS_LOADED: | |
logging.info("Models already loaded") | |
return | |
# Skip the ZERO GPU step 2 print here as it's already shown in test execution flow | |
if is_spaces_environment(): | |
logging.info("ZERO GPU MODEL LOADING: User request triggered model loading") | |
check_hardware_environment() | |
models_status = { | |
"rtdetr": False, | |
"head_detection": False, | |
"rmbg": False, | |
"yolos": False | |
} | |
critical_errors = [] | |
try: | |
load_rtdetr_model() | |
models_status["rtdetr"] = True | |
except Exception as e: | |
critical_errors.append(f"RT-DETR: {str(e)}") | |
logging.error(f"Failed to load RT-DETR model: {e}") | |
try: | |
load_head_detection_model() | |
models_status["head_detection"] = True | |
except Exception as e: | |
critical_errors.append(f"Head Detection: {str(e)}") | |
logging.error(f"Failed to load Head Detection model: {e}") | |
try: | |
load_rmbg_model() | |
models_status["rmbg"] = True if RMBG_MODEL is not None else False | |
except Exception as e: | |
logging.warning(f"Failed to load RMBG model: {e}") | |
models_status["rmbg"] = False | |
try: | |
load_yolos_fashionpedia_model() | |
models_status["yolos"] = True | |
except Exception as e: | |
critical_errors.append(f"YOLOS: {str(e)}") | |
logging.error(f"Failed to load YOLOS model: {e}") | |
if models_status["rtdetr"] or models_status["yolos"]: | |
MODELS_LOADED = True | |
LOAD_ERROR = "" | |
loaded = [k for k, v in models_status.items() if v] | |
failed = [k for k, v in models_status.items() if not v] | |
logging.info(f"Models loaded: {', '.join(loaded)}") | |
if failed: | |
logging.warning(f"Models failed: {', '.join(failed)}") | |
else: | |
error_msg = "Failed to load critical models. " + "; ".join(critical_errors) | |
logging.error(error_msg) | |
LOAD_ERROR = error_msg | |
raise RuntimeError(error_msg) | |
# ---------------------------------------------------------------------- | |
# MOVE MODELS TO GPU FUNCTION | |
# ---------------------------------------------------------------------- | |
def move_models_to_gpu(): | |
global RMBG_MODEL, RTDETR_PROCESSOR, RTDETR_MODEL, HEAD_MODEL, YOLOS_PROCESSOR, YOLOS_MODEL, DEVICE | |
if not torch.cuda.is_available(): | |
logging.warning("CUDA not available, cannot move models to GPU") | |
return | |
original_device = DEVICE | |
DEVICE = "cuda" | |
try: | |
if RMBG_MODEL is not None: | |
logging.info("Moving RMBG model to GPU...") | |
RMBG_MODEL = RMBG_MODEL.to("cuda") | |
if not RMBG_FULL_PRECISION: | |
RMBG_MODEL = RMBG_MODEL.half() | |
logging.info("RMBG model moved to GPU") | |
if RTDETR_MODEL is not None: | |
logging.info("Moving RT-DETR model to GPU...") | |
RTDETR_MODEL = RTDETR_MODEL.to("cuda") | |
if not RTDETR_FULL_PRECISION: | |
RTDETR_MODEL = RTDETR_MODEL.half() | |
logging.info("RT-DETR model moved to GPU") | |
if HEAD_MODEL is not None: | |
logging.info("Moving Head Detection model to GPU...") | |
HEAD_MODEL = HEAD_MODEL.to("cuda") | |
if not HEAD_DETECTION_FULL_PRECISION: | |
HEAD_MODEL = HEAD_MODEL.half() | |
logging.info("Head Detection model moved to GPU") | |
if YOLOS_MODEL is not None: | |
logging.info("Moving YOLOS model to GPU...") | |
YOLOS_MODEL = YOLOS_MODEL.to("cuda") | |
if not YOLOS_FASHIONPEDIA_FULL_PRECISION: | |
YOLOS_MODEL = YOLOS_MODEL.half() | |
logging.info("YOLOS model moved to GPU") | |
logging.info("All models moved to GPU successfully") | |
except Exception as e: | |
logging.error(f"Failed to move models to GPU: {e}") | |
DEVICE = original_device | |
raise | |