import torch import os # required for os.path from abc import ABC, abstractmethod from diffusers_helper import lora_utils from typing import List, Optional from pathlib import Path class BaseModelGenerator(ABC): """ Base class for model generators. This defines the common interface that all model generators must implement. """ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, image_encoder, feature_extractor, high_vram=False, prompt_embedding_cache=None, settings=None, offline=False): # NEW: offline flag """ Initialize the base model generator. Args: text_encoder: The text encoder model text_encoder_2: The second text encoder model tokenizer: The tokenizer for the first text encoder tokenizer_2: The tokenizer for the second text encoder vae: The VAE model image_encoder: The image encoder model feature_extractor: The feature extractor high_vram: Whether high VRAM mode is enabled prompt_embedding_cache: Cache for prompt embeddings settings: Application settings offline: Whether to run in offline mode for model loading """ self.text_encoder = text_encoder self.text_encoder_2 = text_encoder_2 self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 self.vae = vae self.image_encoder = image_encoder self.feature_extractor = feature_extractor self.high_vram = high_vram self.prompt_embedding_cache = prompt_embedding_cache or {} self.settings = settings self.offline = offline self.transformer = None self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.cpu = torch.device("cpu") @abstractmethod def load_model(self): """ Load the transformer model. This method should be implemented by each specific model generator. """ pass @abstractmethod def get_model_name(self): """ Get the name of the model. This method should be implemented by each specific model generator. """ pass @staticmethod def _get_snapshot_hash_from_refs(model_repo_id_for_cache: str) -> str | None: """ Reads the commit hash from the refs/main file for a given model in the HF cache. Args: model_repo_id_for_cache (str): The model ID formatted for cache directory names (e.g., "models--lllyasviel--FramePackI2V_HY"). Returns: str: The commit hash if found, otherwise None. """ hf_home_dir = os.environ.get('HF_HOME') if not hf_home_dir: print("Warning: HF_HOME environment variable not set. Cannot determine snapshot hash.") return None refs_main_path = os.path.join(hf_home_dir, 'hub', model_repo_id_for_cache, 'refs', 'main') if os.path.exists(refs_main_path): try: with open(refs_main_path, 'r') as f: print(f"Offline mode: Reading snapshot hash from: {refs_main_path}") return f.read().strip() except Exception as e: print(f"Warning: Could not read snapshot hash from {refs_main_path}: {e}") return None else: print(f"Warning: refs/main file not found at {refs_main_path}. Cannot determine snapshot hash.") return None def _get_offline_load_path(self) -> str: """ Returns the local snapshot path for offline loading if available. Falls back to the default self.model_path if local snapshot can't be found. Relies on self.model_repo_id_for_cache and self.model_path being set by subclasses. """ # Ensure necessary attributes are set by the subclass if not hasattr(self, 'model_repo_id_for_cache') or not self.model_repo_id_for_cache: print(f"Warning: model_repo_id_for_cache not set in {self.__class__.__name__}. Cannot determine offline path.") # Fallback to model_path if it exists, otherwise None return getattr(self, 'model_path', None) if not hasattr(self, 'model_path') or not self.model_path: print(f"Warning: model_path not set in {self.__class__.__name__}. Cannot determine fallback for offline path.") return None snapshot_hash = self._get_snapshot_hash_from_refs(self.model_repo_id_for_cache) hf_home = os.environ.get('HF_HOME') if snapshot_hash and hf_home: specific_snapshot_path = os.path.join( hf_home, 'hub', self.model_repo_id_for_cache, 'snapshots', snapshot_hash ) if os.path.isdir(specific_snapshot_path): return specific_snapshot_path # If snapshot logic fails or path is not a dir, fallback to the default model path return self.model_path def unload_loras(self): """ Unload all LoRAs from the transformer model. """ if self.transformer is not None: print(f"Unloading all LoRAs from {self.get_model_name()} model") self.transformer = lora_utils.unload_all_loras(self.transformer) self.verify_lora_state("After unloading LoRAs") import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def verify_lora_state(self, label=""): """ Debug function to verify the state of LoRAs in the transformer model. """ if self.transformer is None: print(f"[{label}] Transformer is None, cannot verify LoRA state") return has_loras = False if hasattr(self.transformer, 'peft_config'): adapter_names = list(self.transformer.peft_config.keys()) if self.transformer.peft_config else [] if adapter_names: has_loras = True print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}") else: print(f"[{label}] Transformer has no LoRAs in peft_config") else: print(f"[{label}] Transformer has no peft_config attribute") # Check for any LoRA modules for name, module in self.transformer.named_modules(): if hasattr(module, 'lora_A') and module.lora_A: has_loras = True # print(f"[{label}] Found lora_A in module {name}") if hasattr(module, 'lora_B') and module.lora_B: has_loras = True # print(f"[{label}] Found lora_B in module {name}") if not has_loras: print(f"[{label}] No LoRA components found in transformer") def move_lora_adapters_to_device(self, target_device): """ Move all LoRA adapters in the transformer model to the specified device. This handles the PEFT implementation of LoRA. """ if self.transformer is None: return print(f"Moving all LoRA adapters to {target_device}") # First, find all modules with LoRA adapters lora_modules = [] for name, module in self.transformer.named_modules(): if hasattr(module, 'active_adapter') and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'): lora_modules.append((name, module)) # Now move all LoRA components to the target device for name, module in lora_modules: # Get the active adapter name active_adapter = module.active_adapter # Move the LoRA layers to the target device if active_adapter is not None: if isinstance(module.lora_A, torch.nn.ModuleDict): # Handle ModuleDict case (PEFT implementation) for adapter_name in list(module.lora_A.keys()): # Move lora_A if adapter_name in module.lora_A: module.lora_A[adapter_name] = module.lora_A[adapter_name].to(target_device) # Move lora_B if adapter_name in module.lora_B: module.lora_B[adapter_name] = module.lora_B[adapter_name].to(target_device) # Move scaling if hasattr(module, 'scaling') and isinstance(module.scaling, dict) and adapter_name in module.scaling: if isinstance(module.scaling[adapter_name], torch.Tensor): module.scaling[adapter_name] = module.scaling[adapter_name].to(target_device) else: # Handle direct attribute case if hasattr(module, 'lora_A') and module.lora_A is not None: module.lora_A = module.lora_A.to(target_device) if hasattr(module, 'lora_B') and module.lora_B is not None: module.lora_B = module.lora_B.to(target_device) if hasattr(module, 'scaling') and module.scaling is not None: if isinstance(module.scaling, torch.Tensor): module.scaling = module.scaling.to(target_device) print(f"Moved all LoRA adapters to {target_device}") def load_loras(self, selected_loras: List[str], lora_folder: str, lora_loaded_names: List[str], lora_values: Optional[List[float]] = None): """ Load LoRAs into the transformer model and applies their weights. Args: selected_loras: List of LoRA base names to load (e.g., ["lora_A", "lora_B"]). lora_folder: Path to the folder containing the LoRA files. lora_loaded_names: The master list of ALL available LoRA names, used for correct weight indexing. lora_values: A list of strength values corresponding to lora_loaded_names. """ self.unload_loras() if not selected_loras: print("No LoRAs selected, skipping loading.") return lora_dir = Path(lora_folder) adapter_names = [] strengths = [] for idx, lora_base_name in enumerate(selected_loras): lora_file = None for ext in (".safetensors", ".pt"): candidate_path_relative = f"{lora_base_name}{ext}" candidate_path_full = lora_dir / candidate_path_relative if candidate_path_full.is_file(): lora_file = candidate_path_relative break if not lora_file: print(f"Warning: LoRA file for base name '{lora_base_name}' not found; skipping.") continue print(f"Loading LoRA from '{lora_file}'...") self.transformer, adapter_name = lora_utils.load_lora(self.transformer, lora_dir, lora_file) adapter_names.append(adapter_name) weight = 1.0 if lora_values: try: master_list_idx = lora_loaded_names.index(lora_base_name) if master_list_idx < len(lora_values): weight = float(lora_values[master_list_idx]) else: print(f"Warning: Index mismatch for '{lora_base_name}'. Defaulting to 1.0.") except ValueError: print(f"Warning: LoRA '{lora_base_name}' not found in master list. Defaulting to 1.0.") strengths.append(weight) if adapter_names: print(f"Activating adapters: {adapter_names} with strengths: {strengths}") lora_utils.set_adapters(self.transformer, adapter_names, strengths) self.verify_lora_state("After completing load_loras")