FPS-Studio / modules /toolbox /rife_core.py
rahul7star's picture
Migrated from GitHub
05fcd0f verified
import torch
import numpy as np
from torchvision.transforms.functional import to_tensor, to_pil_image
from pathlib import Path
import os
import gc
from huggingface_hub import snapshot_download
from .RIFE.RIFE_HDv3 import Model as RIFEBaseModel
from .message_manager import MessageManager
import devicetorch
# Get the directory of the current script (rife_core.py)
_MODULE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) # __file__ gives path to current script
# MODEL_RIFE_PATH = "model_rife" # OLD - this is relative to CWD
MODEL_RIFE_PATH = _MODULE_DIR / "model_rife" # NEW - relative to this script's location
RIFE_MODEL_FILENAME = "flownet.pkl"
class RIFEHandler:
def __init__(self, message_manager: MessageManager = None):
self.message_manager = message_manager if message_manager else MessageManager()
self.model_dir = Path(MODEL_RIFE_PATH) # Path() constructor handles Path objects correctly
self.model_file_path = self.model_dir / RIFE_MODEL_FILENAME
self.rife_model = None
def _log(self, message, level="INFO"):
# Helper for logging using the MessageManager
if level.upper() == "ERROR":
self.message_manager.add_error(f"RIFEHandler: {message}")
elif level.upper() == "WARNING":
self.message_manager.add_warning(f"RIFEHandler: {message}")
else:
self.message_manager.add_message(f"RIFEHandler: {message}")
def _ensure_model_downloaded_and_loaded(self) -> bool:
if self.rife_model is not None:
self._log("RIFE model already loaded.")
return True
# self.model_dir is now an absolute path
if not self.model_dir.exists():
os.makedirs(self.model_dir, exist_ok=True)
self._log(f"Created RIFE model directory: {self.model_dir}")
# self.model_file_path is now an absolute path
if not self.model_file_path.exists():
self._log("RIFE model weights not found. Downloading...")
try:
snapshot_download(
repo_id="AlexWortega/RIFE",
allow_patterns=["*.pkl", "*.pth"],
local_dir=self.model_dir, # Pass the absolute path
local_dir_use_symlinks=False
)
if self.model_file_path.exists():
self._log("RIFE model weights downloaded successfully.")
else:
self._log(f"RIFE model download completed, but {RIFE_MODEL_FILENAME} not found in {self.model_dir}. Check allow_patterns and repo structure.", "ERROR")
return False
except Exception as e:
self._log(f"Failed to download RIFE model weights: {e}", "ERROR")
return False
if not self.model_file_path.exists():
self._log(f"RIFE model file {self.model_file_path} does not exist. Cannot load model.", "ERROR")
return False
try:
self._log(f"Loading RIFE model from {self.model_dir}...") # self.model_dir is absolute
current_device_str = devicetorch.get(torch)
self.rife_model = RIFEBaseModel(local_rank=-1)
self.rife_model.load_model(str(self.model_dir), -1) # str(self.model_dir) is absolute
self.rife_model.eval()
self._log(f"RIFE model loaded successfully to its determined device.")
return True
except Exception as e:
self._log(f"Failed to load RIFE model: {e}", "ERROR")
import traceback
self._log(f"Traceback: {traceback.format_exc()}", "ERROR")
self.rife_model = None
return False
def unload_model(self):
if self.rife_model is not None:
self._log("Unloading RIFE model...")
del self.rife_model
self.rife_model = None
devicetorch.empty_cache(torch)
gc.collect()
self._log("RIFE model unloaded and memory cleared.")
else:
self._log("RIFE model not loaded, no need to unload.")
def interpolate_between_frames(self, frame1_np: np.ndarray, frame2_np: np.ndarray) -> np.ndarray | None:
if self.rife_model is None:
self._log("RIFE model not loaded. Call _ensure_model_downloaded_and_loaded() before interpolation.", "ERROR")
return None
try:
img0_tensor = to_tensor(frame1_np).unsqueeze(0)
img1_tensor = to_tensor(frame2_np).unsqueeze(0)
img0 = devicetorch.to(torch, img0_tensor)
img1 = devicetorch.to(torch, img1_tensor)
required_multiple = 32
h_orig, w_orig = img0.shape[2], img0.shape[3]
pad_h = (required_multiple - h_orig % required_multiple) % required_multiple
pad_w = (required_multiple - w_orig % required_multiple) % required_multiple
if pad_h > 0 or pad_w > 0:
img0 = torch.nn.functional.pad(img0, (0, pad_w, 0, pad_h), mode='replicate')
img1 = torch.nn.functional.pad(img1, (0, pad_w, 0, pad_h), mode='replicate')
with torch.no_grad():
middle_frame_tensor = self.rife_model.inference(img0, img1, scale=1.0)
if pad_h > 0 or pad_w > 0:
middle_frame_tensor = middle_frame_tensor[:, :, :h_orig, :w_orig]
middle_frame_pil = to_pil_image(middle_frame_tensor.squeeze(0).cpu())
return np.array(middle_frame_pil)
except Exception as e:
self._log(f"Error during RIFE frame interpolation: {e}", "ERROR")
import traceback
self._log(f"Traceback: {traceback.format_exc()}", "ERROR")
if "out of memory" in str(e).lower():
devicetorch.empty_cache(torch)
return None