Spaces:
Paused
Paused
import torch | |
import os # required for os.path | |
from abc import ABC, abstractmethod | |
from diffusers_helper import lora_utils | |
from typing import List, Optional | |
from pathlib import Path | |
class BaseModelGenerator(ABC): | |
""" | |
Base class for model generators. | |
This defines the common interface that all model generators must implement. | |
""" | |
def __init__(self, | |
text_encoder, | |
text_encoder_2, | |
tokenizer, | |
tokenizer_2, | |
vae, | |
image_encoder, | |
feature_extractor, | |
high_vram=False, | |
prompt_embedding_cache=None, | |
settings=None, | |
offline=False): # NEW: offline flag | |
""" | |
Initialize the base model generator. | |
Args: | |
text_encoder: The text encoder model | |
text_encoder_2: The second text encoder model | |
tokenizer: The tokenizer for the first text encoder | |
tokenizer_2: The tokenizer for the second text encoder | |
vae: The VAE model | |
image_encoder: The image encoder model | |
feature_extractor: The feature extractor | |
high_vram: Whether high VRAM mode is enabled | |
prompt_embedding_cache: Cache for prompt embeddings | |
settings: Application settings | |
offline: Whether to run in offline mode for model loading | |
""" | |
self.text_encoder = text_encoder | |
self.text_encoder_2 = text_encoder_2 | |
self.tokenizer = tokenizer | |
self.tokenizer_2 = tokenizer_2 | |
self.vae = vae | |
self.image_encoder = image_encoder | |
self.feature_extractor = feature_extractor | |
self.high_vram = high_vram | |
self.prompt_embedding_cache = prompt_embedding_cache or {} | |
self.settings = settings | |
self.offline = offline | |
self.transformer = None | |
self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.cpu = torch.device("cpu") | |
def load_model(self): | |
""" | |
Load the transformer model. | |
This method should be implemented by each specific model generator. | |
""" | |
pass | |
def get_model_name(self): | |
""" | |
Get the name of the model. | |
This method should be implemented by each specific model generator. | |
""" | |
pass | |
def _get_snapshot_hash_from_refs(model_repo_id_for_cache: str) -> str | None: | |
""" | |
Reads the commit hash from the refs/main file for a given model in the HF cache. | |
Args: | |
model_repo_id_for_cache (str): The model ID formatted for cache directory names | |
(e.g., "models--lllyasviel--FramePackI2V_HY"). | |
Returns: | |
str: The commit hash if found, otherwise None. | |
""" | |
hf_home_dir = os.environ.get('HF_HOME') | |
if not hf_home_dir: | |
print("Warning: HF_HOME environment variable not set. Cannot determine snapshot hash.") | |
return None | |
refs_main_path = os.path.join(hf_home_dir, 'hub', model_repo_id_for_cache, 'refs', 'main') | |
if os.path.exists(refs_main_path): | |
try: | |
with open(refs_main_path, 'r') as f: | |
print(f"Offline mode: Reading snapshot hash from: {refs_main_path}") | |
return f.read().strip() | |
except Exception as e: | |
print(f"Warning: Could not read snapshot hash from {refs_main_path}: {e}") | |
return None | |
else: | |
print(f"Warning: refs/main file not found at {refs_main_path}. Cannot determine snapshot hash.") | |
return None | |
def _get_offline_load_path(self) -> str: | |
""" | |
Returns the local snapshot path for offline loading if available. | |
Falls back to the default self.model_path if local snapshot can't be found. | |
Relies on self.model_repo_id_for_cache and self.model_path being set by subclasses. | |
""" | |
# Ensure necessary attributes are set by the subclass | |
if not hasattr(self, 'model_repo_id_for_cache') or not self.model_repo_id_for_cache: | |
print(f"Warning: model_repo_id_for_cache not set in {self.__class__.__name__}. Cannot determine offline path.") | |
# Fallback to model_path if it exists, otherwise None | |
return getattr(self, 'model_path', None) | |
if not hasattr(self, 'model_path') or not self.model_path: | |
print(f"Warning: model_path not set in {self.__class__.__name__}. Cannot determine fallback for offline path.") | |
return None | |
snapshot_hash = self._get_snapshot_hash_from_refs(self.model_repo_id_for_cache) | |
hf_home = os.environ.get('HF_HOME') | |
if snapshot_hash and hf_home: | |
specific_snapshot_path = os.path.join( | |
hf_home, 'hub', self.model_repo_id_for_cache, 'snapshots', snapshot_hash | |
) | |
if os.path.isdir(specific_snapshot_path): | |
return specific_snapshot_path | |
# If snapshot logic fails or path is not a dir, fallback to the default model path | |
return self.model_path | |
def unload_loras(self): | |
""" | |
Unload all LoRAs from the transformer model. | |
""" | |
if self.transformer is not None: | |
print(f"Unloading all LoRAs from {self.get_model_name()} model") | |
self.transformer = lora_utils.unload_all_loras(self.transformer) | |
self.verify_lora_state("After unloading LoRAs") | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def verify_lora_state(self, label=""): | |
""" | |
Debug function to verify the state of LoRAs in the transformer model. | |
""" | |
if self.transformer is None: | |
print(f"[{label}] Transformer is None, cannot verify LoRA state") | |
return | |
has_loras = False | |
if hasattr(self.transformer, 'peft_config'): | |
adapter_names = list(self.transformer.peft_config.keys()) if self.transformer.peft_config else [] | |
if adapter_names: | |
has_loras = True | |
print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}") | |
else: | |
print(f"[{label}] Transformer has no LoRAs in peft_config") | |
else: | |
print(f"[{label}] Transformer has no peft_config attribute") | |
# Check for any LoRA modules | |
for name, module in self.transformer.named_modules(): | |
if hasattr(module, 'lora_A') and module.lora_A: | |
has_loras = True | |
# print(f"[{label}] Found lora_A in module {name}") | |
if hasattr(module, 'lora_B') and module.lora_B: | |
has_loras = True | |
# print(f"[{label}] Found lora_B in module {name}") | |
if not has_loras: | |
print(f"[{label}] No LoRA components found in transformer") | |
def move_lora_adapters_to_device(self, target_device): | |
""" | |
Move all LoRA adapters in the transformer model to the specified device. | |
This handles the PEFT implementation of LoRA. | |
""" | |
if self.transformer is None: | |
return | |
print(f"Moving all LoRA adapters to {target_device}") | |
# First, find all modules with LoRA adapters | |
lora_modules = [] | |
for name, module in self.transformer.named_modules(): | |
if hasattr(module, 'active_adapter') and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'): | |
lora_modules.append((name, module)) | |
# Now move all LoRA components to the target device | |
for name, module in lora_modules: | |
# Get the active adapter name | |
active_adapter = module.active_adapter | |
# Move the LoRA layers to the target device | |
if active_adapter is not None: | |
if isinstance(module.lora_A, torch.nn.ModuleDict): | |
# Handle ModuleDict case (PEFT implementation) | |
for adapter_name in list(module.lora_A.keys()): | |
# Move lora_A | |
if adapter_name in module.lora_A: | |
module.lora_A[adapter_name] = module.lora_A[adapter_name].to(target_device) | |
# Move lora_B | |
if adapter_name in module.lora_B: | |
module.lora_B[adapter_name] = module.lora_B[adapter_name].to(target_device) | |
# Move scaling | |
if hasattr(module, 'scaling') and isinstance(module.scaling, dict) and adapter_name in module.scaling: | |
if isinstance(module.scaling[adapter_name], torch.Tensor): | |
module.scaling[adapter_name] = module.scaling[adapter_name].to(target_device) | |
else: | |
# Handle direct attribute case | |
if hasattr(module, 'lora_A') and module.lora_A is not None: | |
module.lora_A = module.lora_A.to(target_device) | |
if hasattr(module, 'lora_B') and module.lora_B is not None: | |
module.lora_B = module.lora_B.to(target_device) | |
if hasattr(module, 'scaling') and module.scaling is not None: | |
if isinstance(module.scaling, torch.Tensor): | |
module.scaling = module.scaling.to(target_device) | |
print(f"Moved all LoRA adapters to {target_device}") | |
def load_loras(self, selected_loras: List[str], lora_folder: str, lora_loaded_names: List[str], lora_values: Optional[List[float]] = None): | |
""" | |
Load LoRAs into the transformer model and applies their weights. | |
Args: | |
selected_loras: List of LoRA base names to load (e.g., ["lora_A", "lora_B"]). | |
lora_folder: Path to the folder containing the LoRA files. | |
lora_loaded_names: The master list of ALL available LoRA names, used for correct weight indexing. | |
lora_values: A list of strength values corresponding to lora_loaded_names. | |
""" | |
self.unload_loras() | |
if not selected_loras: | |
print("No LoRAs selected, skipping loading.") | |
return | |
lora_dir = Path(lora_folder) | |
adapter_names = [] | |
strengths = [] | |
for idx, lora_base_name in enumerate(selected_loras): | |
lora_file = None | |
for ext in (".safetensors", ".pt"): | |
candidate_path_relative = f"{lora_base_name}{ext}" | |
candidate_path_full = lora_dir / candidate_path_relative | |
if candidate_path_full.is_file(): | |
lora_file = candidate_path_relative | |
break | |
if not lora_file: | |
print(f"Warning: LoRA file for base name '{lora_base_name}' not found; skipping.") | |
continue | |
print(f"Loading LoRA from '{lora_file}'...") | |
self.transformer, adapter_name = lora_utils.load_lora(self.transformer, lora_dir, lora_file) | |
adapter_names.append(adapter_name) | |
weight = 1.0 | |
if lora_values: | |
try: | |
master_list_idx = lora_loaded_names.index(lora_base_name) | |
if master_list_idx < len(lora_values): | |
weight = float(lora_values[master_list_idx]) | |
else: | |
print(f"Warning: Index mismatch for '{lora_base_name}'. Defaulting to 1.0.") | |
except ValueError: | |
print(f"Warning: LoRA '{lora_base_name}' not found in master list. Defaulting to 1.0.") | |
strengths.append(weight) | |
if adapter_names: | |
print(f"Activating adapters: {adapter_names} with strengths: {strengths}") | |
lora_utils.set_adapters(self.transformer, adapter_names, strengths) | |
self.verify_lora_state("After completing load_loras") |