Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
def set_stand_in(pipe, train=False, model_path=None): | |
for block in pipe.dit.blocks: | |
block.self_attn.init_lora(train) | |
if model_path is not None: | |
print(f"Loading Stand-In weights from: {model_path}") | |
load_lora_weights_into_pipe(pipe, model_path) | |
def load_lora_weights_into_pipe(pipe, ckpt_path, strict=True): | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
state_dict = ckpt.get("state_dict", ckpt) | |
model = {} | |
for i, block in enumerate(pipe.dit.blocks): | |
prefix = f"blocks.{i}.self_attn." | |
attn = block.self_attn | |
for name in ["q_loras", "k_loras", "v_loras"]: | |
for sub in ["down", "up"]: | |
key = f"{prefix}{name}.{sub}.weight" | |
if hasattr(getattr(attn, name), sub): | |
model[key] = getattr(getattr(attn, name), sub).weight | |
else: | |
if strict: | |
raise KeyError(f"Missing module: {key}") | |
for k, param in state_dict.items(): | |
if k in model: | |
if model[k].shape != param.shape: | |
if strict: | |
raise ValueError( | |
f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}" | |
) | |
else: | |
continue | |
model[k].data.copy_(param) | |
else: | |
if strict: | |
raise KeyError(f"Unexpected key in ckpt: {k}") | |