prembed_sd3.5 / app.py
wangpangintsig's picture
Update app.py
071c88a verified
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
from safetensors import safe_open
from huggingface_hub import hf_hub_download
import torch
def load_into(ckpt, model, prefix, device, dtype=None, remap=None):
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
for key in ckpt.keys():
model_key = key
if remap is not None and key in remap:
model_key = remap[key]
if model_key.startswith(prefix) and not model_key.startswith("loss."):
path = model_key[len(prefix) :].split(".")
obj = model
for p in path:
if obj is list:
obj = obj[int(p)]
else:
obj = getattr(obj, p, None)
if obj is None:
print(
f"Skipping key '{model_key}' in safetensors file as '{p}' does not exist in python model"
)
break
if obj is None:
continue
try:
tensor = ckpt.get_tensor(key).to(device=device)
if dtype is not None and tensor.dtype != torch.int32:
tensor = tensor.to(dtype=dtype)
obj.requires_grad_(False)
# print(f"K: {model_key}, O: {obj.shape} T: {tensor.shape}")
if obj.shape != tensor.shape:
print(
f"W: shape mismatch for key {model_key}, {obj.shape} != {tensor.shape}"
)
obj.set_(tensor)
except Exception as e:
print(f"Failed to load key '{key}' in safetensors file: {e}")
raise e
CLIPG_CONFIG = {
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
}
class ClipG:
def __init__(self, model_folder: str, device: str = "cpu"):
safetensors_path = hf_hub_download(
repo_id=model_folder,
filename="clip_g.safetensors",
cache_dir=None
)
with safe_open(
# f"{model_folder}/clip_g.safetensors", framework="pt", device="cpu"
safetensors_path, framework="pt", device="cpu"
) as f:
self.model = SDXLClipG(CLIPG_CONFIG, device=device, dtype=torch.float32)
load_into(f, self.model.transformer, "", device, torch.float32)
CLIPL_CONFIG = {
"hidden_act": "quick_gelu",
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
}
class ClipL:
def __init__(self, model_folder: str):
safetensors_path = hf_hub_download(
repo_id=model_folder,
filename="clip_l.safetensors",
cache_dir=None
)
with safe_open(
# f"{model_folder}/clip_l.safetensors", framework="pt", device="cpu"
safetensors_path, framework="pt", device="cpu"
) as f:
self.model = SDClipModel(
layer="hidden",
layer_idx=-2,
device="cpu",
dtype=torch.float32,
layer_norm_hidden_state=False,
return_projected_pooled=False,
textmodel_json_config=CLIPL_CONFIG,
)
load_into(f, self.model.transformer, "", "cpu", torch.float32)
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
"num_heads": 64,
"num_layers": 24,
"vocab_size": 32128,
}
class T5XXL:
def __init__(self, model_folder: str, device: str = "cpu", dtype=torch.float32):
safetensors_path = hf_hub_download(
repo_id=model_folder,
filename="t5xxl.safetensors",
cache_dir=None
)
with safe_open(
# f"{model_folder}/t5xxl.safetensors", framework="pt", device="cpu"
safetensors_path, framework="pt", device="cpu"
) as f:
self.model = T5XXLModel(T5_CONFIG, device=device, dtype=dtype)
load_into(f, self.model.transformer, "", device, dtype)
tokenizer = SD3Tokenizer()
text_encoder_device = "cpu"
model_folder = "stabilityai/stable-diffusion-3.5-medium"
print("Loading Google T5-v1-XXL...")
t5xxl = T5XXL(model_folder, text_encoder_device, torch.float32)
print("Loading OpenAI CLIP L...")
clip_l = ClipL(model_folder)
print("Loading OpenCLIP bigG...")
clip_g = ClipG(model_folder, text_encoder_device)
def get_cond(self, prompt):
print("Encode prompt...")
tokens = tokenizer.tokenize_with_weights(prompt)
l_out, l_pooled = clip_l.model.encode_token_weights(tokens["l"])
g_out, g_pooled = clip_g.model.encode_token_weights(tokens["g"])
t5_out, t5_pooled = t5xxl.model.encode_token_weights(tokens["t5xxl"])
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
return torch.cat([lg_out, t5_out], dim=-2), torch.cat(
(l_pooled, g_pooled), dim=-1
)