Spaces:
Runtime error
Runtime error
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel | |
from safetensors import safe_open | |
from huggingface_hub import hf_hub_download | |
import torch | |
def load_into(ckpt, model, prefix, device, dtype=None, remap=None): | |
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module.""" | |
for key in ckpt.keys(): | |
model_key = key | |
if remap is not None and key in remap: | |
model_key = remap[key] | |
if model_key.startswith(prefix) and not model_key.startswith("loss."): | |
path = model_key[len(prefix) :].split(".") | |
obj = model | |
for p in path: | |
if obj is list: | |
obj = obj[int(p)] | |
else: | |
obj = getattr(obj, p, None) | |
if obj is None: | |
print( | |
f"Skipping key '{model_key}' in safetensors file as '{p}' does not exist in python model" | |
) | |
break | |
if obj is None: | |
continue | |
try: | |
tensor = ckpt.get_tensor(key).to(device=device) | |
if dtype is not None and tensor.dtype != torch.int32: | |
tensor = tensor.to(dtype=dtype) | |
obj.requires_grad_(False) | |
# print(f"K: {model_key}, O: {obj.shape} T: {tensor.shape}") | |
if obj.shape != tensor.shape: | |
print( | |
f"W: shape mismatch for key {model_key}, {obj.shape} != {tensor.shape}" | |
) | |
obj.set_(tensor) | |
except Exception as e: | |
print(f"Failed to load key '{key}' in safetensors file: {e}") | |
raise e | |
CLIPG_CONFIG = { | |
"hidden_act": "gelu", | |
"hidden_size": 1280, | |
"intermediate_size": 5120, | |
"num_attention_heads": 20, | |
"num_hidden_layers": 32, | |
} | |
class ClipG: | |
def __init__(self, model_folder: str, device: str = "cpu"): | |
safetensors_path = hf_hub_download( | |
repo_id=model_folder, | |
filename="clip_g.safetensors", | |
cache_dir=None | |
) | |
with safe_open( | |
# f"{model_folder}/clip_g.safetensors", framework="pt", device="cpu" | |
safetensors_path, framework="pt", device="cpu" | |
) as f: | |
self.model = SDXLClipG(CLIPG_CONFIG, device=device, dtype=torch.float32) | |
load_into(f, self.model.transformer, "", device, torch.float32) | |
CLIPL_CONFIG = { | |
"hidden_act": "quick_gelu", | |
"hidden_size": 768, | |
"intermediate_size": 3072, | |
"num_attention_heads": 12, | |
"num_hidden_layers": 12, | |
} | |
class ClipL: | |
def __init__(self, model_folder: str): | |
safetensors_path = hf_hub_download( | |
repo_id=model_folder, | |
filename="clip_l.safetensors", | |
cache_dir=None | |
) | |
with safe_open( | |
# f"{model_folder}/clip_l.safetensors", framework="pt", device="cpu" | |
safetensors_path, framework="pt", device="cpu" | |
) as f: | |
self.model = SDClipModel( | |
layer="hidden", | |
layer_idx=-2, | |
device="cpu", | |
dtype=torch.float32, | |
layer_norm_hidden_state=False, | |
return_projected_pooled=False, | |
textmodel_json_config=CLIPL_CONFIG, | |
) | |
load_into(f, self.model.transformer, "", "cpu", torch.float32) | |
T5_CONFIG = { | |
"d_ff": 10240, | |
"d_model": 4096, | |
"num_heads": 64, | |
"num_layers": 24, | |
"vocab_size": 32128, | |
} | |
class T5XXL: | |
def __init__(self, model_folder: str, device: str = "cpu", dtype=torch.float32): | |
safetensors_path = hf_hub_download( | |
repo_id=model_folder, | |
filename="t5xxl.safetensors", | |
cache_dir=None | |
) | |
with safe_open( | |
# f"{model_folder}/t5xxl.safetensors", framework="pt", device="cpu" | |
safetensors_path, framework="pt", device="cpu" | |
) as f: | |
self.model = T5XXLModel(T5_CONFIG, device=device, dtype=dtype) | |
load_into(f, self.model.transformer, "", device, dtype) | |
tokenizer = SD3Tokenizer() | |
text_encoder_device = "cpu" | |
model_folder = "stabilityai/stable-diffusion-3.5-medium" | |
print("Loading Google T5-v1-XXL...") | |
t5xxl = T5XXL(model_folder, text_encoder_device, torch.float32) | |
print("Loading OpenAI CLIP L...") | |
clip_l = ClipL(model_folder) | |
print("Loading OpenCLIP bigG...") | |
clip_g = ClipG(model_folder, text_encoder_device) | |
def get_cond(self, prompt): | |
print("Encode prompt...") | |
tokens = tokenizer.tokenize_with_weights(prompt) | |
l_out, l_pooled = clip_l.model.encode_token_weights(tokens["l"]) | |
g_out, g_pooled = clip_g.model.encode_token_weights(tokens["g"]) | |
t5_out, t5_pooled = t5xxl.model.encode_token_weights(tokens["t5xxl"]) | |
lg_out = torch.cat([l_out, g_out], dim=-1) | |
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) | |
return torch.cat([lg_out, t5_out], dim=-2), torch.cat( | |
(l_pooled, g_pooled), dim=-1 | |
) |