File size: 5,880 Bytes
05fcd0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import numpy as np
from torchvision.transforms.functional import to_tensor, to_pil_image
from pathlib import Path
import os
import gc 
from huggingface_hub import snapshot_download 

from .RIFE.RIFE_HDv3 import Model as RIFEBaseModel 
from .message_manager import MessageManager 
import devicetorch

# Get the directory of the current script (rife_core.py)
_MODULE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) # __file__ gives path to current script

# MODEL_RIFE_PATH = "model_rife" # OLD - this is relative to CWD
MODEL_RIFE_PATH = _MODULE_DIR / "model_rife" # NEW - relative to this script's location
RIFE_MODEL_FILENAME = "flownet.pkl"

class RIFEHandler: 
    def __init__(self, message_manager: MessageManager = None):
        self.message_manager = message_manager if message_manager else MessageManager()
        self.model_dir = Path(MODEL_RIFE_PATH) # Path() constructor handles Path objects correctly
        self.model_file_path = self.model_dir / RIFE_MODEL_FILENAME
        self.rife_model = None 

    def _log(self, message, level="INFO"):
        # Helper for logging using the MessageManager
        if level.upper() == "ERROR":
            self.message_manager.add_error(f"RIFEHandler: {message}")
        elif level.upper() == "WARNING":
            self.message_manager.add_warning(f"RIFEHandler: {message}")
        else:
            self.message_manager.add_message(f"RIFEHandler: {message}")

    def _ensure_model_downloaded_and_loaded(self) -> bool:
        if self.rife_model is not None:
            self._log("RIFE model already loaded.")
            return True

        # self.model_dir is now an absolute path
        if not self.model_dir.exists():
             os.makedirs(self.model_dir, exist_ok=True)
             self._log(f"Created RIFE model directory: {self.model_dir}")

        # self.model_file_path is now an absolute path
        if not self.model_file_path.exists():
            self._log("RIFE model weights not found. Downloading...")
            try:
                snapshot_download(
                    repo_id="AlexWortega/RIFE", 
                    allow_patterns=["*.pkl", "*.pth"], 
                    local_dir=self.model_dir, # Pass the absolute path
                    local_dir_use_symlinks=False
                )
                if self.model_file_path.exists():
                    self._log("RIFE model weights downloaded successfully.")
                else:
                    self._log(f"RIFE model download completed, but {RIFE_MODEL_FILENAME} not found in {self.model_dir}. Check allow_patterns and repo structure.", "ERROR")
                    return False 
            except Exception as e:
                self._log(f"Failed to download RIFE model weights: {e}", "ERROR")
                return False 

        if not self.model_file_path.exists(): 
            self._log(f"RIFE model file {self.model_file_path} does not exist. Cannot load model.", "ERROR")
            return False

        try:
            self._log(f"Loading RIFE model from {self.model_dir}...") # self.model_dir is absolute
            current_device_str = devicetorch.get(torch) 
            self.rife_model = RIFEBaseModel(local_rank=-1) 
            
            self.rife_model.load_model(str(self.model_dir), -1) # str(self.model_dir) is absolute
            self.rife_model.eval()
            self._log(f"RIFE model loaded successfully to its determined device.")
            return True
        except Exception as e:
            self._log(f"Failed to load RIFE model: {e}", "ERROR")
            import traceback
            self._log(f"Traceback: {traceback.format_exc()}", "ERROR")
            self.rife_model = None 
            return False

    def unload_model(self):
        if self.rife_model is not None:
            self._log("Unloading RIFE model...")
            del self.rife_model 
            self.rife_model = None
            devicetorch.empty_cache(torch) 
            gc.collect() 
            self._log("RIFE model unloaded and memory cleared.")
        else:
            self._log("RIFE model not loaded, no need to unload.")

    def interpolate_between_frames(self, frame1_np: np.ndarray, frame2_np: np.ndarray) -> np.ndarray | None:
        if self.rife_model is None:
            self._log("RIFE model not loaded. Call _ensure_model_downloaded_and_loaded() before interpolation.", "ERROR")
            return None

        try:
            img0_tensor = to_tensor(frame1_np).unsqueeze(0)
            img1_tensor = to_tensor(frame2_np).unsqueeze(0)
            
            img0 = devicetorch.to(torch, img0_tensor)
            img1 = devicetorch.to(torch, img1_tensor)


            required_multiple = 32 
            h_orig, w_orig = img0.shape[2], img0.shape[3]
            pad_h = (required_multiple - h_orig % required_multiple) % required_multiple
            pad_w = (required_multiple - w_orig % required_multiple) % required_multiple

            if pad_h > 0 or pad_w > 0:
                img0 = torch.nn.functional.pad(img0, (0, pad_w, 0, pad_h), mode='replicate')
                img1 = torch.nn.functional.pad(img1, (0, pad_w, 0, pad_h), mode='replicate')

            with torch.no_grad():
                middle_frame_tensor = self.rife_model.inference(img0, img1, scale=1.0) 

            if pad_h > 0 or pad_w > 0:
                middle_frame_tensor = middle_frame_tensor[:, :, :h_orig, :w_orig]

            middle_frame_pil = to_pil_image(middle_frame_tensor.squeeze(0).cpu())
            return np.array(middle_frame_pil)

        except Exception as e:
            self._log(f"Error during RIFE frame interpolation: {e}", "ERROR")
            import traceback
            self._log(f"Traceback: {traceback.format_exc()}", "ERROR")
            if "out of memory" in str(e).lower():
                devicetorch.empty_cache(torch)
            return None