import torch import numpy as np from torchvision.transforms.functional import to_tensor, to_pil_image from pathlib import Path import os import gc from huggingface_hub import snapshot_download from .RIFE.RIFE_HDv3 import Model as RIFEBaseModel from .message_manager import MessageManager import devicetorch # Get the directory of the current script (rife_core.py) _MODULE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) # __file__ gives path to current script # MODEL_RIFE_PATH = "model_rife" # OLD - this is relative to CWD MODEL_RIFE_PATH = _MODULE_DIR / "model_rife" # NEW - relative to this script's location RIFE_MODEL_FILENAME = "flownet.pkl" class RIFEHandler: def __init__(self, message_manager: MessageManager = None): self.message_manager = message_manager if message_manager else MessageManager() self.model_dir = Path(MODEL_RIFE_PATH) # Path() constructor handles Path objects correctly self.model_file_path = self.model_dir / RIFE_MODEL_FILENAME self.rife_model = None def _log(self, message, level="INFO"): # Helper for logging using the MessageManager if level.upper() == "ERROR": self.message_manager.add_error(f"RIFEHandler: {message}") elif level.upper() == "WARNING": self.message_manager.add_warning(f"RIFEHandler: {message}") else: self.message_manager.add_message(f"RIFEHandler: {message}") def _ensure_model_downloaded_and_loaded(self) -> bool: if self.rife_model is not None: self._log("RIFE model already loaded.") return True # self.model_dir is now an absolute path if not self.model_dir.exists(): os.makedirs(self.model_dir, exist_ok=True) self._log(f"Created RIFE model directory: {self.model_dir}") # self.model_file_path is now an absolute path if not self.model_file_path.exists(): self._log("RIFE model weights not found. Downloading...") try: snapshot_download( repo_id="AlexWortega/RIFE", allow_patterns=["*.pkl", "*.pth"], local_dir=self.model_dir, # Pass the absolute path local_dir_use_symlinks=False ) if self.model_file_path.exists(): self._log("RIFE model weights downloaded successfully.") else: self._log(f"RIFE model download completed, but {RIFE_MODEL_FILENAME} not found in {self.model_dir}. Check allow_patterns and repo structure.", "ERROR") return False except Exception as e: self._log(f"Failed to download RIFE model weights: {e}", "ERROR") return False if not self.model_file_path.exists(): self._log(f"RIFE model file {self.model_file_path} does not exist. Cannot load model.", "ERROR") return False try: self._log(f"Loading RIFE model from {self.model_dir}...") # self.model_dir is absolute current_device_str = devicetorch.get(torch) self.rife_model = RIFEBaseModel(local_rank=-1) self.rife_model.load_model(str(self.model_dir), -1) # str(self.model_dir) is absolute self.rife_model.eval() self._log(f"RIFE model loaded successfully to its determined device.") return True except Exception as e: self._log(f"Failed to load RIFE model: {e}", "ERROR") import traceback self._log(f"Traceback: {traceback.format_exc()}", "ERROR") self.rife_model = None return False def unload_model(self): if self.rife_model is not None: self._log("Unloading RIFE model...") del self.rife_model self.rife_model = None devicetorch.empty_cache(torch) gc.collect() self._log("RIFE model unloaded and memory cleared.") else: self._log("RIFE model not loaded, no need to unload.") def interpolate_between_frames(self, frame1_np: np.ndarray, frame2_np: np.ndarray) -> np.ndarray | None: if self.rife_model is None: self._log("RIFE model not loaded. Call _ensure_model_downloaded_and_loaded() before interpolation.", "ERROR") return None try: img0_tensor = to_tensor(frame1_np).unsqueeze(0) img1_tensor = to_tensor(frame2_np).unsqueeze(0) img0 = devicetorch.to(torch, img0_tensor) img1 = devicetorch.to(torch, img1_tensor) required_multiple = 32 h_orig, w_orig = img0.shape[2], img0.shape[3] pad_h = (required_multiple - h_orig % required_multiple) % required_multiple pad_w = (required_multiple - w_orig % required_multiple) % required_multiple if pad_h > 0 or pad_w > 0: img0 = torch.nn.functional.pad(img0, (0, pad_w, 0, pad_h), mode='replicate') img1 = torch.nn.functional.pad(img1, (0, pad_w, 0, pad_h), mode='replicate') with torch.no_grad(): middle_frame_tensor = self.rife_model.inference(img0, img1, scale=1.0) if pad_h > 0 or pad_w > 0: middle_frame_tensor = middle_frame_tensor[:, :, :h_orig, :w_orig] middle_frame_pil = to_pil_image(middle_frame_tensor.squeeze(0).cpu()) return np.array(middle_frame_pil) except Exception as e: self._log(f"Error during RIFE frame interpolation: {e}", "ERROR") import traceback self._log(f"Traceback: {traceback.format_exc()}", "ERROR") if "out of memory" in str(e).lower(): devicetorch.empty_cache(torch) return None