CondRef-AR / CondRefAR /pipeline.py
PuTorch's picture
upload CondRef-AR model
6ffba01 verified
import torch
import numpy as np
from PIL import Image
from safetensors.torch import load_file
from .models.gpt_t2i import GPT_models
from .models.generate import generate
from .tokenizer.vq_model import VQ_models
class CondRefARPipeline:
def __init__(self, device=None, torch_dtype=torch.bfloat16):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch_dtype
self.gpt = None
self.vq = None
self.image_size = None
self.downsample = None
self.n_q = 8
@classmethod
def from_pretrained(cls, repo_or_path, gpt_config, vq_config, gpt_weights="weights/sketch-gpt-xl.safetensors", vq_weights="weights/vq-16.safetensors", device=None, torch_dtype=torch.bfloat16):
pipe = cls(device=device, torch_dtype=torch_dtype)
# 1) VQ
pipe.downsample = int(vq_config["downsample_size"])
codebook_size = int(vq_config["codebook_size"])
codebook_embed_dim = int(vq_config["codebook_embed_dim"])
pipe.vq = VQ_models[vq_config.get("model_name", "VQ-16")](codebook_size=codebook_size, codebook_embed_dim=codebook_embed_dim)
vq_state = load_file(f"{repo_or_path}/{vq_weights}")
pipe.vq.load_state_dict(vq_state, strict=True)
pipe.vq.to(pipe.device)
pipe.vq.eval()
# 2) GPT
pipe.image_size = int(gpt_config["image_size"])
vocab_size = int(gpt_config["vocab_size"])
latent_size = pipe.image_size // pipe.downsample
block_size=latent_size ** 2
num_classes = int(gpt_config.get("num_classes", 1000))
cls_token_num = int(gpt_config.get("cls_token_num", 120))
model_type = gpt_config.get("model_type", "t2i")
adapter_size = gpt_config.get("adapter_size", "small")
condition_type = gpt_config.get("condition_type", "sketch")
pipe.gpt = GPT_models[gpt_config.get("gpt_name", "GPT-XL")](
vocab_size=vocab_size,
block_size=block_size,
num_classes=num_classes,
cls_token_num=cls_token_num,
model_type=model_type,
adapter_size=adapter_size,
condition_type=condition_type
).to(device=pipe.device, dtype=pipe.dtype)
gpt_state = load_file(f"{repo_or_path}/{gpt_weights}")
pipe.gpt.load_state_dict(gpt_state, strict=False)
pipe.gpt.eval()
return pipe
@torch.inference_mode()
def __call__(self, prompt_emb, control_image, cfg_scale=4, cfg_interval=-1, temperature=1.0, top_k=2000, top_p=1.0):
"""
prompt_emb: torch.Tensor [B, T_txt, D]
control_image: np.ndarray/PIL
Return: Image
"""
# 预处理 control
if isinstance(control_image, Image.Image):
control_image = np.array(control_image.convert("RGB"))
if isinstance(control_image, np.ndarray):
# [H,W,C] uint8 -> [-1,1]
control_image = torch.from_numpy(control_image).permute(2,0,1).unsqueeze(0).float()
if control_image.max() > 1.0:
control_image = control_image / 255.0
control_image = 2.0 * (control_image - 0.5)
control = control_image.to(self.device, dtype=self.dtype)
# 文本嵌入
c_indices = prompt_emb.to(self.device, dtype=self.dtype)
# 这里的 emb_mask 若需要,可在外部构造后传入;为了最小示例,这里置 None
c_emb_masks = None
Hq = self.image_size // self.downsample
Wq = Hq
seq_len = Hq * Wq
# 采样 codebook 索引序列(generate 返回 [B, n_q*Hq*Wq] 或 [B, seq_len] 逐 codebook 生成)
index_sample = generate(
self.gpt, c_indices, seq_len, c_emb_masks,
condition=control, cfg_scale=cfg_scale, cfg_interval=cfg_interval,
temperature=temperature, top_k=top_k, top_p=top_p, sample_logits=True
)
# 重排 [B, n_q, Hq, Wq]
if index_sample.dim() == 2 and index_sample.shape[1] == self.n_q * Hq * Wq:
tokens = index_sample.view(index_sample.size(0), self.n_q, Hq, Wq).long()
elif index_sample.dim() == 2 and index_sample.shape[1] == Hq * Wq:
tokens = index_sample.view(index_sample.size(0), 1, Hq, Wq).long()
else:
# 尝试自动推断 n_q
n_q = max(1, index_sample.shape[1] // (Hq * Wq))
tokens = index_sample[:, : n_q * Hq * Wq].view(index_sample.size(0), n_q, Hq, Wq).long()
tokens = tokens.to(self.device)
qzshape = [tokens.size(0), 8, Hq, Wq]
samples = self.vq.decode_code(tokens, qzshape).detach().float().cpu()
# [-1,1] -> [0,1]
if samples.min() < -0.9:
samples = (samples + 1.0) / 2.0
samples = samples.clamp(0, 1)
imgs = []
arr = (samples * 255).to(torch.uint8).permute(0,2,3,1).numpy()
for i in range(arr.shape[0]):
imgs.append(Image.fromarray(arr[i]))
return imgs