|
import torch |
|
import os |
|
from abc import ABC, abstractmethod |
|
from diffusers_helper import lora_utils |
|
|
|
class BaseModelGenerator(ABC): |
|
""" |
|
Base class for model generators. |
|
This defines the common interface that all model generators must implement. |
|
""" |
|
|
|
def __init__(self, |
|
text_encoder, |
|
text_encoder_2, |
|
tokenizer, |
|
tokenizer_2, |
|
vae, |
|
image_encoder, |
|
feature_extractor, |
|
high_vram=False, |
|
prompt_embedding_cache=None, |
|
settings=None, |
|
offline=False): |
|
""" |
|
Initialize the base model generator. |
|
|
|
Args: |
|
text_encoder: The text encoder model |
|
text_encoder_2: The second text encoder model |
|
tokenizer: The tokenizer for the first text encoder |
|
tokenizer_2: The tokenizer for the second text encoder |
|
vae: The VAE model |
|
image_encoder: The image encoder model |
|
feature_extractor: The feature extractor |
|
high_vram: Whether high VRAM mode is enabled |
|
prompt_embedding_cache: Cache for prompt embeddings |
|
settings: Application settings |
|
offline: Whether to run in offline mode for model loading |
|
""" |
|
self.text_encoder = text_encoder |
|
self.text_encoder_2 = text_encoder_2 |
|
self.tokenizer = tokenizer |
|
self.tokenizer_2 = tokenizer_2 |
|
self.vae = vae |
|
self.image_encoder = image_encoder |
|
self.feature_extractor = feature_extractor |
|
self.high_vram = high_vram |
|
self.prompt_embedding_cache = prompt_embedding_cache or {} |
|
self.settings = settings |
|
self.offline = offline |
|
self.transformer = None |
|
self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.cpu = torch.device("cpu") |
|
|
|
|
|
@abstractmethod |
|
def load_model(self): |
|
""" |
|
Load the transformer model. |
|
This method should be implemented by each specific model generator. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def get_model_name(self): |
|
""" |
|
Get the name of the model. |
|
This method should be implemented by each specific model generator. |
|
""" |
|
pass |
|
|
|
@staticmethod |
|
def _get_snapshot_hash_from_refs(model_repo_id_for_cache: str) -> str | None: |
|
""" |
|
Reads the commit hash from the refs/main file for a given model in the HF cache. |
|
Args: |
|
model_repo_id_for_cache (str): The model ID formatted for cache directory names |
|
(e.g., "models--lllyasviel--FramePackI2V_HY"). |
|
Returns: |
|
str: The commit hash if found, otherwise None. |
|
""" |
|
hf_home_dir = os.environ.get('HF_HOME') |
|
if not hf_home_dir: |
|
print("Warning: HF_HOME environment variable not set. Cannot determine snapshot hash.") |
|
return None |
|
|
|
refs_main_path = os.path.join(hf_home_dir, 'hub', model_repo_id_for_cache, 'refs', 'main') |
|
if os.path.exists(refs_main_path): |
|
try: |
|
with open(refs_main_path, 'r') as f: |
|
print(f"Offline mode: Reading snapshot hash from: {refs_main_path}") |
|
return f.read().strip() |
|
except Exception as e: |
|
print(f"Warning: Could not read snapshot hash from {refs_main_path}: {e}") |
|
return None |
|
else: |
|
print(f"Warning: refs/main file not found at {refs_main_path}. Cannot determine snapshot hash.") |
|
return None |
|
|
|
def _get_offline_load_path(self) -> str: |
|
""" |
|
Returns the local snapshot path for offline loading if available. |
|
Falls back to the default self.model_path if local snapshot can't be found. |
|
Relies on self.model_repo_id_for_cache and self.model_path being set by subclasses. |
|
""" |
|
|
|
if not hasattr(self, 'model_repo_id_for_cache') or not self.model_repo_id_for_cache: |
|
print(f"Warning: model_repo_id_for_cache not set in {self.__class__.__name__}. Cannot determine offline path.") |
|
|
|
return getattr(self, 'model_path', None) |
|
|
|
if not hasattr(self, 'model_path') or not self.model_path: |
|
print(f"Warning: model_path not set in {self.__class__.__name__}. Cannot determine fallback for offline path.") |
|
return None |
|
|
|
snapshot_hash = self._get_snapshot_hash_from_refs(self.model_repo_id_for_cache) |
|
hf_home = os.environ.get('HF_HOME') |
|
|
|
if snapshot_hash and hf_home: |
|
specific_snapshot_path = os.path.join( |
|
hf_home, 'hub', self.model_repo_id_for_cache, 'snapshots', snapshot_hash |
|
) |
|
if os.path.isdir(specific_snapshot_path): |
|
return specific_snapshot_path |
|
|
|
|
|
return self.model_path |
|
|
|
def unload_loras(self): |
|
""" |
|
Unload all LoRAs from the transformer model. |
|
""" |
|
if self.transformer is not None: |
|
print(f"Unloading all LoRAs from {self.get_model_name()} model") |
|
self.transformer = lora_utils.unload_all_loras(self.transformer) |
|
self.verify_lora_state("After unloading LoRAs") |
|
import gc |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
def verify_lora_state(self, label=""): |
|
""" |
|
Debug function to verify the state of LoRAs in the transformer model. |
|
""" |
|
if self.transformer is None: |
|
print(f"[{label}] Transformer is None, cannot verify LoRA state") |
|
return |
|
|
|
has_loras = False |
|
if hasattr(self.transformer, 'peft_config'): |
|
adapter_names = list(self.transformer.peft_config.keys()) if self.transformer.peft_config else [] |
|
if adapter_names: |
|
has_loras = True |
|
print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}") |
|
else: |
|
print(f"[{label}] Transformer has no LoRAs in peft_config") |
|
else: |
|
print(f"[{label}] Transformer has no peft_config attribute") |
|
|
|
|
|
for name, module in self.transformer.named_modules(): |
|
if hasattr(module, 'lora_A') and module.lora_A: |
|
has_loras = True |
|
|
|
if hasattr(module, 'lora_B') and module.lora_B: |
|
has_loras = True |
|
|
|
|
|
if not has_loras: |
|
print(f"[{label}] No LoRA components found in transformer") |
|
|
|
def move_lora_adapters_to_device(self, target_device): |
|
""" |
|
Move all LoRA adapters in the transformer model to the specified device. |
|
This handles the PEFT implementation of LoRA. |
|
""" |
|
if self.transformer is None: |
|
return |
|
|
|
print(f"Moving all LoRA adapters to {target_device}") |
|
|
|
|
|
lora_modules = [] |
|
for name, module in self.transformer.named_modules(): |
|
if hasattr(module, 'active_adapter') and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'): |
|
lora_modules.append((name, module)) |
|
|
|
|
|
for name, module in lora_modules: |
|
|
|
active_adapter = module.active_adapter |
|
|
|
|
|
if active_adapter is not None: |
|
if isinstance(module.lora_A, torch.nn.ModuleDict): |
|
|
|
for adapter_name in list(module.lora_A.keys()): |
|
|
|
if adapter_name in module.lora_A: |
|
module.lora_A[adapter_name] = module.lora_A[adapter_name].to(target_device) |
|
|
|
|
|
if adapter_name in module.lora_B: |
|
module.lora_B[adapter_name] = module.lora_B[adapter_name].to(target_device) |
|
|
|
|
|
if hasattr(module, 'scaling') and isinstance(module.scaling, dict) and adapter_name in module.scaling: |
|
if isinstance(module.scaling[adapter_name], torch.Tensor): |
|
module.scaling[adapter_name] = module.scaling[adapter_name].to(target_device) |
|
else: |
|
|
|
if hasattr(module, 'lora_A') and module.lora_A is not None: |
|
module.lora_A = module.lora_A.to(target_device) |
|
if hasattr(module, 'lora_B') and module.lora_B is not None: |
|
module.lora_B = module.lora_B.to(target_device) |
|
if hasattr(module, 'scaling') and module.scaling is not None: |
|
if isinstance(module.scaling, torch.Tensor): |
|
module.scaling = module.scaling.to(target_device) |
|
|
|
print(f"Moved all LoRA adapters to {target_device}") |
|
|
|
def load_loras(self, selected_loras, lora_folder, lora_loaded_names, lora_values=None): |
|
""" |
|
Load LoRAs into the transformer model. |
|
|
|
Args: |
|
selected_loras: List of LoRA names to load |
|
lora_folder: Folder containing the LoRA files |
|
lora_loaded_names: List of loaded LoRA names |
|
lora_values: Optional list of LoRA strength values |
|
""" |
|
if self.transformer is None: |
|
print("Cannot load LoRAs: Transformer model is not loaded") |
|
return |
|
|
|
import os |
|
|
|
|
|
self.unload_loras() |
|
|
|
|
|
if isinstance(selected_loras, list): |
|
for lora_name in selected_loras: |
|
try: |
|
idx = lora_loaded_names.index(lora_name) |
|
lora_file = None |
|
for ext in [".safetensors", ".pt"]: |
|
candidate_path_relative = f"{lora_name}{ext}" |
|
candidate_path_full = os.path.join(lora_folder, candidate_path_relative) |
|
if os.path.isfile(candidate_path_full): |
|
lora_file = candidate_path_relative |
|
break |
|
|
|
if lora_file: |
|
print(f"Loading LoRA '{lora_file}' to {self.get_model_name()} model") |
|
self.transformer = lora_utils.load_lora(self.transformer, lora_folder, lora_file) |
|
|
|
|
|
if lora_values and idx < len(lora_values): |
|
lora_strength = float(lora_values[idx]) |
|
print(f"Setting LoRA '{lora_name}' strength to {lora_strength}") |
|
|
|
|
|
for name, module in self.transformer.named_modules(): |
|
if hasattr(module, 'scaling'): |
|
if isinstance(module.scaling, dict): |
|
|
|
if lora_name in module.scaling: |
|
if isinstance(module.scaling[lora_name], torch.Tensor): |
|
module.scaling[lora_name] = torch.tensor( |
|
lora_strength, device=module.scaling[lora_name].device |
|
) |
|
else: |
|
module.scaling[lora_name] = lora_strength |
|
else: |
|
|
|
if isinstance(module.scaling, torch.Tensor): |
|
module.scaling = torch.tensor( |
|
lora_strength, device=module.scaling.device |
|
) |
|
else: |
|
module.scaling = lora_strength |
|
else: |
|
print(f"LoRA file for {lora_name} not found!") |
|
except Exception as e: |
|
print(f"Error loading LoRA {lora_name}: {e}") |
|
else: |
|
print(f"Warning: selected_loras is not a list (type: {type(selected_loras)}), skipping LoRA loading.") |
|
|
|
|
|
self.verify_lora_state("After loading LoRAs") |
|
|
|
|