Stand-In / models /set_condition_branch.py
fffiloni's picture
Migrated from GitHub
26557da verified
import torch
def set_stand_in(pipe, train=False, model_path=None):
for block in pipe.dit.blocks:
block.self_attn.init_lora(train)
if model_path is not None:
print(f"Loading Stand-In weights from: {model_path}")
load_lora_weights_into_pipe(pipe, model_path)
def load_lora_weights_into_pipe(pipe, ckpt_path, strict=True):
ckpt = torch.load(ckpt_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
model = {}
for i, block in enumerate(pipe.dit.blocks):
prefix = f"blocks.{i}.self_attn."
attn = block.self_attn
for name in ["q_loras", "k_loras", "v_loras"]:
for sub in ["down", "up"]:
key = f"{prefix}{name}.{sub}.weight"
if hasattr(getattr(attn, name), sub):
model[key] = getattr(getattr(attn, name), sub).weight
else:
if strict:
raise KeyError(f"Missing module: {key}")
for k, param in state_dict.items():
if k in model:
if model[k].shape != param.shape:
if strict:
raise ValueError(
f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}"
)
else:
continue
model[k].data.copy_(param)
else:
if strict:
raise KeyError(f"Unexpected key in ckpt: {k}")