File size: 5,097 Bytes
c219d70
 
 
071c88a
c219d70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
from safetensors import safe_open
from huggingface_hub import hf_hub_download
import torch

def load_into(ckpt, model, prefix, device, dtype=None, remap=None):
    """Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
    for key in ckpt.keys():
        model_key = key
        if remap is not None and key in remap:
            model_key = remap[key]
        if model_key.startswith(prefix) and not model_key.startswith("loss."):
            path = model_key[len(prefix) :].split(".")
            obj = model
            for p in path:
                if obj is list:
                    obj = obj[int(p)]
                else:
                    obj = getattr(obj, p, None)
                    if obj is None:
                        print(
                            f"Skipping key '{model_key}' in safetensors file as '{p}' does not exist in python model"
                        )
                        break
            if obj is None:
                continue
            try:
                tensor = ckpt.get_tensor(key).to(device=device)
                if dtype is not None and tensor.dtype != torch.int32:
                    tensor = tensor.to(dtype=dtype)
                obj.requires_grad_(False)
                # print(f"K: {model_key}, O: {obj.shape} T: {tensor.shape}")
                if obj.shape != tensor.shape:
                    print(
                        f"W: shape mismatch for key {model_key}, {obj.shape} != {tensor.shape}"
                    )
                obj.set_(tensor)
            except Exception as e:
                print(f"Failed to load key '{key}' in safetensors file: {e}")
                raise e

CLIPG_CONFIG = {
    "hidden_act": "gelu",
    "hidden_size": 1280,
    "intermediate_size": 5120,
    "num_attention_heads": 20,
    "num_hidden_layers": 32,
}


class ClipG:
    def __init__(self, model_folder: str, device: str = "cpu"):
        safetensors_path = hf_hub_download(
            repo_id=model_folder,
            filename="clip_g.safetensors",
            cache_dir=None
        )
        with safe_open(
            # f"{model_folder}/clip_g.safetensors", framework="pt", device="cpu"
            safetensors_path, framework="pt", device="cpu"
        ) as f:
            self.model = SDXLClipG(CLIPG_CONFIG, device=device, dtype=torch.float32)
            load_into(f, self.model.transformer, "", device, torch.float32)


CLIPL_CONFIG = {
    "hidden_act": "quick_gelu",
    "hidden_size": 768,
    "intermediate_size": 3072,
    "num_attention_heads": 12,
    "num_hidden_layers": 12,
}


class ClipL:
    def __init__(self, model_folder: str):
        safetensors_path = hf_hub_download(
            repo_id=model_folder,
            filename="clip_l.safetensors",
            cache_dir=None
        )
        with safe_open(
            # f"{model_folder}/clip_l.safetensors", framework="pt", device="cpu"
            safetensors_path, framework="pt", device="cpu"
        ) as f:
            self.model = SDClipModel(
                layer="hidden",
                layer_idx=-2,
                device="cpu",
                dtype=torch.float32,
                layer_norm_hidden_state=False,
                return_projected_pooled=False,
                textmodel_json_config=CLIPL_CONFIG,
            )
            load_into(f, self.model.transformer, "", "cpu", torch.float32)


T5_CONFIG = {
    "d_ff": 10240,
    "d_model": 4096,
    "num_heads": 64,
    "num_layers": 24,
    "vocab_size": 32128,
}


class T5XXL:
    def __init__(self, model_folder: str, device: str = "cpu", dtype=torch.float32):
        safetensors_path = hf_hub_download(
            repo_id=model_folder,
            filename="t5xxl.safetensors",
            cache_dir=None
        )
        with safe_open(
            # f"{model_folder}/t5xxl.safetensors", framework="pt", device="cpu"
            safetensors_path, framework="pt", device="cpu"
        ) as f:
            self.model = T5XXLModel(T5_CONFIG, device=device, dtype=dtype)
            load_into(f, self.model.transformer, "", device, dtype)


tokenizer = SD3Tokenizer()
text_encoder_device = "cpu"
model_folder = "stabilityai/stable-diffusion-3.5-medium"
print("Loading Google T5-v1-XXL...")
t5xxl = T5XXL(model_folder, text_encoder_device, torch.float32)
print("Loading OpenAI CLIP L...")
clip_l = ClipL(model_folder)
print("Loading OpenCLIP bigG...")
clip_g = ClipG(model_folder, text_encoder_device)


def get_cond(self, prompt):
    print("Encode prompt...")
    tokens = tokenizer.tokenize_with_weights(prompt)
    l_out, l_pooled = clip_l.model.encode_token_weights(tokens["l"])
    g_out, g_pooled = clip_g.model.encode_token_weights(tokens["g"])
    t5_out, t5_pooled = t5xxl.model.encode_token_weights(tokens["t5xxl"])
    lg_out = torch.cat([l_out, g_out], dim=-1)
    lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
    return torch.cat([lg_out, t5_out], dim=-2), torch.cat(
        (l_pooled, g_pooled), dim=-1
    )