Spaces:
Paused
Paused
from pathlib import Path, PurePath | |
from typing import Dict, List, Optional, Union, Tuple | |
from diffusers.loaders.lora_pipeline import _fetch_state_dict | |
from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers | |
from diffusers.utils.peft_utils import set_weights_and_activate_adapters | |
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING | |
import torch | |
FALLBACK_CLASS_ALIASES = { | |
"HunyuanVideoTransformer3DModelPacked": "HunyuanVideoTransformer3DModel", | |
} | |
def load_lora(transformer: torch.nn.Module, lora_path: Path, weight_name: str) -> Tuple[torch.nn.Module, str]: | |
""" | |
Load LoRA weights into the transformer model. | |
Args: | |
transformer: The transformer model to which LoRA weights will be applied. | |
lora_path: Path to the folder containing the LoRA weights file. | |
weight_name: Filename of the weight to load. | |
Returns: | |
A tuple containing the modified transformer and the canonical adapter name. | |
""" | |
state_dict = _fetch_state_dict( | |
lora_path, | |
weight_name, | |
True, | |
True, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None) | |
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) | |
# should weight_name even be Optional[str] or just str? | |
# For now, we assume it is never None | |
# The module name in the state_dict must not include a . in the name | |
# See https://github.com/pytorch/pytorch/pull/6639/files#diff-4be56271f7bfe650e3521c81fd363da58f109cd23ee80d243156d2d6ccda6263R133-R134 | |
adapter_name = str(PurePath(weight_name).with_suffix('')).replace('.', '_DOT_') | |
if '_DOT_' in adapter_name: | |
print( | |
f"LoRA file '{weight_name}' contains a '.' in the name. " + | |
'This may cause issues. Consider renaming the file.' + | |
f" Using '{adapter_name}' as the adapter name to be safe." | |
) | |
# Check if adapter already exists and delete it if it does | |
if hasattr(transformer, 'peft_config') and adapter_name in transformer.peft_config: | |
print(f"Adapter '{adapter_name}' already exists. Removing it before loading again.") | |
# Use delete_adapters (plural) instead of delete_adapter | |
transformer.delete_adapters([adapter_name]) | |
# Load the adapter with the original name | |
transformer.load_lora_adapter(state_dict, network_alphas=None, adapter_name=adapter_name) | |
print(f"LoRA weights '{adapter_name}' loaded successfully.") | |
return transformer, adapter_name | |
def unload_all_loras(transformer: torch.nn.Module) -> torch.nn.Module: | |
""" | |
Completely unload all LoRA adapters from the transformer model. | |
Args: | |
transformer: The transformer model from which LoRA adapters will be removed. | |
Returns: | |
The transformer model after all LoRA adapters have been removed. | |
""" | |
if hasattr(transformer, 'peft_config') and transformer.peft_config: | |
# Get all adapter names | |
adapter_names = list(transformer.peft_config.keys()) | |
if adapter_names: | |
print(f"Removing all LoRA adapters: {', '.join(adapter_names)}") | |
# Delete all adapters | |
transformer.delete_adapters(adapter_names) | |
# Force cleanup of any remaining adapter references | |
if hasattr(transformer, 'active_adapter'): | |
transformer.active_adapter = None | |
# Clear any cached states | |
for module in transformer.modules(): | |
if hasattr(module, 'lora_A'): | |
if isinstance(module.lora_A, dict): | |
module.lora_A.clear() | |
if hasattr(module, 'lora_B'): | |
if isinstance(module.lora_B, dict): | |
module.lora_B.clear() | |
if hasattr(module, 'scaling'): | |
if isinstance(module.scaling, dict): | |
module.scaling.clear() | |
print("All LoRA adapters have been completely removed.") | |
else: | |
print("No LoRA adapters found to remove.") | |
else: | |
print("Model doesn't have any LoRA adapters or peft_config.") | |
# Force garbage collection | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return transformer | |
def resolve_expansion_class_name( | |
transformer: torch.nn.Module, | |
fallback_aliases: Dict[str, str], | |
fn_mapping: Dict[str, callable] | |
) -> Optional[str]: | |
""" | |
Resolves the canonical class name for adapter scale expansion functions, | |
considering potential fallback aliases. | |
Args: | |
transformer: The transformer model instance. | |
fallback_aliases: A dictionary mapping model class names to fallback class names. | |
fn_mapping: A dictionary mapping class names to their respective scale expansion functions. | |
Returns: | |
The resolved class name as a string if a matching scale function is found, | |
otherwise None. | |
""" | |
class_name = transformer.__class__.__name__ | |
if class_name in fn_mapping: | |
return class_name | |
fallback_class = fallback_aliases.get(class_name) | |
if fallback_class in fn_mapping: | |
print(f"Warning: No scale function for '{class_name}'. Falling back to '{fallback_class}'") | |
return fallback_class | |
return None | |
def set_adapters( | |
transformer: torch.nn.Module, | |
adapter_names: Union[List[str], str], | |
weights: Optional[Union[float, List[float]]] = None, | |
): | |
""" | |
Activates and sets the weights for one or more LoRA adapters on the transformer model. | |
Args: | |
transformer: The transformer model to which LoRA adapters are applied. | |
adapter_names: A single adapter name (str) or a list of adapter names (List[str]) to activate. | |
weights: Optional. A single float weight or a list of float weights | |
corresponding to each adapter name. If None, defaults to 1.0 for each adapter. | |
If a single float, it will be applied to all adapters. | |
""" | |
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names | |
# Expand a single weight to apply to all adapters if needed | |
if not isinstance(weights, list): | |
weights = [weights] * len(adapter_names) | |
if len(adapter_names) != len(weights): | |
raise ValueError( | |
f"The number of adapter names ({len(adapter_names)}) does not match the number of weights ({len(weights)})." | |
) | |
# Replace any None weights with a default value of 1.0 | |
sanitized_weights = [w if w is not None else 1.0 for w in weights] | |
resolved_class_name = resolve_expansion_class_name( | |
transformer, | |
fallback_aliases=FALLBACK_CLASS_ALIASES, | |
fn_mapping=_SET_ADAPTER_SCALE_FN_MAPPING | |
) | |
transformer_class_name = transformer.__class__.__name__ | |
if resolved_class_name: | |
scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[resolved_class_name] | |
print(f"Using scale expansion function for model class '{resolved_class_name}' (original: '{transformer_class_name}')") | |
final_weights = [ | |
scale_expansion_fn(transformer, [weight])[0] for weight in sanitized_weights | |
] | |
else: | |
print(f"Warning: No scale expansion function found for '{transformer_class_name}'. Using raw weights.") | |
final_weights = sanitized_weights | |
set_weights_and_activate_adapters(transformer, adapter_names, final_weights) | |
print(f"Adapters {adapter_names} activated with weights {final_weights}.") |