File size: 1,455 Bytes
26557da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch


def set_stand_in(pipe, train=False, model_path=None):
    for block in pipe.dit.blocks:
        block.self_attn.init_lora(train)
    if model_path is not None:
        print(f"Loading Stand-In weights from: {model_path}")
        load_lora_weights_into_pipe(pipe, model_path)


def load_lora_weights_into_pipe(pipe, ckpt_path, strict=True):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state_dict = ckpt.get("state_dict", ckpt)

    model = {}
    for i, block in enumerate(pipe.dit.blocks):
        prefix = f"blocks.{i}.self_attn."
        attn = block.self_attn
        for name in ["q_loras", "k_loras", "v_loras"]:
            for sub in ["down", "up"]:
                key = f"{prefix}{name}.{sub}.weight"
                if hasattr(getattr(attn, name), sub):
                    model[key] = getattr(getattr(attn, name), sub).weight
                else:
                    if strict:
                        raise KeyError(f"Missing module: {key}")

    for k, param in state_dict.items():
        if k in model:
            if model[k].shape != param.shape:
                if strict:
                    raise ValueError(
                        f"Shape mismatch: {k} | {model[k].shape} vs {param.shape}"
                    )
                else:
                    continue
            model[k].data.copy_(param)
        else:
            if strict:
                raise KeyError(f"Unexpected key in ckpt: {k}")