|
|
import torch
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
from safetensors.torch import load_file
|
|
|
from .models.gpt_t2i import GPT_models
|
|
|
from .models.generate import generate
|
|
|
from .tokenizer.vq_model import VQ_models
|
|
|
|
|
|
class CondRefARPipeline:
|
|
|
def __init__(self, device=None, torch_dtype=torch.bfloat16):
|
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
self.dtype = torch_dtype
|
|
|
self.gpt = None
|
|
|
self.vq = None
|
|
|
self.image_size = None
|
|
|
self.downsample = None
|
|
|
self.n_q = 8
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained(cls, repo_or_path, gpt_config, vq_config, gpt_weights="weights/sketch-gpt-xl.safetensors", vq_weights="weights/vq-16.safetensors", device=None, torch_dtype=torch.bfloat16):
|
|
|
pipe = cls(device=device, torch_dtype=torch_dtype)
|
|
|
|
|
|
|
|
|
pipe.downsample = int(vq_config["downsample_size"])
|
|
|
codebook_size = int(vq_config["codebook_size"])
|
|
|
codebook_embed_dim = int(vq_config["codebook_embed_dim"])
|
|
|
pipe.vq = VQ_models[vq_config.get("model_name", "VQ-16")](codebook_size=codebook_size, codebook_embed_dim=codebook_embed_dim)
|
|
|
vq_state = load_file(f"{repo_or_path}/{vq_weights}")
|
|
|
pipe.vq.load_state_dict(vq_state, strict=True)
|
|
|
pipe.vq.to(pipe.device)
|
|
|
pipe.vq.eval()
|
|
|
|
|
|
|
|
|
pipe.image_size = int(gpt_config["image_size"])
|
|
|
vocab_size = int(gpt_config["vocab_size"])
|
|
|
latent_size = pipe.image_size // pipe.downsample
|
|
|
block_size=latent_size ** 2
|
|
|
num_classes = int(gpt_config.get("num_classes", 1000))
|
|
|
cls_token_num = int(gpt_config.get("cls_token_num", 120))
|
|
|
model_type = gpt_config.get("model_type", "t2i")
|
|
|
adapter_size = gpt_config.get("adapter_size", "small")
|
|
|
condition_type = gpt_config.get("condition_type", "sketch")
|
|
|
|
|
|
|
|
|
pipe.gpt = GPT_models[gpt_config.get("gpt_name", "GPT-XL")](
|
|
|
vocab_size=vocab_size,
|
|
|
block_size=block_size,
|
|
|
num_classes=num_classes,
|
|
|
cls_token_num=cls_token_num,
|
|
|
model_type=model_type,
|
|
|
adapter_size=adapter_size,
|
|
|
condition_type=condition_type
|
|
|
).to(device=pipe.device, dtype=pipe.dtype)
|
|
|
gpt_state = load_file(f"{repo_or_path}/{gpt_weights}")
|
|
|
pipe.gpt.load_state_dict(gpt_state, strict=False)
|
|
|
pipe.gpt.eval()
|
|
|
|
|
|
return pipe
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def __call__(self, prompt_emb, control_image, cfg_scale=4, cfg_interval=-1, temperature=1.0, top_k=2000, top_p=1.0):
|
|
|
"""
|
|
|
prompt_emb: torch.Tensor [B, T_txt, D]
|
|
|
control_image: np.ndarray/PIL
|
|
|
Return: Image
|
|
|
"""
|
|
|
|
|
|
if isinstance(control_image, Image.Image):
|
|
|
control_image = np.array(control_image.convert("RGB"))
|
|
|
if isinstance(control_image, np.ndarray):
|
|
|
|
|
|
control_image = torch.from_numpy(control_image).permute(2,0,1).unsqueeze(0).float()
|
|
|
if control_image.max() > 1.0:
|
|
|
control_image = control_image / 255.0
|
|
|
control_image = 2.0 * (control_image - 0.5)
|
|
|
control = control_image.to(self.device, dtype=self.dtype)
|
|
|
|
|
|
c_indices = prompt_emb.to(self.device, dtype=self.dtype)
|
|
|
|
|
|
c_emb_masks = None
|
|
|
|
|
|
Hq = self.image_size // self.downsample
|
|
|
Wq = Hq
|
|
|
seq_len = Hq * Wq
|
|
|
|
|
|
index_sample = generate(
|
|
|
self.gpt, c_indices, seq_len, c_emb_masks,
|
|
|
condition=control, cfg_scale=cfg_scale, cfg_interval=cfg_interval,
|
|
|
temperature=temperature, top_k=top_k, top_p=top_p, sample_logits=True
|
|
|
)
|
|
|
|
|
|
if index_sample.dim() == 2 and index_sample.shape[1] == self.n_q * Hq * Wq:
|
|
|
tokens = index_sample.view(index_sample.size(0), self.n_q, Hq, Wq).long()
|
|
|
elif index_sample.dim() == 2 and index_sample.shape[1] == Hq * Wq:
|
|
|
tokens = index_sample.view(index_sample.size(0), 1, Hq, Wq).long()
|
|
|
else:
|
|
|
|
|
|
n_q = max(1, index_sample.shape[1] // (Hq * Wq))
|
|
|
tokens = index_sample[:, : n_q * Hq * Wq].view(index_sample.size(0), n_q, Hq, Wq).long()
|
|
|
tokens = tokens.to(self.device)
|
|
|
qzshape = [tokens.size(0), 8, Hq, Wq]
|
|
|
samples = self.vq.decode_code(tokens, qzshape).detach().float().cpu()
|
|
|
|
|
|
if samples.min() < -0.9:
|
|
|
samples = (samples + 1.0) / 2.0
|
|
|
samples = samples.clamp(0, 1)
|
|
|
|
|
|
imgs = []
|
|
|
arr = (samples * 255).to(torch.uint8).permute(0,2,3,1).numpy()
|
|
|
for i in range(arr.shape[0]):
|
|
|
imgs.append(Image.fromarray(arr[i]))
|
|
|
return imgs |