Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
class GeneralLoRALoader: | |
def __init__(self, device="cpu", torch_dtype=torch.float32): | |
self.device = device | |
self.torch_dtype = torch_dtype | |
def get_name_dict(self, lora_state_dict): | |
lora_name_dict = {} | |
has_lora_A = any(k.endswith(".lora_A.weight") for k in lora_state_dict) | |
has_lora_down = any(k.endswith(".lora_down.weight") for k in lora_state_dict) | |
if has_lora_A: | |
lora_a_keys = [k for k in lora_state_dict if k.endswith(".lora_A.weight")] | |
for lora_a_key in lora_a_keys: | |
base_name = lora_a_key.replace(".lora_A.weight", "") | |
lora_b_key = base_name + ".lora_B.weight" | |
if lora_b_key in lora_state_dict: | |
target_name = base_name.replace("diffusion_model.", "", 1) | |
lora_name_dict[target_name] = (lora_b_key, lora_a_key) | |
elif has_lora_down: | |
lora_down_keys = [ | |
k for k in lora_state_dict if k.endswith(".lora_down.weight") | |
] | |
for lora_down_key in lora_down_keys: | |
base_name = lora_down_key.replace(".lora_down.weight", "") | |
lora_up_key = base_name + ".lora_up.weight" | |
if lora_up_key in lora_state_dict: | |
target_name = base_name.replace("lora_unet_", "").replace("_", ".") | |
target_name = target_name.replace(".attn.", "_attn.") | |
lora_name_dict[target_name] = (lora_up_key, lora_down_key) | |
else: | |
print( | |
"Warning: No recognizable LoRA key names found in state_dict (neither 'lora_A' nor 'lora_down')." | |
) | |
return lora_name_dict | |
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): | |
lora_name_dict = self.get_name_dict(state_dict_lora) | |
updated_num = 0 | |
lora_target_names = set(lora_name_dict.keys()) | |
model_layer_names = { | |
name for name, module in model.named_modules() if hasattr(module, "weight") | |
} | |
matched_names = lora_target_names.intersection(model_layer_names) | |
unmatched_lora_names = lora_target_names - model_layer_names | |
print(f"Successfully matched {len(matched_names)} layers.") | |
if unmatched_lora_names: | |
print( | |
f"Warning: {len(unmatched_lora_names)} LoRA layers not matched and will be ignored." | |
) | |
for name, module in model.named_modules(): | |
if name in matched_names: | |
lora_b_key, lora_a_key = lora_name_dict[name] | |
weight_up = state_dict_lora[lora_b_key].to( | |
device=self.device, dtype=self.torch_dtype | |
) | |
weight_down = state_dict_lora[lora_a_key].to( | |
device=self.device, dtype=self.torch_dtype | |
) | |
if len(weight_up.shape) == 4: | |
weight_up = weight_up.squeeze(3).squeeze(2) | |
weight_down = weight_down.squeeze(3).squeeze(2) | |
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze( | |
2 | |
).unsqueeze(3) | |
else: | |
weight_lora = alpha * torch.mm(weight_up, weight_down) | |
if module.weight.shape != weight_lora.shape: | |
print(f"Error: Shape mismatch for layer '{name}'! Skipping update.") | |
continue | |
module.weight.data = ( | |
module.weight.data.to(weight_lora.device, dtype=weight_lora.dtype) | |
+ weight_lora | |
) | |
updated_num += 1 | |
print(f"LoRA loading complete, updated {updated_num} tensors in total.\n") | |