import torch import numpy as np from PIL import Image from safetensors.torch import load_file from .models.gpt_t2i import GPT_models from .models.generate import generate from .tokenizer.vq_model import VQ_models class CondRefARPipeline: def __init__(self, device=None, torch_dtype=torch.bfloat16): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.dtype = torch_dtype self.gpt = None self.vq = None self.image_size = None self.downsample = None self.n_q = 8 @classmethod def from_pretrained(cls, repo_or_path, gpt_config, vq_config, gpt_weights="weights/sketch-gpt-xl.safetensors", vq_weights="weights/vq-16.safetensors", device=None, torch_dtype=torch.bfloat16): pipe = cls(device=device, torch_dtype=torch_dtype) # 1) VQ pipe.downsample = int(vq_config["downsample_size"]) codebook_size = int(vq_config["codebook_size"]) codebook_embed_dim = int(vq_config["codebook_embed_dim"]) pipe.vq = VQ_models[vq_config.get("model_name", "VQ-16")](codebook_size=codebook_size, codebook_embed_dim=codebook_embed_dim) vq_state = load_file(f"{repo_or_path}/{vq_weights}") pipe.vq.load_state_dict(vq_state, strict=True) pipe.vq.to(pipe.device) pipe.vq.eval() # 2) GPT pipe.image_size = int(gpt_config["image_size"]) vocab_size = int(gpt_config["vocab_size"]) latent_size = pipe.image_size // pipe.downsample block_size=latent_size ** 2 num_classes = int(gpt_config.get("num_classes", 1000)) cls_token_num = int(gpt_config.get("cls_token_num", 120)) model_type = gpt_config.get("model_type", "t2i") adapter_size = gpt_config.get("adapter_size", "small") condition_type = gpt_config.get("condition_type", "sketch") pipe.gpt = GPT_models[gpt_config.get("gpt_name", "GPT-XL")]( vocab_size=vocab_size, block_size=block_size, num_classes=num_classes, cls_token_num=cls_token_num, model_type=model_type, adapter_size=adapter_size, condition_type=condition_type ).to(device=pipe.device, dtype=pipe.dtype) gpt_state = load_file(f"{repo_or_path}/{gpt_weights}") pipe.gpt.load_state_dict(gpt_state, strict=False) pipe.gpt.eval() return pipe @torch.inference_mode() def __call__(self, prompt_emb, control_image, cfg_scale=4, cfg_interval=-1, temperature=1.0, top_k=2000, top_p=1.0): """ prompt_emb: torch.Tensor [B, T_txt, D] control_image: np.ndarray/PIL Return: Image """ # 预处理 control if isinstance(control_image, Image.Image): control_image = np.array(control_image.convert("RGB")) if isinstance(control_image, np.ndarray): # [H,W,C] uint8 -> [-1,1] control_image = torch.from_numpy(control_image).permute(2,0,1).unsqueeze(0).float() if control_image.max() > 1.0: control_image = control_image / 255.0 control_image = 2.0 * (control_image - 0.5) control = control_image.to(self.device, dtype=self.dtype) # 文本嵌入 c_indices = prompt_emb.to(self.device, dtype=self.dtype) # 这里的 emb_mask 若需要,可在外部构造后传入;为了最小示例,这里置 None c_emb_masks = None Hq = self.image_size // self.downsample Wq = Hq seq_len = Hq * Wq # 采样 codebook 索引序列(generate 返回 [B, n_q*Hq*Wq] 或 [B, seq_len] 逐 codebook 生成) index_sample = generate( self.gpt, c_indices, seq_len, c_emb_masks, condition=control, cfg_scale=cfg_scale, cfg_interval=cfg_interval, temperature=temperature, top_k=top_k, top_p=top_p, sample_logits=True ) # 重排 [B, n_q, Hq, Wq] if index_sample.dim() == 2 and index_sample.shape[1] == self.n_q * Hq * Wq: tokens = index_sample.view(index_sample.size(0), self.n_q, Hq, Wq).long() elif index_sample.dim() == 2 and index_sample.shape[1] == Hq * Wq: tokens = index_sample.view(index_sample.size(0), 1, Hq, Wq).long() else: # 尝试自动推断 n_q n_q = max(1, index_sample.shape[1] // (Hq * Wq)) tokens = index_sample[:, : n_q * Hq * Wq].view(index_sample.size(0), n_q, Hq, Wq).long() tokens = tokens.to(self.device) qzshape = [tokens.size(0), 8, Hq, Wq] samples = self.vq.decode_code(tokens, qzshape).detach().float().cpu() # [-1,1] -> [0,1] if samples.min() < -0.9: samples = (samples + 1.0) / 2.0 samples = samples.clamp(0, 1) imgs = [] arr = (samples * 255).to(torch.uint8).permute(0,2,3,1).numpy() for i in range(arr.shape[0]): imgs.append(Image.fromarray(arr[i])) return imgs