Spaces:
Paused
Paused
File size: 4,362 Bytes
fc6bdf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import torch
from safetensors import safe_open
from loguru import logger
import gc
from functools import lru_cache
from tqdm import tqdm
@lru_cache(maxsize=None)
def GET_DTYPE():
RUNNING_FLAG = os.getenv("DTYPE")
return RUNNING_FLAG
class WanLoraWrapper:
def __init__(self, wan_model):
self.model = wan_model
self.lora_metadata = {}
# self.override_dict = {} # On CPU
def load_lora(self, lora_path, lora_name=None):
if lora_name is None:
lora_name = os.path.basename(lora_path).split(".")[0]
if lora_name in self.lora_metadata:
logger.info(f"LoRA {lora_name} already loaded, skipping...")
return lora_name
self.lora_metadata[lora_name] = {"path": lora_path}
logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}")
return lora_name
def _load_lora_file(self, file_path, param_dtype):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(param_dtype) for key in f.keys()}
return tensor_dict
def apply_lora(self, lora_name, alpha=1.0, param_dtype=torch.bfloat16, device='cpu'):
if lora_name not in self.lora_metadata:
logger.info(f"LoRA {lora_name} not found. Please load it first.")
lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"], param_dtype)
# weight_dict = self.model.original_weight_dict
self._apply_lora_weights(lora_weights, alpha, device)
# self.model._init_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
return True
def get_parameter_by_name(self, model, param_name):
parts = param_name.split('.')
current = model
for part in parts:
if part.isdigit():
current = current[int(part)]
else:
current = getattr(current, part)
return current
@torch.no_grad()
def _apply_lora_weights(self, lora_weights, alpha, device):
lora_pairs = {}
prefix = "diffusion_model."
for key in lora_weights.keys():
if key.endswith("lora_down.weight") and key.startswith(prefix):
base_name = key[len(prefix) :].replace("lora_down.weight", "weight")
b_key = key.replace("lora_down.weight", "lora_up.weight")
if b_key in lora_weights:
lora_pairs[base_name] = (key, b_key)
elif key.endswith("diff_b") and key.startswith(prefix):
base_name = key[len(prefix) :].replace("diff_b", "bias")
lora_pairs[base_name] = (key)
elif key.endswith("diff") and key.startswith(prefix):
base_name = key[len(prefix) :].replace("diff", "weight")
lora_pairs[base_name] = (key)
applied_count = 0
for name in tqdm(lora_pairs.keys(), desc="Loading LoRA weights"):
param = self.get_parameter_by_name(self.model, name)
if device == 'cpu':
dtype = torch.float32
else:
dtype = param.dtype
if isinstance(lora_pairs[name], tuple):
name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(device, dtype)
lora_B = lora_weights[name_lora_B].to(device, dtype)
delta = torch.matmul(lora_B, lora_A) * alpha
delta = delta.to(param.device, param.dtype)
param.add_(delta)
else:
name_lora = lora_pairs[name]
delta = lora_weights[name_lora].to(param.device, dtype)* alpha
delta = delta.to(param.device, param.dtype)
param.add_(delta)
applied_count += 1
logger.info(f"Applied {applied_count} LoRA weight adjustments")
if applied_count == 0:
logger.info(
"Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
)
def list_loaded_loras(self):
return list(self.lora_metadata.keys())
def get_current_lora(self):
return self.model.current_lora |