FPS-Studio / modules /toolbox /esrgan_core.py
rahul7star's picture
Migrated from GitHub
05fcd0f verified
import os
import torch
import gc
import devicetorch
import warnings
import traceback
from pathlib import Path
from huggingface_hub import snapshot_download
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from basicsr.utils.download_util import load_file_from_url # Import for direct downloads
# Conditional import for GFPGAN
try:
from gfpgan import GFPGANer
GFPGAN_AVAILABLE = True
except ImportError:
GFPGAN_AVAILABLE = False
from .message_manager import MessageManager
_MODULE_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
MODEL_ESRGAN_PATH = _MODULE_DIR / "model_esrgan"
# Define a path for GFPGAN models, can be within MODEL_ESRGAN_PATH or separate
MODEL_GFPGAN_PATH = _MODULE_DIR / "model_gfpgan"
class ESRGANUpscaler:
def __init__(self, message_manager: MessageManager, device: torch.device):
self.message_manager = message_manager
self.device = device
self.model_dir = Path(MODEL_ESRGAN_PATH)
self.gfpgan_model_dir = Path(MODEL_GFPGAN_PATH) # GFPGAN model directory
os.makedirs(self.model_dir, exist_ok=True)
os.makedirs(self.gfpgan_model_dir, exist_ok=True) # Ensure GFPGAN model dir exists
self.supported_models = {
"RealESRGAN_x2plus": {
"filename": "RealESRGAN_x2plus.pth",
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
"hf_repo_id": None,
"scale": 2,
"model_class": RRDBNet,
"model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2),
"description": "General purpose. Faster than x4 models due to smaller native output. Good for moderate upscaling."
},
"RealESRGAN_x4plus": {
"filename": "RealESRGAN_x4plus.pth",
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"hf_repo_id": None,
"scale": 4,
"model_class": RRDBNet,
"model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
"description": "General purpose. Prioritizes sharpness & detail. Good default for most videos."
},
"RealESRNet_x4plus": {
"filename": "RealESRNet_x4plus.pth",
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
"hf_repo_id": None,
"scale": 4,
"model_class": RRDBNet,
"model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
"description": "Similar to RealESRGAN_x4plus, but trained for higher fidelity, often yielding smoother results."
},
"RealESR-general-x4v3": {
"filename": "realesr-general-x4v3.pth", # Main model
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
"wdn_filename": "realesr-general-wdn-x4v3.pth", # Companion WDN model
"wdn_file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
"scale": 4, "model_class": SRVGGNetCompact,
"model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'),
"description": "Versatile SRVGG-based. Balances detail & naturalness. Has adjustable denoise strength." # Updated description
},
"RealESRGAN_x4plus_anime_6B": {
"filename": "RealESRGAN_x4plus_anime_6B.pth",
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"hf_repo_id": None,
"scale": 4,
"model_class": RRDBNet,
"model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4),
"description": "Optimized for anime. Lighter 6-block version of x4plus for faster anime upscaling."
},
"RealESR_AnimeVideo_v3": {
"filename": "realesr-animevideov3.pth",
"file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
"hf_repo_id": None,
"scale": 4,
"model_class": SRVGGNetCompact,
"model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'),
"description": "Specialized SRVGG-based model for anime. Often excels with animated content."
}
}
self.upsamplers: dict[str, dict[str, RealESRGANer | int | None]] = {}
self.face_enhancer: GFPGANer | None = None # For GFPGAN
def _ensure_model_downloaded(self, model_key: str, target_dir: Path | None = None, is_gfpgan: bool = False, is_wdn_companion: bool = False) -> bool:
# Modified to handle WDN companion model download for RealESR-general-x4v3
if target_dir is None:
current_model_dir = self.gfpgan_model_dir if is_gfpgan else self.model_dir
else:
current_model_dir = target_dir
model_info_source = {}
actual_model_filename = ""
if is_gfpgan:
model_info_source = {
"filename": "GFPGANv1.4.pth",
"file_url": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth",
"hf_repo_id": None
}
actual_model_filename = model_info_source["filename"]
else:
if model_key not in self.supported_models:
self.message_manager.add_error(f"ESRGAN model key '{model_key}' not supported.")
return False
model_details = self.supported_models[model_key]
if is_wdn_companion:
if "wdn_filename" not in model_details or "wdn_file_url" not in model_details:
self.message_manager.add_error(f"WDN companion model details missing for '{model_key}'.")
return False
model_info_source = {
"filename": model_details["wdn_filename"],
"file_url": model_details["wdn_file_url"],
"hf_repo_id": None # Assuming direct URL for WDN for now
}
actual_model_filename = model_details["wdn_filename"]
else: # Regular ESRGAN model
model_info_source = model_details
actual_model_filename = model_details["filename"]
model_path = current_model_dir / actual_model_filename
if not model_path.exists():
log_prefix = "WDN " if is_wdn_companion else ""
self.message_manager.add_message(f"{log_prefix}Model '{actual_model_filename}' not found. Downloading...")
try:
downloaded_successfully = False
if "file_url" in model_info_source and model_info_source["file_url"]:
urls_to_try = model_info_source["file_url"]
if isinstance(urls_to_try, str): urls_to_try = [urls_to_try]
for url in urls_to_try:
self.message_manager.add_message(f"Attempting download from URL: {url}")
try:
load_file_from_url(
url=url, model_dir=str(current_model_dir),
progress=True, file_name=actual_model_filename
)
if model_path.exists():
downloaded_successfully = True
self.message_manager.add_success(f"{log_prefix}Model '{actual_model_filename}' downloaded from URL.")
break
except Exception as e_url:
self.message_manager.add_warning(f"Failed to download from {url}: {e_url}. Trying next source.")
continue
if not downloaded_successfully and "hf_repo_id" in model_info_source and model_info_source["hf_repo_id"]:
self.message_manager.add_message(f"Attempting download from Hugging Face Hub: {model_info_source['hf_repo_id']}")
snapshot_download(
repo_id=model_info_source["hf_repo_id"], allow_patterns=[actual_model_filename],
local_dir=current_model_dir, local_dir_use_symlinks=False
)
if model_path.exists():
downloaded_successfully = True
self.message_manager.add_success(f"{log_prefix}Model '{actual_model_filename}' downloaded from Hugging Face Hub.")
if not downloaded_successfully:
self.message_manager.add_error(f"All download attempts failed for '{actual_model_filename}'.")
return False
except Exception as e:
self.message_manager.add_error(f"Failed to download {log_prefix}model '{actual_model_filename}': {e}")
self.message_manager.add_error(traceback.format_exc())
return False
return True
def load_model(self, model_key: str, tile_size: int = 0, denoise_strength: float | None = None) -> RealESRGANer | None:
if model_key not in self.supported_models:
self.message_manager.add_error(f"ESRGAN model key '{model_key}' not supported.")
return None
# Check if model is already loaded with the same configuration
current_config_signature = (tile_size, denoise_strength if model_key == "RealESR-general-x4v3" else None)
if model_key in self.upsamplers:
existing_config = self.upsamplers[model_key]
existing_config_signature = (
existing_config.get('tile_size', 0),
existing_config.get('denoise_strength') if model_key == "RealESR-general-x4v3" else None
)
if existing_config.get("upsampler") is not None and existing_config_signature == current_config_signature:
log_tile = f"Tile: {str(tile_size) if tile_size > 0 else 'Auto'}"
log_dni = f", DNI: {denoise_strength:.2f}" if denoise_strength is not None and model_key == "RealESR-general-x4v3" else ""
self.message_manager.add_message(f"ESRGAN model '{model_key}' ({log_tile}{log_dni}) already loaded.")
return existing_config["upsampler"]
elif existing_config.get("upsampler") is not None and existing_config_signature != current_config_signature:
self.message_manager.add_message(
f"ESRGAN model '{model_key}' config changed. Unloading to reload with new settings."
)
self.unload_model(model_key)
# Ensure main model is downloaded
if not self._ensure_model_downloaded(model_key):
return None
model_info = self.supported_models[model_key]
model_path_for_upsampler = str(self.model_dir / model_info["filename"])
dni_weight_for_upsampler = None
log_msg_parts = [
f"Loading ESRGAN model '{model_info['filename']}' (Key: {model_key}, Scale: {model_info['scale']}x",
f"Tile: {str(tile_size) if tile_size > 0 else 'Auto'}"
]
# Specific handling for RealESR-general-x4v3 with denoise_strength
if model_key == "RealESR-general-x4v3" and denoise_strength is not None and 0.0 <= denoise_strength < 1.0:
# Denoise strength 1.0 means use only the main model, so no DNI.
# Denoise strength < 0.0 is invalid.
if "wdn_filename" not in model_info or "wdn_file_url" not in model_info:
self.message_manager.add_error(f"WDN companion model details missing for '{model_key}'. Cannot apply denoise strength.")
return None # Or fallback to no DNI? For now, error.
# Ensure WDN companion model is downloaded
if not self._ensure_model_downloaded(model_key, is_wdn_companion=True):
self.message_manager.add_error(f"Failed to download WDN companion for '{model_key}'. Cannot apply denoise strength.")
return None
wdn_model_path_str = str(self.model_dir / model_info["wdn_filename"])
model_path_for_upsampler = [model_path_for_upsampler, wdn_model_path_str] # Pass list of paths
dni_weight_for_upsampler = [denoise_strength, 1.0 - denoise_strength] # [main_model_strength, wdn_model_strength]
log_msg_parts.append(f"DNI Strength: {denoise_strength:.2f}")
log_msg_parts.append(f") to device: {self.device}...")
self.message_manager.add_message(" ".join(log_msg_parts))
try:
model_params_with_correct_scale = model_info["model_params"].copy()
if "scale" in model_params_with_correct_scale: model_params_with_correct_scale["scale"] = model_info["scale"]
elif "upscale" in model_params_with_correct_scale: model_params_with_correct_scale["upscale"] = model_info["scale"]
else: model_params_with_correct_scale["scale"] = model_info["scale"]
model_arch = model_info["model_class"](**model_params_with_correct_scale)
gpu_id_for_realesrgan = self.device.index if self.device.type == 'cuda' and self.device.index is not None else None
use_half_precision = True if self.device.type == 'cuda' else False
with warnings.catch_warnings():
# Suppress the TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD warning from RealESRGANer/basicsr
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=".*Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected.*"
)
# Suppress torchvision pretrained/weights warnings potentially triggered by basicsr
warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated.*")
warnings.filterwarnings("ignore", category=UserWarning, message="Arguments other than a weight enum or `None` for 'weights' are deprecated.*")
upsampler = RealESRGANer(
scale=model_info["scale"],
model_path=model_path_for_upsampler,
dni_weight=dni_weight_for_upsampler,
model=model_arch,
tile=tile_size,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
gpu_id=gpu_id_for_realesrgan
)
self.upsamplers[model_key] = {
"upsampler": upsampler,
"tile_size": tile_size,
"native_scale": model_info["scale"],
"denoise_strength": denoise_strength if model_key == "RealESR-general-x4v3" else None
}
self.message_manager.add_success(f"ESRGAN model '{model_info['filename']}' (Key: {model_key}) loaded successfully.")
return upsampler
except Exception as e:
self.message_manager.add_error(f"Failed to load ESRGAN model '{model_info['filename']}' (Key: {model_key}): {e}")
self.message_manager.add_error(traceback.format_exc())
if model_key in self.upsamplers: del self.upsamplers[model_key]
return None
def _load_face_enhancer(self, model_name="GFPGANv1.4.pth", bg_upsampler=None) -> bool:
if not GFPGAN_AVAILABLE:
self.message_manager.add_warning("GFPGAN library not available. Cannot load face enhancer.")
return False
if self.face_enhancer is not None:
# If bg_upsampler changed, we might need to re-init. For now, assume if loaded, it's fine or will be handled by caller.
if bg_upsampler is not None and hasattr(self.face_enhancer, 'bg_upsampler') and self.face_enhancer.bg_upsampler != bg_upsampler:
self.message_manager.add_message("GFPGAN face enhancer already loaded, but with a different background upsampler. Re-initializing GFPGAN...")
self._unload_face_enhancer() # Unload to reload with new bg_upsampler
else:
self.message_manager.add_message("GFPGAN face enhancer already loaded.")
return True
if not self._ensure_model_downloaded(model_key=model_name, is_gfpgan=True):
self.message_manager.add_error(f"Failed to download GFPGAN model '{model_name}'.")
return False
gfpgan_model_path = str(self.gfpgan_model_dir / model_name)
self.message_manager.add_message(f"Loading GFPGAN face enhancer from {gfpgan_model_path}...")
try:
# --- ADDED: warnings.catch_warnings() context manager ---
with warnings.catch_warnings():
# Suppress warnings from GFPGANer and its dependencies (facexlib)
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=".*Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected.*"
)
warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated.*")
warnings.filterwarnings("ignore", category=UserWarning, message="Arguments other than a weight enum or `None` for 'weights' are deprecated.*")
self.face_enhancer = GFPGANer(
model_path=gfpgan_model_path,
upscale=1,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler,
device=self.device
)
self.message_manager.add_success("GFPGAN face enhancer loaded.")
return True
except Exception as e:
self.message_manager.add_error(f"Failed to load GFPGAN face enhancer: {e}")
self.message_manager.add_error(traceback.format_exc())
self.face_enhancer = None
return False
def _unload_face_enhancer(self):
if self.face_enhancer is not None:
self.message_manager.add_message("Unloading GFPGAN face enhancer...")
del self.face_enhancer
self.face_enhancer = None
gc.collect()
if self.device.type == 'cuda':
torch.cuda.empty_cache()
self.message_manager.add_success("GFPGAN face enhancer unloaded.")
else:
self.message_manager.add_message("GFPGAN face enhancer not loaded.")
def unload_model(self, model_key: str):
if model_key in self.upsamplers and self.upsamplers[model_key].get("upsampler") is not None:
config = self.upsamplers.pop(model_key)
upsampler_instance = config["upsampler"]
tile_s = config.get("tile_size", 0)
native_scale = config.get("native_scale", "N/A") # Get native_scale for logging
log_tile_size = str(tile_s) if tile_s > 0 else "Auto"
self.message_manager.add_message(f"Unloading ESRGAN model '{model_key}' (Scale: {native_scale}x, Tile: {log_tile_size})...")
if self.face_enhancer and hasattr(self.face_enhancer, 'bg_upsampler') and self.face_enhancer.bg_upsampler == upsampler_instance:
self.message_manager.add_message("Unloading associated GFPGAN as its BG upsampler is being removed.")
self._unload_face_enhancer()
del upsampler_instance
devicetorch.empty_cache(torch)
gc.collect()
self.message_manager.add_success(f"ESRGAN model '{model_key}' unloaded and memory cleared.")
else:
self.message_manager.add_message(f"ESRGAN model '{model_key}' not loaded, no need to unload.")
def unload_all_models(self):
if not self.upsamplers and not self.face_enhancer:
self.message_manager.add_message("No ESRGAN or GFPGAN models currently loaded.")
return
self.message_manager.add_message("Unloading all ESRGAN models...")
model_keys_to_unload = list(self.upsamplers.keys())
for key in model_keys_to_unload:
if key in self.upsamplers:
config = self.upsamplers.pop(key)
upsampler_instance = config["upsampler"]
del upsampler_instance # type: ignore
self._unload_face_enhancer()
devicetorch.empty_cache(torch)
gc.collect()
self.message_manager.add_success("All ESRGAN and GFPGAN models unloaded and memory cleared.")
def upscale_frame(self, frame_np_array, model_key: str, target_outscale_factor: float, enhance_face: bool = False):
"""
Upscales a single frame using the specified model and target output scale.
"""
config = self.upsamplers.get(model_key)
upsampler: RealESRGANer | None = None
current_tile_size = 0
model_native_scale = 0
if config and config.get("upsampler"):
upsampler = config["upsampler"] # type: ignore
current_tile_size = config.get("tile_size", 0) # type: ignore
model_native_scale = config.get("native_scale", 0) # type: ignore
if model_native_scale == 0:
self.message_manager.add_error(f"Error: Native scale for model '{model_key}' is 0 or not found in config.")
return None
if upsampler is None:
self.message_manager.add_warning(
f"ESRGAN model '{model_key}' not pre-loaded. Attempting to load now (with default Tile: Auto)..."
)
tile_to_load_with = config.get("tile_size", 0) if config else 0
upsampler = self.load_model(model_key, tile_size=tile_to_load_with)
if upsampler is None:
self.message_manager.add_error(f"Failed to auto-load ESRGAN model '{model_key}'. Cannot upscale.")
return None
loaded_config = self.upsamplers.get(model_key) # Re-fetch config after load
if loaded_config:
current_tile_size = loaded_config.get("tile_size", 0) # type: ignore
model_native_scale = loaded_config.get("native_scale", 0) # type: ignore
if model_native_scale == 0:
self.message_manager.add_error(f"Error: Native scale for auto-loaded model '{model_key}' is 0.")
return None
else:
self.message_manager.add_error(f"Error: Config for auto-loaded model '{model_key}' not found.")
return None
# Validate target_outscale_factor against model's native scale.
# Allow outscale from a small factor up to the model's native scale.
# You could allow slightly more (e.g., model_native_scale * 1.1) if you want to permit minor bicubic post-upscale.
# For now, strictly <= native_scale.
if not (0.25 <= target_outscale_factor <= model_native_scale):
self.message_manager.add_warning(
f"Target outscale factor {target_outscale_factor:.2f}x is outside the recommended range "
f"(0.25x to {model_native_scale:.2f}x) for model '{model_key}' (native {model_native_scale}x). "
f"Adjusting to model's native scale {model_native_scale:.2f}x."
)
target_outscale_factor = float(model_native_scale)
if enhance_face:
if not self.face_enhancer or (hasattr(self.face_enhancer, 'bg_upsampler') and self.face_enhancer.bg_upsampler != upsampler):
self.message_manager.add_message("Face enhancement requested, loading/re-configuring GFPGAN...")
self._load_face_enhancer(bg_upsampler=upsampler)
if not self.face_enhancer:
self.message_manager.add_warning("GFPGAN could not be loaded. Proceeding without face enhancement.")
enhance_face = False
try:
img_bgr = frame_np_array[:, :, ::-1]
outscale_for_enhance = float(target_outscale_factor)
if enhance_face and self.face_enhancer:
if self.face_enhancer.upscale != 1: # Ensure GFPGAN is only cleaning, not upscaling itself in this pipeline path
self.message_manager.add_warning(
f"GFPGANer's internal upscale is {self.face_enhancer.upscale}, but for the 'Clean Face -> ESRGAN Upscale' pipeline, "
f"it should be 1. RealESRGAN will handle the main scaling to {target_outscale_factor:.2f}x."
)
_, _, cleaned_img_bgr = self.face_enhancer.enhance(img_bgr, has_aligned=False, only_center_face=False, paste_back=True)
output_bgr, _ = upsampler.enhance(cleaned_img_bgr, outscale=outscale_for_enhance)
else:
output_bgr, _ = upsampler.enhance(img_bgr, outscale=outscale_for_enhance)
output_rgb = output_bgr[:, :, ::-1]
return output_rgb
except Exception as e:
tile_size_msg_part = str(current_tile_size) if current_tile_size > 0 else 'Auto'
face_msg_part = " + Face Enhance" if enhance_face else ""
self.message_manager.add_error(
f"Error during ESRGAN frame upscaling (Model: {model_key}{face_msg_part}, "
f"Target Scale: {target_outscale_factor:.2f}x, Native: {model_native_scale}x, Tile: {tile_size_msg_part}): {e}"
)
self.message_manager.add_error(traceback.format_exc())
if "out of memory" in str(e).lower() and self.device.type == 'cuda':
self.message_manager.add_warning(
"CUDA OOM during upscaling. Emptying cache. "
f"Current model (Model: {model_key}, Tile: {tile_size_msg_part}) may need reloading. "
"Consider using a smaller tile size or a smaller input video if issues persist."
)
devicetorch.empty_cache(torch)
return None