FPS-Studio / diffusers_helper /lora_utils.py
rahul7star's picture
Migrated from GitHub
05fcd0f verified
from pathlib import Path, PurePath
from typing import Dict, List, Optional, Union, Tuple
from diffusers.loaders.lora_pipeline import _fetch_state_dict
from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
import torch
FALLBACK_CLASS_ALIASES = {
"HunyuanVideoTransformer3DModelPacked": "HunyuanVideoTransformer3DModel",
}
def load_lora(transformer: torch.nn.Module, lora_path: Path, weight_name: str) -> Tuple[torch.nn.Module, str]:
"""
Load LoRA weights into the transformer model.
Args:
transformer: The transformer model to which LoRA weights will be applied.
lora_path: Path to the folder containing the LoRA weights file.
weight_name: Filename of the weight to load.
Returns:
A tuple containing the modified transformer and the canonical adapter name.
"""
state_dict = _fetch_state_dict(
lora_path,
weight_name,
True,
True,
None,
None,
None,
None,
None,
None,
None,
None)
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
# should weight_name even be Optional[str] or just str?
# For now, we assume it is never None
# The module name in the state_dict must not include a . in the name
# See https://github.com/pytorch/pytorch/pull/6639/files#diff-4be56271f7bfe650e3521c81fd363da58f109cd23ee80d243156d2d6ccda6263R133-R134
adapter_name = str(PurePath(weight_name).with_suffix('')).replace('.', '_DOT_')
if '_DOT_' in adapter_name:
print(
f"LoRA file '{weight_name}' contains a '.' in the name. " +
'This may cause issues. Consider renaming the file.' +
f" Using '{adapter_name}' as the adapter name to be safe."
)
# Check if adapter already exists and delete it if it does
if hasattr(transformer, 'peft_config') and adapter_name in transformer.peft_config:
print(f"Adapter '{adapter_name}' already exists. Removing it before loading again.")
# Use delete_adapters (plural) instead of delete_adapter
transformer.delete_adapters([adapter_name])
# Load the adapter with the original name
transformer.load_lora_adapter(state_dict, network_alphas=None, adapter_name=adapter_name)
print(f"LoRA weights '{adapter_name}' loaded successfully.")
return transformer, adapter_name
def unload_all_loras(transformer: torch.nn.Module) -> torch.nn.Module:
"""
Completely unload all LoRA adapters from the transformer model.
Args:
transformer: The transformer model from which LoRA adapters will be removed.
Returns:
The transformer model after all LoRA adapters have been removed.
"""
if hasattr(transformer, 'peft_config') and transformer.peft_config:
# Get all adapter names
adapter_names = list(transformer.peft_config.keys())
if adapter_names:
print(f"Removing all LoRA adapters: {', '.join(adapter_names)}")
# Delete all adapters
transformer.delete_adapters(adapter_names)
# Force cleanup of any remaining adapter references
if hasattr(transformer, 'active_adapter'):
transformer.active_adapter = None
# Clear any cached states
for module in transformer.modules():
if hasattr(module, 'lora_A'):
if isinstance(module.lora_A, dict):
module.lora_A.clear()
if hasattr(module, 'lora_B'):
if isinstance(module.lora_B, dict):
module.lora_B.clear()
if hasattr(module, 'scaling'):
if isinstance(module.scaling, dict):
module.scaling.clear()
print("All LoRA adapters have been completely removed.")
else:
print("No LoRA adapters found to remove.")
else:
print("Model doesn't have any LoRA adapters or peft_config.")
# Force garbage collection
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return transformer
def resolve_expansion_class_name(
transformer: torch.nn.Module,
fallback_aliases: Dict[str, str],
fn_mapping: Dict[str, callable]
) -> Optional[str]:
"""
Resolves the canonical class name for adapter scale expansion functions,
considering potential fallback aliases.
Args:
transformer: The transformer model instance.
fallback_aliases: A dictionary mapping model class names to fallback class names.
fn_mapping: A dictionary mapping class names to their respective scale expansion functions.
Returns:
The resolved class name as a string if a matching scale function is found,
otherwise None.
"""
class_name = transformer.__class__.__name__
if class_name in fn_mapping:
return class_name
fallback_class = fallback_aliases.get(class_name)
if fallback_class in fn_mapping:
print(f"Warning: No scale function for '{class_name}'. Falling back to '{fallback_class}'")
return fallback_class
return None
def set_adapters(
transformer: torch.nn.Module,
adapter_names: Union[List[str], str],
weights: Optional[Union[float, List[float]]] = None,
):
"""
Activates and sets the weights for one or more LoRA adapters on the transformer model.
Args:
transformer: The transformer model to which LoRA adapters are applied.
adapter_names: A single adapter name (str) or a list of adapter names (List[str]) to activate.
weights: Optional. A single float weight or a list of float weights
corresponding to each adapter name. If None, defaults to 1.0 for each adapter.
If a single float, it will be applied to all adapters.
"""
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
# Expand a single weight to apply to all adapters if needed
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"The number of adapter names ({len(adapter_names)}) does not match the number of weights ({len(weights)})."
)
# Replace any None weights with a default value of 1.0
sanitized_weights = [w if w is not None else 1.0 for w in weights]
resolved_class_name = resolve_expansion_class_name(
transformer,
fallback_aliases=FALLBACK_CLASS_ALIASES,
fn_mapping=_SET_ADAPTER_SCALE_FN_MAPPING
)
transformer_class_name = transformer.__class__.__name__
if resolved_class_name:
scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[resolved_class_name]
print(f"Using scale expansion function for model class '{resolved_class_name}' (original: '{transformer_class_name}')")
final_weights = [
scale_expansion_fn(transformer, [weight])[0] for weight in sanitized_weights
]
else:
print(f"Warning: No scale expansion function found for '{transformer_class_name}'. Using raw weights.")
final_weights = sanitized_weights
set_weights_and_activate_adapters(transformer, adapter_names, final_weights)
print(f"Adapters {adapter_names} activated with weights {final_weights}.")