File size: 7,639 Bytes
05fcd0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from pathlib import Path, PurePath
from typing import Dict, List, Optional, Union, Tuple
from diffusers.loaders.lora_pipeline import _fetch_state_dict
from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
import torch

FALLBACK_CLASS_ALIASES = {
    "HunyuanVideoTransformer3DModelPacked": "HunyuanVideoTransformer3DModel",
}

def load_lora(transformer: torch.nn.Module, lora_path: Path, weight_name: str) -> Tuple[torch.nn.Module, str]:
    """
    Load LoRA weights into the transformer model.

    Args:
        transformer: The transformer model to which LoRA weights will be applied.
        lora_path: Path to the folder containing the LoRA weights file.
        weight_name: Filename of the weight to load.

    Returns:
        A tuple containing the modified transformer and the canonical adapter name.
    """
    
    state_dict = _fetch_state_dict(
        lora_path,
        weight_name,
        True,
        True,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None)

    state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
    
    # should weight_name even be Optional[str] or just str?
    # For now, we assume it is never None
    # The module name in the state_dict must not include a . in the name
    # See https://github.com/pytorch/pytorch/pull/6639/files#diff-4be56271f7bfe650e3521c81fd363da58f109cd23ee80d243156d2d6ccda6263R133-R134
    adapter_name = str(PurePath(weight_name).with_suffix('')).replace('.', '_DOT_')
    if '_DOT_' in adapter_name:
        print(
            f"LoRA file '{weight_name}' contains a '.' in the name. " +
            'This may cause issues. Consider renaming the file.' +
            f" Using '{adapter_name}' as the adapter name to be safe."
        )
    
    # Check if adapter already exists and delete it if it does
    if hasattr(transformer, 'peft_config') and adapter_name in transformer.peft_config:
        print(f"Adapter '{adapter_name}' already exists. Removing it before loading again.")
        # Use delete_adapters (plural) instead of delete_adapter
        transformer.delete_adapters([adapter_name])
    
    # Load the adapter with the original name
    transformer.load_lora_adapter(state_dict, network_alphas=None, adapter_name=adapter_name)
    print(f"LoRA weights '{adapter_name}' loaded successfully.")
    
    return transformer, adapter_name

def unload_all_loras(transformer: torch.nn.Module) -> torch.nn.Module:
    """
    Completely unload all LoRA adapters from the transformer model.

    Args:
        transformer: The transformer model from which LoRA adapters will be removed.

    Returns:
        The transformer model after all LoRA adapters have been removed.
    """
    if hasattr(transformer, 'peft_config') and transformer.peft_config:
        # Get all adapter names
        adapter_names = list(transformer.peft_config.keys())
        
        if adapter_names:
            print(f"Removing all LoRA adapters: {', '.join(adapter_names)}")
            # Delete all adapters
            transformer.delete_adapters(adapter_names)
            
            # Force cleanup of any remaining adapter references
            if hasattr(transformer, 'active_adapter'):
                transformer.active_adapter = None
                
            # Clear any cached states
            for module in transformer.modules():
                if hasattr(module, 'lora_A'):
                    if isinstance(module.lora_A, dict):
                        module.lora_A.clear()
                if hasattr(module, 'lora_B'):
                    if isinstance(module.lora_B, dict):
                        module.lora_B.clear()
                if hasattr(module, 'scaling'):
                    if isinstance(module.scaling, dict):
                        module.scaling.clear()
            
            print("All LoRA adapters have been completely removed.")
        else:
            print("No LoRA adapters found to remove.")
    else:
        print("Model doesn't have any LoRA adapters or peft_config.")
    
    # Force garbage collection
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return transformer

def resolve_expansion_class_name(
    transformer: torch.nn.Module, 
    fallback_aliases: Dict[str, str], 
    fn_mapping: Dict[str, callable]
) -> Optional[str]:
    """
    Resolves the canonical class name for adapter scale expansion functions,
    considering potential fallback aliases.

    Args:
        transformer: The transformer model instance.
        fallback_aliases: A dictionary mapping model class names to fallback class names.
        fn_mapping: A dictionary mapping class names to their respective scale expansion functions.

    Returns:
        The resolved class name as a string if a matching scale function is found,
        otherwise None.
    """
    class_name = transformer.__class__.__name__

    if class_name in fn_mapping:
        return class_name

    fallback_class = fallback_aliases.get(class_name)
    if fallback_class in fn_mapping:
        print(f"Warning: No scale function for '{class_name}'. Falling back to '{fallback_class}'")
        return fallback_class

    return None

def set_adapters(
    transformer: torch.nn.Module,
    adapter_names: Union[List[str], str],
    weights: Optional[Union[float, List[float]]] = None,
):
    """
    Activates and sets the weights for one or more LoRA adapters on the transformer model.

    Args:
        transformer: The transformer model to which LoRA adapters are applied.
        adapter_names: A single adapter name (str) or a list of adapter names (List[str]) to activate.
        weights: Optional. A single float weight or a list of float weights
                 corresponding to each adapter name. If None, defaults to 1.0 for each adapter.
                 If a single float, it will be applied to all adapters.
    """
    adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

    # Expand a single weight to apply to all adapters if needed
    if not isinstance(weights, list):
        weights = [weights] * len(adapter_names)

    if len(adapter_names) != len(weights):
        raise ValueError(
            f"The number of adapter names ({len(adapter_names)}) does not match the number of weights ({len(weights)})."
        )

    # Replace any None weights with a default value of 1.0
    sanitized_weights = [w if w is not None else 1.0 for w in weights]

    resolved_class_name = resolve_expansion_class_name(
        transformer,
        fallback_aliases=FALLBACK_CLASS_ALIASES,
        fn_mapping=_SET_ADAPTER_SCALE_FN_MAPPING
    )

    transformer_class_name = transformer.__class__.__name__

    if resolved_class_name:
        scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[resolved_class_name]
        print(f"Using scale expansion function for model class '{resolved_class_name}' (original: '{transformer_class_name}')")
        final_weights = [
            scale_expansion_fn(transformer, [weight])[0] for weight in sanitized_weights
        ]
    else:
        print(f"Warning: No scale expansion function found for '{transformer_class_name}'. Using raw weights.")
        final_weights = sanitized_weights

    set_weights_and_activate_adapters(transformer, adapter_names, final_weights)
    
    print(f"Adapters {adapter_names} activated with weights {final_weights}.")