# ---------------------------------------------------------------------- # IMPORTS # ---------------------------------------------------------------------- import os import sys import logging import threading import torch import warnings import time # Suppress the model loading warnings about non-meta parameters warnings.filterwarnings("ignore", message=".*copying from a non-meta parameter.*") warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") # Add parent directory to path for imports during deployment current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(os.path.dirname(current_dir)) if parent_dir not in sys.path: sys.path.insert(0, parent_dir) from transformers import ( AutoProcessor, AutoImageProcessor, AutoModelForObjectDetection, DetrImageProcessor, DetrForObjectDetection, AutoModelForImageSegmentation, YolosImageProcessor, YolosForObjectDetection ) # ---------------------------------------------------------------------- # HARDWARE CONFIGURATION # ---------------------------------------------------------------------- def setup_device(): if os.getenv("SPACE_ID"): return "cpu" elif torch.cuda.is_available(): device_count = torch.cuda.device_count() if device_count >= 1: return "cuda" return "cpu" def check_cuda_availability(): if os.getenv("SPACE_ID"): logging.info("Running in Hugging Face Spaces (Zero GPU) - GPU will be available in decorated functions") return False if not torch.cuda.is_available(): logging.warning("\n" + "="*60 + "\n" + "WARNING: CUDA NOT AVAILABLE!\n" + "Running on CPU. Performance will be significantly reduced.\n" + "="*60 + "\n") return False device_count = torch.cuda.device_count() if device_count > 0: for i in range(device_count): props = torch.cuda.get_device_properties(i) logging.info(f"GPU {i}: {props.name} (Memory: {props.total_memory / (1024**3):.1f} GB)") else: logging.info("CUDA available but no GPUs detected") return True def check_hardware_environment(): gpu_available = check_cuda_availability() if os.getenv("SPACE_ID"): ensure_zerogpu() else: if gpu_available: logging.info(f"Running on {setup_device().upper()}") else: logging.info("Running on CPU") # ---------------------------------------------------------------------- # ZERO GPU CONFIGURATION # ---------------------------------------------------------------------- def ensure_zerogpu(): space_id = os.getenv("SPACE_ID") hf_token = os.getenv("HF_TOKEN") if not space_id: logging.info("Not running in Hugging Face Spaces") return try: from huggingface_hub import HfApi api = HfApi(token=hf_token) if hf_token else HfApi() space_info = api.get_space_runtime(space_id) current_hardware = getattr(space_info, 'hardware', None) logging.info(f"Current space hardware: {current_hardware}") if current_hardware and "a10g" not in current_hardware.lower(): logging.warning(f"Space is running on {current_hardware}, not zero-a10g") if hf_token: try: api.request_space_hardware(repo_id=space_id, hardware="zero-a10g") logging.info("Requested hardware change to zero-a10g") except Exception as e: logging.error(f"Failed to request hardware change: {e}") else: logging.warning("Cannot request hardware change without HF_TOKEN") else: logging.info("Space is already running on zero-a10g") except ImportError: logging.warning("huggingface_hub not available, cannot verify space hardware") except Exception as e: logging.error(f"Unexpected error in ensure_zerogpu: {str(e)}") DEVICE = setup_device() # ---------------------------------------------------------------------- # MODEL PRECISION SETTINGS # ---------------------------------------------------------------------- RTDETR_FULL_PRECISION = True HEAD_DETECTION_FULL_PRECISION = True RMBG_FULL_PRECISION = True YOLOS_FASHIONPEDIA_FULL_PRECISION = True # ---------------------------------------------------------------------- # OPTIMIZATION SETTINGS # ---------------------------------------------------------------------- USE_TORCH_COMPILE = True TORCH_COMPILE_MODE = "reduce-overhead" TORCH_COMPILE_BACKEND = "inductor" ENABLE_CHANNELS_LAST = True ENABLE_CUDA_GRAPHS = True USE_MIXED_PRECISION = True # ---------------------------------------------------------------------- # MODEL REPOSITORY IDENTIFIERS # ---------------------------------------------------------------------- RTDETR_REPO = "PekingU/rtdetr_r50vd" HEAD_DETECTION_REPO = "sanali209/DT_face_head_char" RMBG_REPO = "briaai/RMBG-2.0" YOLOS_FASHIONPEDIA_REPO = "valentinafeve/yolos-fashionpedia" # ---------------------------------------------------------------------- # BIREFNET CONFIGURATION # ---------------------------------------------------------------------- BIREFNET_CONFIG_PYTHON_TEMPLATE = """from transformers.configuration_utils import PretrainedConfig class BiRefNetConfig(PretrainedConfig): model_type = "SegformerForSemanticSegmentation" num_channels = 3 backbone = "mit_b5" hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 bb_pretrained = False """ BIREFNET_CONFIG_JSON = """{ "_name_or_path": "briaai/RMBG-2.0", "architectures": ["BiRefNet"], "auto_map": { "AutoConfig": "BiRefNet_config.BiRefNetConfig", "AutoModelForImageSegmentation": "birefnet.BiRefNet" }, "bb_pretrained": false }""" BIREFNET_CONFIG_FILES = { "BiRefNet_config.py": BIREFNET_CONFIG_PYTHON_TEMPLATE, "config.json": BIREFNET_CONFIG_JSON } BIREFNET_DOWNLOAD_FILES = ["birefnet.py", "preprocessor_config.json"] BIREFNET_WEIGHT_FILES = ["model.safetensors", "pytorch_model.bin"] DEFAULT_LOCAL_RMBG_DIR = "models/rmbg2" # ---------------------------------------------------------------------- # ERROR MESSAGES # ---------------------------------------------------------------------- ERROR_NO_HF_TOKEN = "HF_TOKEN environment variable not set. Please set it in your Space secrets." ERROR_ACCESS_DENIED = "Access denied to RMBG-2.0. Please request access at https://huggingface.co/briaai/RMBG-2.0 and try again." ERROR_AUTH_FAILED = "Authentication failed. Please set HF_TOKEN environment variable." # ---------------------------------------------------------------------- # GLOBAL MODEL INSTANCES # ---------------------------------------------------------------------- RTDETR_PROCESSOR = None RTDETR_MODEL = None HEAD_PROCESSOR = None HEAD_MODEL = None RMBG_MODEL = None YOLOS_PROCESSOR = None YOLOS_MODEL = None # ---------------------------------------------------------------------- # GLOBAL STATE VARIABLES # ---------------------------------------------------------------------- MODELS_LOADED = False LOAD_ERROR = "" LOAD_LOCK = threading.Lock() # ---------------------------------------------------------------------- # MODEL LOADING WORKAROUNDS FOR SPACES ENVIRONMENT # ---------------------------------------------------------------------- def patch_spaces_device_handling(): try: import spaces.zero.torch.patching as spaces_patching original_untyped_storage_new = spaces_patching._untyped_storage_new_register def patched_untyped_storage_new_register(storage_cls): def wrapper(*args, **kwargs): device = kwargs.get('device') if device is not None and isinstance(device, str): kwargs['device'] = torch.device(device) return original_untyped_storage_new(storage_cls)(*args, **kwargs) return wrapper spaces_patching._untyped_storage_new_register = patched_untyped_storage_new_register logging.info("Successfully patched spaces device handling") return True except Exception as e: logging.debug(f"Spaces patching not available or failed: {e}") return False def is_spaces_environment(): return os.getenv("SPACE_ID") is not None or "spaces" in sys.modules # ---------------------------------------------------------------------- # BIREFNET FILE MANAGEMENT # ---------------------------------------------------------------------- def create_config_files(local_dir: str) -> None: os.makedirs(local_dir, exist_ok=True) for filename, content in BIREFNET_CONFIG_FILES.items(): file_path = os.path.join(local_dir, filename) if not os.path.exists(file_path): with open(file_path, "w") as f: f.write(content) logging.info(f"Created {filename} in {local_dir}") def download_birefnet_files(local_dir: str, token: str) -> None: from huggingface_hub import hf_hub_download for file in BIREFNET_DOWNLOAD_FILES: file_path = os.path.join(local_dir, file) if not os.path.exists(file_path): try: hf_hub_download( repo_id=RMBG_REPO, filename=file, token=token, local_dir=local_dir, local_dir_use_symlinks=False ) logging.info(f"Downloaded {file} to {local_dir}") except Exception as e: logging.error(f"Failed to download {file}: {e}") raise RuntimeError(f"Failed to download {file} from {RMBG_REPO}") def download_model_weights(local_dir: str, token: str) -> None: from huggingface_hub import hf_hub_download weights_exist = any( os.path.exists(os.path.join(local_dir, weight_file)) for weight_file in BIREFNET_WEIGHT_FILES ) if weights_exist: return try: hf_hub_download( repo_id=RMBG_REPO, filename="model.safetensors", token=token, local_dir=local_dir, local_dir_use_symlinks=False ) logging.info(f"Downloaded model.safetensors to {local_dir}") return except Exception as e: logging.warning(f"Failed to download model.safetensors: {e}") try: hf_hub_download( repo_id=RMBG_REPO, filename="pytorch_model.bin", token=token, local_dir=local_dir, local_dir_use_symlinks=False ) logging.info(f"Downloaded pytorch_model.bin to {local_dir}") except Exception as e: logging.error(f"Failed to download pytorch_model.bin: {e}") raise RuntimeError(f"Failed to download model weights from {RMBG_REPO}") def ensure_birefnet_files(local_dir: str, token: str) -> None: create_config_files(local_dir) download_birefnet_files(local_dir, token) download_model_weights(local_dir, token) def ensure_models_loaded() -> None: global MODELS_LOADED, LOAD_ERROR if not MODELS_LOADED: if is_spaces_environment(): # ---------------------------------------------------------------------- # ZERO GPU MODEL LOADING: 1. Models NOT loaded at startup # ---------------------------------------------------------------------- time.sleep(1) print("="*70) print("ZERO GPU MODEL LOADING: 1. Models NOT loaded at startup") print("="*70) logging.info("ZERO GPU MODEL LOADING: Models NOT loaded at startup") logging.info("ZERO GPU MODEL LOADING: Models will be loaded on-demand in GPU context") return with LOAD_LOCK: if not MODELS_LOADED: if LOAD_ERROR: raise RuntimeError(f"Models failed to load: {LOAD_ERROR}") try: load_models() except Exception as e: LOAD_ERROR = str(e) raise # ---------------------------------------------------------------------- # MODEL LOADING WITH PRECISION # ---------------------------------------------------------------------- def load_model_with_precision(model_class, repo_id: str, full_precision: bool, device_map: bool = True, trust_remote_code: bool = False): global DEVICE try: spaces_env = is_spaces_environment() if spaces_env: torch_device = torch.device("cpu") patch_spaces_device_handling() else: if DEVICE == "cuda": torch.cuda.empty_cache() torch_device = torch.device(DEVICE) load_kwargs = { "torch_dtype": torch.float32 if full_precision else torch.float16, "trust_remote_code": trust_remote_code, "low_cpu_mem_usage": True, "use_safetensors": True } if spaces_env: load_kwargs["device_map"] = None elif DEVICE == "cuda" and device_map and torch.cuda.device_count() > 1: load_kwargs["device_map"] = "auto" try: model = model_class.from_pretrained(repo_id, **load_kwargs) if not spaces_env and not hasattr(model, 'hf_device_map'): model = model.to(torch_device) if not full_precision and DEVICE == "cuda": model = model.half() except (ValueError, RuntimeError, OSError, UnicodeDecodeError) as e: logging.warning(f"Failed to load model with initial configuration: {e}") if "Unable to load weights from pytorch checkpoint" in str(e) or "UnicodeDecodeError" in str(e): logging.info(f"Attempting to clear cache and retry for {repo_id}") try: from huggingface_hub import scan_cache_dir cache_info = scan_cache_dir() for repo in cache_info.repos: if repo_id.replace("/", "--") in repo.repo_id: repo.delete() logging.info(f"Cleared cache for {repo_id}") break except Exception as cache_e: logging.warning(f"Cache clearing failed: {cache_e}") try: load_kwargs_retry = { "torch_dtype": torch.float32, "trust_remote_code": trust_remote_code, "force_download": True, "device_map": None, "low_cpu_mem_usage": True } model = model_class.from_pretrained(repo_id, **load_kwargs_retry) model = model.to(torch_device) except Exception as retry_e: logging.warning(f"Retry with force_download failed: {retry_e}") try: load_kwargs_tf = { "from_tf": True, "torch_dtype": torch.float32, "trust_remote_code": trust_remote_code, "device_map": None, "low_cpu_mem_usage": True } model = model_class.from_pretrained(repo_id, **load_kwargs_tf) model = model.to(torch_device) logging.info(f"Successfully loaded {repo_id} from TensorFlow checkpoint") except Exception as tf_e: logging.warning(f"TensorFlow fallback failed: {tf_e}") try: load_kwargs_basic = { "torch_dtype": torch.float32, "trust_remote_code": trust_remote_code, "device_map": None, "use_safetensors": False, "local_files_only": False } model = model_class.from_pretrained(repo_id, **load_kwargs_basic) model = model.to(torch_device) logging.info(f"Successfully loaded {repo_id} with basic configuration") except Exception as basic_e: logging.error(f"All fallback strategies failed for {repo_id}: {basic_e}") raise RuntimeError(f"Unable to load model {repo_id} after all retry attempts: {basic_e}") else: load_kwargs_fallback = { "torch_dtype": torch.float32, "trust_remote_code": trust_remote_code, "device_map": None } model = model_class.from_pretrained(repo_id, **load_kwargs_fallback) model = model.to(torch_device) model.eval() if not spaces_env: with torch.no_grad(): logging.info(f"Verifying model {repo_id} is on correct device") param = next(model.parameters()) if DEVICE == "cuda" and not param.is_cuda: model = model.to(torch_device) logging.warning(f"Forced model {repo_id} to {DEVICE}") logging.info(f"Model {repo_id} device: {param.device}") else: logging.info(f"Model {repo_id} loaded on CPU (Zero GPU environment)") return model except Exception as e: logging.error(f"Failed to load model from {repo_id} on {DEVICE}: {e}") raise def handle_rmbg_access_error(error_msg: str) -> None: if "403" in error_msg and "gated repo" in error_msg: logging.error("\n" + "="*60 + "\n" "ERROR: Access denied to RMBG-2.0 model!\n" "You need to request access at: https://huggingface.co/briaai/RMBG-2.0\n" + "="*60 + "\n") raise RuntimeError(ERROR_ACCESS_DENIED) elif "401" in error_msg: logging.error("\n" + "="*60 + "\n" "ERROR: Authentication failed!\n" "Please set your HF_TOKEN environment variable.\n" + "="*60 + "\n") raise RuntimeError(ERROR_AUTH_FAILED) else: raise RuntimeError(error_msg) # ---------------------------------------------------------------------- # INDIVIDUAL MODEL LOADING FUNCTIONS # ---------------------------------------------------------------------- def load_rtdetr_model() -> None: global RTDETR_PROCESSOR, RTDETR_MODEL logging.info("Loading RT-DETR model...") RTDETR_PROCESSOR = AutoProcessor.from_pretrained(RTDETR_REPO) RTDETR_MODEL = load_model_with_precision( AutoModelForObjectDetection, RTDETR_REPO, RTDETR_FULL_PRECISION, device_map=False ) logging.info("RT-DETR model loaded successfully") def load_head_detection_model() -> None: global HEAD_PROCESSOR, HEAD_MODEL logging.info("Loading Head Detection model...") HEAD_PROCESSOR = AutoImageProcessor.from_pretrained(HEAD_DETECTION_REPO) HEAD_MODEL = load_model_with_precision( DetrForObjectDetection, HEAD_DETECTION_REPO, HEAD_DETECTION_FULL_PRECISION, device_map=False ) logging.info("Head Detection model loaded successfully") def load_rmbg_model() -> None: global RMBG_MODEL logging.info("Loading RMBG model...") token = os.getenv("HF_TOKEN", "") if not token: logging.error(ERROR_NO_HF_TOKEN) logging.warning("RMBG model requires HF_TOKEN. Skipping RMBG model loading...") RMBG_MODEL = None return local_dir = DEFAULT_LOCAL_RMBG_DIR try: ensure_birefnet_files(local_dir, token) except RuntimeError as e: handle_rmbg_access_error(str(e)) os.environ["HF_HOME"] = os.path.dirname(local_dir) try: RMBG_MODEL = load_model_with_precision( AutoModelForImageSegmentation, local_dir, RMBG_FULL_PRECISION, trust_remote_code=True, device_map=False ) if USE_TORCH_COMPILE and DEVICE == "cuda": try: RMBG_MODEL = torch.compile( RMBG_MODEL, mode=TORCH_COMPILE_MODE, backend=TORCH_COMPILE_BACKEND, fullgraph=False, dynamic=False ) logging.info(f"RMBG model compiled with mode={TORCH_COMPILE_MODE}, backend={TORCH_COMPILE_BACKEND}") except Exception as e: logging.warning(f"Failed to compile RMBG model: {e}") logging.info("RMBG-2.0 model loaded successfully from local directory") except Exception as e: error_msg = str(e) handle_rmbg_access_error(error_msg) def load_yolos_fashionpedia_model() -> None: global YOLOS_PROCESSOR, YOLOS_MODEL logging.info("Loading YOLOS FashionPedia model...") try: YOLOS_PROCESSOR = AutoImageProcessor.from_pretrained( YOLOS_FASHIONPEDIA_REPO, size={"height": 512, "width": 512} ) except Exception: logging.warning("Failed to set custom size for YOLOS processor, using default") YOLOS_PROCESSOR = AutoImageProcessor.from_pretrained(YOLOS_FASHIONPEDIA_REPO) YOLOS_MODEL = load_model_with_precision( YolosForObjectDetection, YOLOS_FASHIONPEDIA_REPO, YOLOS_FASHIONPEDIA_FULL_PRECISION, device_map=False ) logging.info("YOLOS FashionPedia model loaded successfully") # ---------------------------------------------------------------------- # MAIN MODEL LOADING FUNCTION # ---------------------------------------------------------------------- def load_models() -> None: global MODELS_LOADED, LOAD_ERROR with LOAD_LOCK: if MODELS_LOADED: logging.info("Models already loaded") return # Skip the ZERO GPU step 2 print here as it's already shown in test execution flow if is_spaces_environment(): logging.info("ZERO GPU MODEL LOADING: User request triggered model loading") check_hardware_environment() models_status = { "rtdetr": False, "head_detection": False, "rmbg": False, "yolos": False } critical_errors = [] try: load_rtdetr_model() models_status["rtdetr"] = True except Exception as e: critical_errors.append(f"RT-DETR: {str(e)}") logging.error(f"Failed to load RT-DETR model: {e}") try: load_head_detection_model() models_status["head_detection"] = True except Exception as e: critical_errors.append(f"Head Detection: {str(e)}") logging.error(f"Failed to load Head Detection model: {e}") try: load_rmbg_model() models_status["rmbg"] = True if RMBG_MODEL is not None else False except Exception as e: logging.warning(f"Failed to load RMBG model: {e}") models_status["rmbg"] = False try: load_yolos_fashionpedia_model() models_status["yolos"] = True except Exception as e: critical_errors.append(f"YOLOS: {str(e)}") logging.error(f"Failed to load YOLOS model: {e}") if models_status["rtdetr"] or models_status["yolos"]: MODELS_LOADED = True LOAD_ERROR = "" loaded = [k for k, v in models_status.items() if v] failed = [k for k, v in models_status.items() if not v] logging.info(f"Models loaded: {', '.join(loaded)}") if failed: logging.warning(f"Models failed: {', '.join(failed)}") else: error_msg = "Failed to load critical models. " + "; ".join(critical_errors) logging.error(error_msg) LOAD_ERROR = error_msg raise RuntimeError(error_msg) # ---------------------------------------------------------------------- # MOVE MODELS TO GPU FUNCTION # ---------------------------------------------------------------------- def move_models_to_gpu(): global RMBG_MODEL, RTDETR_PROCESSOR, RTDETR_MODEL, HEAD_MODEL, YOLOS_PROCESSOR, YOLOS_MODEL, DEVICE if not torch.cuda.is_available(): logging.warning("CUDA not available, cannot move models to GPU") return original_device = DEVICE DEVICE = "cuda" try: if RMBG_MODEL is not None: logging.info("Moving RMBG model to GPU...") RMBG_MODEL = RMBG_MODEL.to("cuda") if not RMBG_FULL_PRECISION: RMBG_MODEL = RMBG_MODEL.half() logging.info("RMBG model moved to GPU") if RTDETR_MODEL is not None: logging.info("Moving RT-DETR model to GPU...") RTDETR_MODEL = RTDETR_MODEL.to("cuda") if not RTDETR_FULL_PRECISION: RTDETR_MODEL = RTDETR_MODEL.half() logging.info("RT-DETR model moved to GPU") if HEAD_MODEL is not None: logging.info("Moving Head Detection model to GPU...") HEAD_MODEL = HEAD_MODEL.to("cuda") if not HEAD_DETECTION_FULL_PRECISION: HEAD_MODEL = HEAD_MODEL.half() logging.info("Head Detection model moved to GPU") if YOLOS_MODEL is not None: logging.info("Moving YOLOS model to GPU...") YOLOS_MODEL = YOLOS_MODEL.to("cuda") if not YOLOS_FASHIONPEDIA_FULL_PRECISION: YOLOS_MODEL = YOLOS_MODEL.half() logging.info("YOLOS model moved to GPU") logging.info("All models moved to GPU successfully") except Exception as e: logging.error(f"Failed to move models to GPU: {e}") DEVICE = original_device raise