Spaces:
Paused
Paused
File size: 7,639 Bytes
05fcd0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
from pathlib import Path, PurePath
from typing import Dict, List, Optional, Union, Tuple
from diffusers.loaders.lora_pipeline import _fetch_state_dict
from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
import torch
FALLBACK_CLASS_ALIASES = {
"HunyuanVideoTransformer3DModelPacked": "HunyuanVideoTransformer3DModel",
}
def load_lora(transformer: torch.nn.Module, lora_path: Path, weight_name: str) -> Tuple[torch.nn.Module, str]:
"""
Load LoRA weights into the transformer model.
Args:
transformer: The transformer model to which LoRA weights will be applied.
lora_path: Path to the folder containing the LoRA weights file.
weight_name: Filename of the weight to load.
Returns:
A tuple containing the modified transformer and the canonical adapter name.
"""
state_dict = _fetch_state_dict(
lora_path,
weight_name,
True,
True,
None,
None,
None,
None,
None,
None,
None,
None)
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
# should weight_name even be Optional[str] or just str?
# For now, we assume it is never None
# The module name in the state_dict must not include a . in the name
# See https://github.com/pytorch/pytorch/pull/6639/files#diff-4be56271f7bfe650e3521c81fd363da58f109cd23ee80d243156d2d6ccda6263R133-R134
adapter_name = str(PurePath(weight_name).with_suffix('')).replace('.', '_DOT_')
if '_DOT_' in adapter_name:
print(
f"LoRA file '{weight_name}' contains a '.' in the name. " +
'This may cause issues. Consider renaming the file.' +
f" Using '{adapter_name}' as the adapter name to be safe."
)
# Check if adapter already exists and delete it if it does
if hasattr(transformer, 'peft_config') and adapter_name in transformer.peft_config:
print(f"Adapter '{adapter_name}' already exists. Removing it before loading again.")
# Use delete_adapters (plural) instead of delete_adapter
transformer.delete_adapters([adapter_name])
# Load the adapter with the original name
transformer.load_lora_adapter(state_dict, network_alphas=None, adapter_name=adapter_name)
print(f"LoRA weights '{adapter_name}' loaded successfully.")
return transformer, adapter_name
def unload_all_loras(transformer: torch.nn.Module) -> torch.nn.Module:
"""
Completely unload all LoRA adapters from the transformer model.
Args:
transformer: The transformer model from which LoRA adapters will be removed.
Returns:
The transformer model after all LoRA adapters have been removed.
"""
if hasattr(transformer, 'peft_config') and transformer.peft_config:
# Get all adapter names
adapter_names = list(transformer.peft_config.keys())
if adapter_names:
print(f"Removing all LoRA adapters: {', '.join(adapter_names)}")
# Delete all adapters
transformer.delete_adapters(adapter_names)
# Force cleanup of any remaining adapter references
if hasattr(transformer, 'active_adapter'):
transformer.active_adapter = None
# Clear any cached states
for module in transformer.modules():
if hasattr(module, 'lora_A'):
if isinstance(module.lora_A, dict):
module.lora_A.clear()
if hasattr(module, 'lora_B'):
if isinstance(module.lora_B, dict):
module.lora_B.clear()
if hasattr(module, 'scaling'):
if isinstance(module.scaling, dict):
module.scaling.clear()
print("All LoRA adapters have been completely removed.")
else:
print("No LoRA adapters found to remove.")
else:
print("Model doesn't have any LoRA adapters or peft_config.")
# Force garbage collection
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return transformer
def resolve_expansion_class_name(
transformer: torch.nn.Module,
fallback_aliases: Dict[str, str],
fn_mapping: Dict[str, callable]
) -> Optional[str]:
"""
Resolves the canonical class name for adapter scale expansion functions,
considering potential fallback aliases.
Args:
transformer: The transformer model instance.
fallback_aliases: A dictionary mapping model class names to fallback class names.
fn_mapping: A dictionary mapping class names to their respective scale expansion functions.
Returns:
The resolved class name as a string if a matching scale function is found,
otherwise None.
"""
class_name = transformer.__class__.__name__
if class_name in fn_mapping:
return class_name
fallback_class = fallback_aliases.get(class_name)
if fallback_class in fn_mapping:
print(f"Warning: No scale function for '{class_name}'. Falling back to '{fallback_class}'")
return fallback_class
return None
def set_adapters(
transformer: torch.nn.Module,
adapter_names: Union[List[str], str],
weights: Optional[Union[float, List[float]]] = None,
):
"""
Activates and sets the weights for one or more LoRA adapters on the transformer model.
Args:
transformer: The transformer model to which LoRA adapters are applied.
adapter_names: A single adapter name (str) or a list of adapter names (List[str]) to activate.
weights: Optional. A single float weight or a list of float weights
corresponding to each adapter name. If None, defaults to 1.0 for each adapter.
If a single float, it will be applied to all adapters.
"""
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
# Expand a single weight to apply to all adapters if needed
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"The number of adapter names ({len(adapter_names)}) does not match the number of weights ({len(weights)})."
)
# Replace any None weights with a default value of 1.0
sanitized_weights = [w if w is not None else 1.0 for w in weights]
resolved_class_name = resolve_expansion_class_name(
transformer,
fallback_aliases=FALLBACK_CLASS_ALIASES,
fn_mapping=_SET_ADAPTER_SCALE_FN_MAPPING
)
transformer_class_name = transformer.__class__.__name__
if resolved_class_name:
scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[resolved_class_name]
print(f"Using scale expansion function for model class '{resolved_class_name}' (original: '{transformer_class_name}')")
final_weights = [
scale_expansion_fn(transformer, [weight])[0] for weight in sanitized_weights
]
else:
print(f"Warning: No scale expansion function found for '{transformer_class_name}'. Using raw weights.")
final_weights = sanitized_weights
set_weights_and_activate_adapters(transformer, adapter_names, final_weights)
print(f"Adapters {adapter_names} activated with weights {final_weights}.") |