File size: 4,362 Bytes
fc6bdf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import torch
from safetensors import safe_open
from loguru import logger
import gc
from functools import lru_cache
from tqdm import tqdm

@lru_cache(maxsize=None)
def GET_DTYPE():
    RUNNING_FLAG = os.getenv("DTYPE")
    return RUNNING_FLAG

class WanLoraWrapper:
    def __init__(self, wan_model):
        self.model = wan_model
        self.lora_metadata = {}
        # self.override_dict = {}  # On CPU

    def load_lora(self, lora_path, lora_name=None):
        if lora_name is None:
            lora_name = os.path.basename(lora_path).split(".")[0]

        if lora_name in self.lora_metadata:
            logger.info(f"LoRA {lora_name} already loaded, skipping...")
            return lora_name

        self.lora_metadata[lora_name] = {"path": lora_path}
        logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}")

        return lora_name

    def _load_lora_file(self, file_path, param_dtype):
        with safe_open(file_path, framework="pt") as f:
            tensor_dict = {key: f.get_tensor(key).to(param_dtype) for key in f.keys()}
        return tensor_dict

    def apply_lora(self, lora_name, alpha=1.0, param_dtype=torch.bfloat16, device='cpu'):
        if lora_name not in self.lora_metadata:
            logger.info(f"LoRA {lora_name} not found. Please load it first.")



        lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"], param_dtype)
        # weight_dict = self.model.original_weight_dict
        self._apply_lora_weights(lora_weights, alpha, device)
        # self.model._init_weights(weight_dict)

        logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
        return True

    def get_parameter_by_name(self, model, param_name):
        parts = param_name.split('.')
        current = model
        for part in parts:
            if part.isdigit():
                current = current[int(part)]
            else:
                current = getattr(current, part)
        return current

    @torch.no_grad()
    def _apply_lora_weights(self, lora_weights, alpha, device):
        lora_pairs = {}
        prefix = "diffusion_model."

        for key in lora_weights.keys():
            if key.endswith("lora_down.weight") and key.startswith(prefix):
                base_name = key[len(prefix) :].replace("lora_down.weight", "weight")
                b_key = key.replace("lora_down.weight", "lora_up.weight")
                if b_key in lora_weights:
                    lora_pairs[base_name] = (key, b_key)
            elif key.endswith("diff_b") and key.startswith(prefix):
                base_name = key[len(prefix) :].replace("diff_b", "bias")
                lora_pairs[base_name] = (key)
            elif key.endswith("diff") and key.startswith(prefix):
                base_name = key[len(prefix) :].replace("diff", "weight")
                lora_pairs[base_name] = (key)

        applied_count = 0
        for name in tqdm(lora_pairs.keys(), desc="Loading LoRA weights"):
            param = self.get_parameter_by_name(self.model, name)
            if device == 'cpu':
                dtype = torch.float32
            else:
                dtype = param.dtype
            if isinstance(lora_pairs[name], tuple):
                name_lora_A, name_lora_B = lora_pairs[name]
                lora_A = lora_weights[name_lora_A].to(device, dtype)
                lora_B = lora_weights[name_lora_B].to(device, dtype)
                delta = torch.matmul(lora_B, lora_A) * alpha
                delta = delta.to(param.device, param.dtype)
                param.add_(delta)
            else:
                name_lora = lora_pairs[name]
                delta = lora_weights[name_lora].to(param.device, dtype)* alpha
                delta = delta.to(param.device, param.dtype)
                param.add_(delta)
            applied_count += 1


        logger.info(f"Applied {applied_count} LoRA weight adjustments")
        if applied_count == 0:
            logger.info(
                "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
            )


    def list_loaded_loras(self):
        return list(self.lora_metadata.keys())

    def get_current_lora(self):
        return self.model.current_lora