Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,779 Bytes
26557da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
class GeneralLoRALoader:
def __init__(self, device="cpu", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
has_lora_A = any(k.endswith(".lora_A.weight") for k in lora_state_dict)
has_lora_down = any(k.endswith(".lora_down.weight") for k in lora_state_dict)
if has_lora_A:
lora_a_keys = [k for k in lora_state_dict if k.endswith(".lora_A.weight")]
for lora_a_key in lora_a_keys:
base_name = lora_a_key.replace(".lora_A.weight", "")
lora_b_key = base_name + ".lora_B.weight"
if lora_b_key in lora_state_dict:
target_name = base_name.replace("diffusion_model.", "", 1)
lora_name_dict[target_name] = (lora_b_key, lora_a_key)
elif has_lora_down:
lora_down_keys = [
k for k in lora_state_dict if k.endswith(".lora_down.weight")
]
for lora_down_key in lora_down_keys:
base_name = lora_down_key.replace(".lora_down.weight", "")
lora_up_key = base_name + ".lora_up.weight"
if lora_up_key in lora_state_dict:
target_name = base_name.replace("lora_unet_", "").replace("_", ".")
target_name = target_name.replace(".attn.", "_attn.")
lora_name_dict[target_name] = (lora_up_key, lora_down_key)
else:
print(
"Warning: No recognizable LoRA key names found in state_dict (neither 'lora_A' nor 'lora_down')."
)
return lora_name_dict
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
lora_name_dict = self.get_name_dict(state_dict_lora)
updated_num = 0
lora_target_names = set(lora_name_dict.keys())
model_layer_names = {
name for name, module in model.named_modules() if hasattr(module, "weight")
}
matched_names = lora_target_names.intersection(model_layer_names)
unmatched_lora_names = lora_target_names - model_layer_names
print(f"Successfully matched {len(matched_names)} layers.")
if unmatched_lora_names:
print(
f"Warning: {len(unmatched_lora_names)} LoRA layers not matched and will be ignored."
)
for name, module in model.named_modules():
if name in matched_names:
lora_b_key, lora_a_key = lora_name_dict[name]
weight_up = state_dict_lora[lora_b_key].to(
device=self.device, dtype=self.torch_dtype
)
weight_down = state_dict_lora[lora_a_key].to(
device=self.device, dtype=self.torch_dtype
)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(
2
).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
if module.weight.shape != weight_lora.shape:
print(f"Error: Shape mismatch for layer '{name}'! Skipping update.")
continue
module.weight.data = (
module.weight.data.to(weight_lora.device, dtype=weight_lora.dtype)
+ weight_lora
)
updated_num += 1
print(f"LoRA loading complete, updated {updated_num} tensors in total.\n")
|