File size: 13,735 Bytes
84669a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import torch
import os # required for os.path
from abc import ABC, abstractmethod
from diffusers_helper import lora_utils

class BaseModelGenerator(ABC):
    """
    Base class for model generators.
    This defines the common interface that all model generators must implement.
    """
    
    def __init__(self, 
                 text_encoder, 
                 text_encoder_2, 
                 tokenizer, 
                 tokenizer_2, 
                 vae, 
                 image_encoder, 
                 feature_extractor, 
                 high_vram=False,
                 prompt_embedding_cache=None,
                 settings=None,
                 offline=False): # NEW: offline flag
        """
        Initialize the base model generator.
        
        Args:
            text_encoder: The text encoder model
            text_encoder_2: The second text encoder model
            tokenizer: The tokenizer for the first text encoder
            tokenizer_2: The tokenizer for the second text encoder
            vae: The VAE model
            image_encoder: The image encoder model
            feature_extractor: The feature extractor
            high_vram: Whether high VRAM mode is enabled
            prompt_embedding_cache: Cache for prompt embeddings
            settings: Application settings
            offline: Whether to run in offline mode for model loading
        """
        self.text_encoder = text_encoder
        self.text_encoder_2 = text_encoder_2
        self.tokenizer = tokenizer
        self.tokenizer_2 = tokenizer_2
        self.vae = vae
        self.image_encoder = image_encoder
        self.feature_extractor = feature_extractor
        self.high_vram = high_vram
        self.prompt_embedding_cache = prompt_embedding_cache or {}
        self.settings = settings
        self.offline = offline 
        self.transformer = None
        self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.cpu = torch.device("cpu")

            
    @abstractmethod
    def load_model(self):
        """
        Load the transformer model.
        This method should be implemented by each specific model generator.
        """
        pass
    
    @abstractmethod
    def get_model_name(self):
        """
        Get the name of the model.
        This method should be implemented by each specific model generator.
        """
        pass

    @staticmethod
    def _get_snapshot_hash_from_refs(model_repo_id_for_cache: str) -> str | None:
        """
        Reads the commit hash from the refs/main file for a given model in the HF cache.
        Args:
            model_repo_id_for_cache (str): The model ID formatted for cache directory names
                                           (e.g., "models--lllyasviel--FramePackI2V_HY").
        Returns:
            str: The commit hash if found, otherwise None.
        """
        hf_home_dir = os.environ.get('HF_HOME')
        if not hf_home_dir:
            print("Warning: HF_HOME environment variable not set. Cannot determine snapshot hash.")
            return None
            
        refs_main_path = os.path.join(hf_home_dir, 'hub', model_repo_id_for_cache, 'refs', 'main')
        if os.path.exists(refs_main_path):
            try:
                with open(refs_main_path, 'r') as f:
                    print(f"Offline mode: Reading snapshot hash from: {refs_main_path}")
                    return f.read().strip()
            except Exception as e:
                print(f"Warning: Could not read snapshot hash from {refs_main_path}: {e}")
                return None
        else:
            print(f"Warning: refs/main file not found at {refs_main_path}. Cannot determine snapshot hash.")
            return None

    def _get_offline_load_path(self) -> str:
        """
        Returns the local snapshot path for offline loading if available.
        Falls back to the default self.model_path if local snapshot can't be found.
        Relies on self.model_repo_id_for_cache and self.model_path being set by subclasses.
        """
        # Ensure necessary attributes are set by the subclass
        if not hasattr(self, 'model_repo_id_for_cache') or not self.model_repo_id_for_cache:
            print(f"Warning: model_repo_id_for_cache not set in {self.__class__.__name__}. Cannot determine offline path.")
            # Fallback to model_path if it exists, otherwise None
            return getattr(self, 'model_path', None) 

        if not hasattr(self, 'model_path') or not self.model_path:
            print(f"Warning: model_path not set in {self.__class__.__name__}. Cannot determine fallback for offline path.")
            return None

        snapshot_hash = self._get_snapshot_hash_from_refs(self.model_repo_id_for_cache)
        hf_home = os.environ.get('HF_HOME')

        if snapshot_hash and hf_home:
            specific_snapshot_path = os.path.join(
                hf_home, 'hub', self.model_repo_id_for_cache, 'snapshots', snapshot_hash
            )
            if os.path.isdir(specific_snapshot_path):
                return specific_snapshot_path
                
        # If snapshot logic fails or path is not a dir, fallback to the default model path
        return self.model_path
        
    def unload_loras(self):
        """
        Unload all LoRAs from the transformer model.
        """
        if self.transformer is not None:
            print(f"Unloading all LoRAs from {self.get_model_name()} model")
            self.transformer = lora_utils.unload_all_loras(self.transformer)
            self.verify_lora_state("After unloading LoRAs")
            import gc
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    def verify_lora_state(self, label=""):
        """
        Debug function to verify the state of LoRAs in the transformer model.
        """
        if self.transformer is None:
            print(f"[{label}] Transformer is None, cannot verify LoRA state")
            return
            
        has_loras = False
        if hasattr(self.transformer, 'peft_config'):
            adapter_names = list(self.transformer.peft_config.keys()) if self.transformer.peft_config else []
            if adapter_names:
                has_loras = True
                print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}")
            else:
                print(f"[{label}] Transformer has no LoRAs in peft_config")
        else:
            print(f"[{label}] Transformer has no peft_config attribute")
            
        # Check for any LoRA modules
        for name, module in self.transformer.named_modules():
            if hasattr(module, 'lora_A') and module.lora_A:
                has_loras = True
                # print(f"[{label}] Found lora_A in module {name}")
            if hasattr(module, 'lora_B') and module.lora_B:
                has_loras = True
                # print(f"[{label}] Found lora_B in module {name}")
                
        if not has_loras:
            print(f"[{label}] No LoRA components found in transformer")
    
    def move_lora_adapters_to_device(self, target_device):
        """
        Move all LoRA adapters in the transformer model to the specified device.
        This handles the PEFT implementation of LoRA.
        """
        if self.transformer is None:
            return
            
        print(f"Moving all LoRA adapters to {target_device}")
        
        # First, find all modules with LoRA adapters
        lora_modules = []
        for name, module in self.transformer.named_modules():
            if hasattr(module, 'active_adapter') and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
                lora_modules.append((name, module))
        
        # Now move all LoRA components to the target device
        for name, module in lora_modules:
            # Get the active adapter name
            active_adapter = module.active_adapter
            
            # Move the LoRA layers to the target device
            if active_adapter is not None:
                if isinstance(module.lora_A, torch.nn.ModuleDict):
                    # Handle ModuleDict case (PEFT implementation)
                    for adapter_name in list(module.lora_A.keys()):
                        # Move lora_A
                        if adapter_name in module.lora_A:
                            module.lora_A[adapter_name] = module.lora_A[adapter_name].to(target_device)
                        
                        # Move lora_B
                        if adapter_name in module.lora_B:
                            module.lora_B[adapter_name] = module.lora_B[adapter_name].to(target_device)
                        
                        # Move scaling
                        if hasattr(module, 'scaling') and isinstance(module.scaling, dict) and adapter_name in module.scaling:
                            if isinstance(module.scaling[adapter_name], torch.Tensor):
                                module.scaling[adapter_name] = module.scaling[adapter_name].to(target_device)
                else:
                    # Handle direct attribute case
                    if hasattr(module, 'lora_A') and module.lora_A is not None:
                        module.lora_A = module.lora_A.to(target_device)
                    if hasattr(module, 'lora_B') and module.lora_B is not None:
                        module.lora_B = module.lora_B.to(target_device)
                    if hasattr(module, 'scaling') and module.scaling is not None:
                        if isinstance(module.scaling, torch.Tensor):
                            module.scaling = module.scaling.to(target_device)
        
        print(f"Moved all LoRA adapters to {target_device}")
    
    def load_loras(self, selected_loras, lora_folder, lora_loaded_names, lora_values=None):
        """
        Load LoRAs into the transformer model.
        
        Args:
            selected_loras: List of LoRA names to load
            lora_folder: Folder containing the LoRA files
            lora_loaded_names: List of loaded LoRA names
            lora_values: Optional list of LoRA strength values
        """
        if self.transformer is None:
            print("Cannot load LoRAs: Transformer model is not loaded")
            return
            
        import os
        
        # Ensure all LoRAs are unloaded first
        self.unload_loras()
        
        # Load each selected LoRA
        if isinstance(selected_loras, list):
            for lora_name in selected_loras:
                try:
                    idx = lora_loaded_names.index(lora_name)
                    lora_file = None
                    for ext in [".safetensors", ".pt"]:
                        candidate_path_relative = f"{lora_name}{ext}"
                        candidate_path_full = os.path.join(lora_folder, candidate_path_relative)
                        if os.path.isfile(candidate_path_full):
                            lora_file = candidate_path_relative
                            break
                            
                    if lora_file:
                        print(f"Loading LoRA '{lora_file}' to {self.get_model_name()} model")
                        self.transformer = lora_utils.load_lora(self.transformer, lora_folder, lora_file)
                        
                        # Set LoRA strength if provided
                        if lora_values and idx < len(lora_values):
                            lora_strength = float(lora_values[idx])
                            print(f"Setting LoRA '{lora_name}' strength to {lora_strength}")
                            
                            # Set scaling for this LoRA by iterating through modules
                            for name, module in self.transformer.named_modules():
                                if hasattr(module, 'scaling'):
                                    if isinstance(module.scaling, dict):
                                        # Handle ModuleDict case (PEFT implementation)
                                        if lora_name in module.scaling:
                                            if isinstance(module.scaling[lora_name], torch.Tensor):
                                                module.scaling[lora_name] = torch.tensor(
                                                    lora_strength, device=module.scaling[lora_name].device
                                                )
                                            else:
                                                module.scaling[lora_name] = lora_strength
                                    else:
                                        # Handle direct attribute case for scaling if needed
                                        if isinstance(module.scaling, torch.Tensor):
                                            module.scaling = torch.tensor(
                                                lora_strength, device=module.scaling.device
                                            )
                                        else:
                                            module.scaling = lora_strength
                    else:
                        print(f"LoRA file for {lora_name} not found!")
                except Exception as e:
                    print(f"Error loading LoRA {lora_name}: {e}")
        else:
            print(f"Warning: selected_loras is not a list (type: {type(selected_loras)}), skipping LoRA loading.")
        
        # Verify LoRA state after loading
        self.verify_lora_state("After loading LoRAs")
# with the `if` condition and the `for` loop, and then I will provide the *entire rest of the function*