Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,455 Bytes
26557da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
def set_stand_in(pipe, train=False, model_path=None):
for block in pipe.dit.blocks:
block.self_attn.init_lora(train)
if model_path is not None:
print(f"Loading Stand-In weights from: {model_path}")
load_lora_weights_into_pipe(pipe, model_path)
def load_lora_weights_into_pipe(pipe, ckpt_path, strict=True):
ckpt = torch.load(ckpt_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
model = {}
for i, block in enumerate(pipe.dit.blocks):
prefix = f"blocks.{i}.self_attn."
attn = block.self_attn
for name in ["q_loras", "k_loras", "v_loras"]:
for sub in ["down", "up"]:
key = f"{prefix}{name}.{sub}.weight"
if hasattr(getattr(attn, name), sub):
model[key] = getattr(getattr(attn, name), sub).weight
else:
if strict:
raise KeyError(f"Missing module: {key}")
for k, param in state_dict.items():
if k in model:
if model[k].shape != param.shape:
if strict:
raise ValueError(
f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}"
)
else:
continue
model[k].data.copy_(param)
else:
if strict:
raise KeyError(f"Unexpected key in ckpt: {k}")
|