|
|
|
|
|
|
|
import os |
|
import sys |
|
import logging |
|
import threading |
|
import torch |
|
import warnings |
|
import time |
|
|
|
|
|
warnings.filterwarnings("ignore", message=".*copying from a non-meta parameter.*") |
|
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
parent_dir = os.path.dirname(os.path.dirname(current_dir)) |
|
if parent_dir not in sys.path: |
|
sys.path.insert(0, parent_dir) |
|
|
|
from transformers import ( |
|
AutoProcessor, |
|
AutoImageProcessor, |
|
AutoModelForObjectDetection, |
|
DetrImageProcessor, |
|
DetrForObjectDetection, |
|
AutoModelForImageSegmentation, |
|
YolosImageProcessor, |
|
YolosForObjectDetection |
|
) |
|
|
|
|
|
|
|
|
|
def setup_device(): |
|
if os.getenv("SPACE_ID"): |
|
return "cpu" |
|
elif torch.cuda.is_available(): |
|
device_count = torch.cuda.device_count() |
|
if device_count >= 1: |
|
return "cuda" |
|
|
|
return "cpu" |
|
|
|
def check_cuda_availability(): |
|
if os.getenv("SPACE_ID"): |
|
logging.info("Running in Hugging Face Spaces (Zero GPU) - GPU will be available in decorated functions") |
|
return False |
|
|
|
if not torch.cuda.is_available(): |
|
logging.warning("\n" + "="*60 + "\n" + |
|
"WARNING: CUDA NOT AVAILABLE!\n" + |
|
"Running on CPU. Performance will be significantly reduced.\n" + |
|
"="*60 + "\n") |
|
return False |
|
|
|
device_count = torch.cuda.device_count() |
|
if device_count > 0: |
|
for i in range(device_count): |
|
props = torch.cuda.get_device_properties(i) |
|
logging.info(f"GPU {i}: {props.name} (Memory: {props.total_memory / (1024**3):.1f} GB)") |
|
else: |
|
logging.info("CUDA available but no GPUs detected") |
|
return True |
|
|
|
def check_hardware_environment(): |
|
gpu_available = check_cuda_availability() |
|
|
|
if os.getenv("SPACE_ID"): |
|
ensure_zerogpu() |
|
else: |
|
if gpu_available: |
|
logging.info(f"Running on {setup_device().upper()}") |
|
else: |
|
logging.info("Running on CPU") |
|
|
|
|
|
|
|
|
|
def ensure_zerogpu(): |
|
space_id = os.getenv("SPACE_ID") |
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
if not space_id: |
|
logging.info("Not running in Hugging Face Spaces") |
|
return |
|
|
|
try: |
|
from huggingface_hub import HfApi |
|
|
|
api = HfApi(token=hf_token) if hf_token else HfApi() |
|
space_info = api.get_space_runtime(space_id) |
|
|
|
current_hardware = getattr(space_info, 'hardware', None) |
|
logging.info(f"Current space hardware: {current_hardware}") |
|
|
|
if current_hardware and "a10g" not in current_hardware.lower(): |
|
logging.warning(f"Space is running on {current_hardware}, not zero-a10g") |
|
|
|
if hf_token: |
|
try: |
|
api.request_space_hardware(repo_id=space_id, hardware="zero-a10g") |
|
logging.info("Requested hardware change to zero-a10g") |
|
except Exception as e: |
|
logging.error(f"Failed to request hardware change: {e}") |
|
else: |
|
logging.warning("Cannot request hardware change without HF_TOKEN") |
|
else: |
|
logging.info("Space is already running on zero-a10g") |
|
|
|
except ImportError: |
|
logging.warning("huggingface_hub not available, cannot verify space hardware") |
|
except Exception as e: |
|
logging.error(f"Unexpected error in ensure_zerogpu: {str(e)}") |
|
|
|
DEVICE = setup_device() |
|
|
|
|
|
|
|
|
|
RTDETR_FULL_PRECISION = True |
|
HEAD_DETECTION_FULL_PRECISION = True |
|
RMBG_FULL_PRECISION = True |
|
YOLOS_FASHIONPEDIA_FULL_PRECISION = True |
|
|
|
|
|
|
|
|
|
USE_TORCH_COMPILE = True |
|
TORCH_COMPILE_MODE = "reduce-overhead" |
|
TORCH_COMPILE_BACKEND = "inductor" |
|
ENABLE_CHANNELS_LAST = True |
|
ENABLE_CUDA_GRAPHS = True |
|
USE_MIXED_PRECISION = True |
|
|
|
|
|
|
|
|
|
RTDETR_REPO = "PekingU/rtdetr_r50vd" |
|
HEAD_DETECTION_REPO = "sanali209/DT_face_head_char" |
|
RMBG_REPO = "briaai/RMBG-2.0" |
|
YOLOS_FASHIONPEDIA_REPO = "valentinafeve/yolos-fashionpedia" |
|
|
|
|
|
|
|
|
|
BIREFNET_CONFIG_PYTHON_TEMPLATE = """from transformers.configuration_utils import PretrainedConfig |
|
|
|
class BiRefNetConfig(PretrainedConfig): |
|
model_type = "SegformerForSemanticSegmentation" |
|
num_channels = 3 |
|
backbone = "mit_b5" |
|
hidden_size = 768 |
|
num_hidden_layers = 12 |
|
num_attention_heads = 12 |
|
bb_pretrained = False |
|
""" |
|
|
|
BIREFNET_CONFIG_JSON = """{ |
|
"_name_or_path": "briaai/RMBG-2.0", |
|
"architectures": ["BiRefNet"], |
|
"auto_map": { |
|
"AutoConfig": "BiRefNet_config.BiRefNetConfig", |
|
"AutoModelForImageSegmentation": "birefnet.BiRefNet" |
|
}, |
|
"bb_pretrained": false |
|
}""" |
|
|
|
BIREFNET_CONFIG_FILES = { |
|
"BiRefNet_config.py": BIREFNET_CONFIG_PYTHON_TEMPLATE, |
|
"config.json": BIREFNET_CONFIG_JSON |
|
} |
|
|
|
BIREFNET_DOWNLOAD_FILES = ["birefnet.py", "preprocessor_config.json"] |
|
BIREFNET_WEIGHT_FILES = ["model.safetensors", "pytorch_model.bin"] |
|
DEFAULT_LOCAL_RMBG_DIR = "models/rmbg2" |
|
|
|
|
|
|
|
|
|
ERROR_NO_HF_TOKEN = "HF_TOKEN environment variable not set. Please set it in your Space secrets." |
|
ERROR_ACCESS_DENIED = "Access denied to RMBG-2.0. Please request access at https://huggingface.co/briaai/RMBG-2.0 and try again." |
|
ERROR_AUTH_FAILED = "Authentication failed. Please set HF_TOKEN environment variable." |
|
|
|
|
|
|
|
|
|
RTDETR_PROCESSOR = None |
|
RTDETR_MODEL = None |
|
HEAD_PROCESSOR = None |
|
HEAD_MODEL = None |
|
RMBG_MODEL = None |
|
YOLOS_PROCESSOR = None |
|
YOLOS_MODEL = None |
|
|
|
|
|
|
|
|
|
MODELS_LOADED = False |
|
LOAD_ERROR = "" |
|
LOAD_LOCK = threading.Lock() |
|
|
|
|
|
|
|
|
|
def patch_spaces_device_handling(): |
|
try: |
|
import spaces.zero.torch.patching as spaces_patching |
|
original_untyped_storage_new = spaces_patching._untyped_storage_new_register |
|
|
|
def patched_untyped_storage_new_register(storage_cls): |
|
def wrapper(*args, **kwargs): |
|
device = kwargs.get('device') |
|
if device is not None and isinstance(device, str): |
|
kwargs['device'] = torch.device(device) |
|
return original_untyped_storage_new(storage_cls)(*args, **kwargs) |
|
return wrapper |
|
|
|
spaces_patching._untyped_storage_new_register = patched_untyped_storage_new_register |
|
logging.info("Successfully patched spaces device handling") |
|
return True |
|
except Exception as e: |
|
logging.debug(f"Spaces patching not available or failed: {e}") |
|
return False |
|
|
|
def is_spaces_environment(): |
|
return os.getenv("SPACE_ID") is not None or "spaces" in sys.modules |
|
|
|
|
|
|
|
|
|
def create_config_files(local_dir: str) -> None: |
|
os.makedirs(local_dir, exist_ok=True) |
|
|
|
for filename, content in BIREFNET_CONFIG_FILES.items(): |
|
file_path = os.path.join(local_dir, filename) |
|
if not os.path.exists(file_path): |
|
with open(file_path, "w") as f: |
|
f.write(content) |
|
logging.info(f"Created {filename} in {local_dir}") |
|
|
|
def download_birefnet_files(local_dir: str, token: str) -> None: |
|
from huggingface_hub import hf_hub_download |
|
|
|
for file in BIREFNET_DOWNLOAD_FILES: |
|
file_path = os.path.join(local_dir, file) |
|
if not os.path.exists(file_path): |
|
try: |
|
hf_hub_download( |
|
repo_id=RMBG_REPO, |
|
filename=file, |
|
token=token, |
|
local_dir=local_dir, |
|
local_dir_use_symlinks=False |
|
) |
|
logging.info(f"Downloaded {file} to {local_dir}") |
|
except Exception as e: |
|
logging.error(f"Failed to download {file}: {e}") |
|
raise RuntimeError(f"Failed to download {file} from {RMBG_REPO}") |
|
|
|
def download_model_weights(local_dir: str, token: str) -> None: |
|
from huggingface_hub import hf_hub_download |
|
|
|
weights_exist = any( |
|
os.path.exists(os.path.join(local_dir, weight_file)) |
|
for weight_file in BIREFNET_WEIGHT_FILES |
|
) |
|
|
|
if weights_exist: |
|
return |
|
|
|
try: |
|
hf_hub_download( |
|
repo_id=RMBG_REPO, |
|
filename="model.safetensors", |
|
token=token, |
|
local_dir=local_dir, |
|
local_dir_use_symlinks=False |
|
) |
|
logging.info(f"Downloaded model.safetensors to {local_dir}") |
|
return |
|
except Exception as e: |
|
logging.warning(f"Failed to download model.safetensors: {e}") |
|
|
|
try: |
|
hf_hub_download( |
|
repo_id=RMBG_REPO, |
|
filename="pytorch_model.bin", |
|
token=token, |
|
local_dir=local_dir, |
|
local_dir_use_symlinks=False |
|
) |
|
logging.info(f"Downloaded pytorch_model.bin to {local_dir}") |
|
except Exception as e: |
|
logging.error(f"Failed to download pytorch_model.bin: {e}") |
|
raise RuntimeError(f"Failed to download model weights from {RMBG_REPO}") |
|
|
|
def ensure_birefnet_files(local_dir: str, token: str) -> None: |
|
create_config_files(local_dir) |
|
download_birefnet_files(local_dir, token) |
|
download_model_weights(local_dir, token) |
|
|
|
def ensure_models_loaded() -> None: |
|
global MODELS_LOADED, LOAD_ERROR |
|
|
|
if not MODELS_LOADED: |
|
if is_spaces_environment(): |
|
|
|
|
|
|
|
time.sleep(1) |
|
print("="*70) |
|
print("ZERO GPU MODEL LOADING: 1. Models NOT loaded at startup") |
|
print("="*70) |
|
logging.info("ZERO GPU MODEL LOADING: Models NOT loaded at startup") |
|
logging.info("ZERO GPU MODEL LOADING: Models will be loaded on-demand in GPU context") |
|
return |
|
|
|
with LOAD_LOCK: |
|
if not MODELS_LOADED: |
|
if LOAD_ERROR: |
|
raise RuntimeError(f"Models failed to load: {LOAD_ERROR}") |
|
|
|
try: |
|
load_models() |
|
except Exception as e: |
|
LOAD_ERROR = str(e) |
|
raise |
|
|
|
|
|
|
|
|
|
def load_model_with_precision(model_class, repo_id: str, full_precision: bool, device_map: bool = True, trust_remote_code: bool = False): |
|
global DEVICE |
|
|
|
try: |
|
spaces_env = is_spaces_environment() |
|
|
|
if spaces_env: |
|
torch_device = torch.device("cpu") |
|
patch_spaces_device_handling() |
|
else: |
|
if DEVICE == "cuda": |
|
torch.cuda.empty_cache() |
|
torch_device = torch.device(DEVICE) |
|
|
|
load_kwargs = { |
|
"torch_dtype": torch.float32 if full_precision else torch.float16, |
|
"trust_remote_code": trust_remote_code, |
|
"low_cpu_mem_usage": True, |
|
"use_safetensors": True |
|
} |
|
|
|
if spaces_env: |
|
load_kwargs["device_map"] = None |
|
elif DEVICE == "cuda" and device_map and torch.cuda.device_count() > 1: |
|
load_kwargs["device_map"] = "auto" |
|
|
|
try: |
|
model = model_class.from_pretrained(repo_id, **load_kwargs) |
|
|
|
if not spaces_env and not hasattr(model, 'hf_device_map'): |
|
model = model.to(torch_device) |
|
|
|
if not full_precision and DEVICE == "cuda": |
|
model = model.half() |
|
|
|
except (ValueError, RuntimeError, OSError, UnicodeDecodeError) as e: |
|
logging.warning(f"Failed to load model with initial configuration: {e}") |
|
|
|
if "Unable to load weights from pytorch checkpoint" in str(e) or "UnicodeDecodeError" in str(e): |
|
logging.info(f"Attempting to clear cache and retry for {repo_id}") |
|
|
|
try: |
|
from huggingface_hub import scan_cache_dir |
|
cache_info = scan_cache_dir() |
|
for repo in cache_info.repos: |
|
if repo_id.replace("/", "--") in repo.repo_id: |
|
repo.delete() |
|
logging.info(f"Cleared cache for {repo_id}") |
|
break |
|
except Exception as cache_e: |
|
logging.warning(f"Cache clearing failed: {cache_e}") |
|
|
|
try: |
|
load_kwargs_retry = { |
|
"torch_dtype": torch.float32, |
|
"trust_remote_code": trust_remote_code, |
|
"force_download": True, |
|
"device_map": None, |
|
"low_cpu_mem_usage": True |
|
} |
|
model = model_class.from_pretrained(repo_id, **load_kwargs_retry) |
|
model = model.to(torch_device) |
|
|
|
except Exception as retry_e: |
|
logging.warning(f"Retry with force_download failed: {retry_e}") |
|
|
|
try: |
|
load_kwargs_tf = { |
|
"from_tf": True, |
|
"torch_dtype": torch.float32, |
|
"trust_remote_code": trust_remote_code, |
|
"device_map": None, |
|
"low_cpu_mem_usage": True |
|
} |
|
model = model_class.from_pretrained(repo_id, **load_kwargs_tf) |
|
model = model.to(torch_device) |
|
logging.info(f"Successfully loaded {repo_id} from TensorFlow checkpoint") |
|
|
|
except Exception as tf_e: |
|
logging.warning(f"TensorFlow fallback failed: {tf_e}") |
|
|
|
try: |
|
load_kwargs_basic = { |
|
"torch_dtype": torch.float32, |
|
"trust_remote_code": trust_remote_code, |
|
"device_map": None, |
|
"use_safetensors": False, |
|
"local_files_only": False |
|
} |
|
model = model_class.from_pretrained(repo_id, **load_kwargs_basic) |
|
model = model.to(torch_device) |
|
logging.info(f"Successfully loaded {repo_id} with basic configuration") |
|
|
|
except Exception as basic_e: |
|
logging.error(f"All fallback strategies failed for {repo_id}: {basic_e}") |
|
raise RuntimeError(f"Unable to load model {repo_id} after all retry attempts: {basic_e}") |
|
else: |
|
load_kwargs_fallback = { |
|
"torch_dtype": torch.float32, |
|
"trust_remote_code": trust_remote_code, |
|
"device_map": None |
|
} |
|
model = model_class.from_pretrained(repo_id, **load_kwargs_fallback) |
|
model = model.to(torch_device) |
|
|
|
model.eval() |
|
|
|
if not spaces_env: |
|
with torch.no_grad(): |
|
logging.info(f"Verifying model {repo_id} is on correct device") |
|
param = next(model.parameters()) |
|
|
|
if DEVICE == "cuda" and not param.is_cuda: |
|
model = model.to(torch_device) |
|
logging.warning(f"Forced model {repo_id} to {DEVICE}") |
|
|
|
logging.info(f"Model {repo_id} device: {param.device}") |
|
else: |
|
logging.info(f"Model {repo_id} loaded on CPU (Zero GPU environment)") |
|
|
|
return model |
|
|
|
except Exception as e: |
|
logging.error(f"Failed to load model from {repo_id} on {DEVICE}: {e}") |
|
raise |
|
|
|
def handle_rmbg_access_error(error_msg: str) -> None: |
|
if "403" in error_msg and "gated repo" in error_msg: |
|
logging.error("\n" + "="*60 + "\n" |
|
"ERROR: Access denied to RMBG-2.0 model!\n" |
|
"You need to request access at: https://huggingface.co/briaai/RMBG-2.0\n" + |
|
"="*60 + "\n") |
|
raise RuntimeError(ERROR_ACCESS_DENIED) |
|
elif "401" in error_msg: |
|
logging.error("\n" + "="*60 + "\n" |
|
"ERROR: Authentication failed!\n" |
|
"Please set your HF_TOKEN environment variable.\n" + |
|
"="*60 + "\n") |
|
raise RuntimeError(ERROR_AUTH_FAILED) |
|
else: |
|
raise RuntimeError(error_msg) |
|
|
|
|
|
|
|
|
|
def load_rtdetr_model() -> None: |
|
global RTDETR_PROCESSOR, RTDETR_MODEL |
|
logging.info("Loading RT-DETR model...") |
|
RTDETR_PROCESSOR = AutoProcessor.from_pretrained(RTDETR_REPO) |
|
RTDETR_MODEL = load_model_with_precision( |
|
AutoModelForObjectDetection, |
|
RTDETR_REPO, |
|
RTDETR_FULL_PRECISION, |
|
device_map=False |
|
) |
|
logging.info("RT-DETR model loaded successfully") |
|
|
|
def load_head_detection_model() -> None: |
|
global HEAD_PROCESSOR, HEAD_MODEL |
|
logging.info("Loading Head Detection model...") |
|
HEAD_PROCESSOR = AutoImageProcessor.from_pretrained(HEAD_DETECTION_REPO) |
|
HEAD_MODEL = load_model_with_precision( |
|
DetrForObjectDetection, |
|
HEAD_DETECTION_REPO, |
|
HEAD_DETECTION_FULL_PRECISION, |
|
device_map=False |
|
) |
|
logging.info("Head Detection model loaded successfully") |
|
|
|
def load_rmbg_model() -> None: |
|
global RMBG_MODEL |
|
logging.info("Loading RMBG model...") |
|
|
|
token = os.getenv("HF_TOKEN", "") |
|
if not token: |
|
logging.error(ERROR_NO_HF_TOKEN) |
|
logging.warning("RMBG model requires HF_TOKEN. Skipping RMBG model loading...") |
|
RMBG_MODEL = None |
|
return |
|
|
|
local_dir = DEFAULT_LOCAL_RMBG_DIR |
|
|
|
try: |
|
ensure_birefnet_files(local_dir, token) |
|
except RuntimeError as e: |
|
handle_rmbg_access_error(str(e)) |
|
|
|
os.environ["HF_HOME"] = os.path.dirname(local_dir) |
|
|
|
try: |
|
RMBG_MODEL = load_model_with_precision( |
|
AutoModelForImageSegmentation, |
|
local_dir, |
|
RMBG_FULL_PRECISION, |
|
trust_remote_code=True, |
|
device_map=False |
|
) |
|
|
|
if USE_TORCH_COMPILE and DEVICE == "cuda": |
|
try: |
|
RMBG_MODEL = torch.compile( |
|
RMBG_MODEL, |
|
mode=TORCH_COMPILE_MODE, |
|
backend=TORCH_COMPILE_BACKEND, |
|
fullgraph=False, |
|
dynamic=False |
|
) |
|
logging.info(f"RMBG model compiled with mode={TORCH_COMPILE_MODE}, backend={TORCH_COMPILE_BACKEND}") |
|
except Exception as e: |
|
logging.warning(f"Failed to compile RMBG model: {e}") |
|
|
|
logging.info("RMBG-2.0 model loaded successfully from local directory") |
|
except Exception as e: |
|
error_msg = str(e) |
|
handle_rmbg_access_error(error_msg) |
|
|
|
def load_yolos_fashionpedia_model() -> None: |
|
global YOLOS_PROCESSOR, YOLOS_MODEL |
|
logging.info("Loading YOLOS FashionPedia model...") |
|
|
|
try: |
|
YOLOS_PROCESSOR = AutoImageProcessor.from_pretrained( |
|
YOLOS_FASHIONPEDIA_REPO, |
|
size={"height": 512, "width": 512} |
|
) |
|
except Exception: |
|
logging.warning("Failed to set custom size for YOLOS processor, using default") |
|
YOLOS_PROCESSOR = AutoImageProcessor.from_pretrained(YOLOS_FASHIONPEDIA_REPO) |
|
|
|
YOLOS_MODEL = load_model_with_precision( |
|
YolosForObjectDetection, |
|
YOLOS_FASHIONPEDIA_REPO, |
|
YOLOS_FASHIONPEDIA_FULL_PRECISION, |
|
device_map=False |
|
) |
|
logging.info("YOLOS FashionPedia model loaded successfully") |
|
|
|
|
|
|
|
|
|
def load_models() -> None: |
|
global MODELS_LOADED, LOAD_ERROR |
|
|
|
with LOAD_LOCK: |
|
if MODELS_LOADED: |
|
logging.info("Models already loaded") |
|
return |
|
|
|
|
|
if is_spaces_environment(): |
|
logging.info("ZERO GPU MODEL LOADING: User request triggered model loading") |
|
|
|
check_hardware_environment() |
|
|
|
models_status = { |
|
"rtdetr": False, |
|
"head_detection": False, |
|
"rmbg": False, |
|
"yolos": False |
|
} |
|
|
|
critical_errors = [] |
|
|
|
try: |
|
load_rtdetr_model() |
|
models_status["rtdetr"] = True |
|
except Exception as e: |
|
critical_errors.append(f"RT-DETR: {str(e)}") |
|
logging.error(f"Failed to load RT-DETR model: {e}") |
|
|
|
try: |
|
load_head_detection_model() |
|
models_status["head_detection"] = True |
|
except Exception as e: |
|
critical_errors.append(f"Head Detection: {str(e)}") |
|
logging.error(f"Failed to load Head Detection model: {e}") |
|
|
|
try: |
|
load_rmbg_model() |
|
models_status["rmbg"] = True if RMBG_MODEL is not None else False |
|
except Exception as e: |
|
logging.warning(f"Failed to load RMBG model: {e}") |
|
models_status["rmbg"] = False |
|
|
|
try: |
|
load_yolos_fashionpedia_model() |
|
models_status["yolos"] = True |
|
except Exception as e: |
|
critical_errors.append(f"YOLOS: {str(e)}") |
|
logging.error(f"Failed to load YOLOS model: {e}") |
|
|
|
if models_status["rtdetr"] or models_status["yolos"]: |
|
MODELS_LOADED = True |
|
LOAD_ERROR = "" |
|
|
|
loaded = [k for k, v in models_status.items() if v] |
|
failed = [k for k, v in models_status.items() if not v] |
|
|
|
logging.info(f"Models loaded: {', '.join(loaded)}") |
|
|
|
if failed: |
|
logging.warning(f"Models failed: {', '.join(failed)}") |
|
else: |
|
error_msg = "Failed to load critical models. " + "; ".join(critical_errors) |
|
logging.error(error_msg) |
|
LOAD_ERROR = error_msg |
|
raise RuntimeError(error_msg) |
|
|
|
|
|
|
|
|
|
def move_models_to_gpu(): |
|
global RMBG_MODEL, RTDETR_PROCESSOR, RTDETR_MODEL, HEAD_MODEL, YOLOS_PROCESSOR, YOLOS_MODEL, DEVICE |
|
|
|
if not torch.cuda.is_available(): |
|
logging.warning("CUDA not available, cannot move models to GPU") |
|
return |
|
|
|
original_device = DEVICE |
|
DEVICE = "cuda" |
|
|
|
try: |
|
if RMBG_MODEL is not None: |
|
logging.info("Moving RMBG model to GPU...") |
|
RMBG_MODEL = RMBG_MODEL.to("cuda") |
|
if not RMBG_FULL_PRECISION: |
|
RMBG_MODEL = RMBG_MODEL.half() |
|
logging.info("RMBG model moved to GPU") |
|
|
|
if RTDETR_MODEL is not None: |
|
logging.info("Moving RT-DETR model to GPU...") |
|
RTDETR_MODEL = RTDETR_MODEL.to("cuda") |
|
if not RTDETR_FULL_PRECISION: |
|
RTDETR_MODEL = RTDETR_MODEL.half() |
|
logging.info("RT-DETR model moved to GPU") |
|
|
|
if HEAD_MODEL is not None: |
|
logging.info("Moving Head Detection model to GPU...") |
|
HEAD_MODEL = HEAD_MODEL.to("cuda") |
|
if not HEAD_DETECTION_FULL_PRECISION: |
|
HEAD_MODEL = HEAD_MODEL.half() |
|
logging.info("Head Detection model moved to GPU") |
|
|
|
if YOLOS_MODEL is not None: |
|
logging.info("Moving YOLOS model to GPU...") |
|
YOLOS_MODEL = YOLOS_MODEL.to("cuda") |
|
if not YOLOS_FASHIONPEDIA_FULL_PRECISION: |
|
YOLOS_MODEL = YOLOS_MODEL.half() |
|
logging.info("YOLOS model moved to GPU") |
|
|
|
logging.info("All models moved to GPU successfully") |
|
|
|
except Exception as e: |
|
logging.error(f"Failed to move models to GPU: {e}") |
|
DEVICE = original_device |
|
raise |
|
|