File size: 5,119 Bytes
6ffba01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import numpy as np
from PIL import Image
from safetensors.torch import load_file
from .models.gpt_t2i import GPT_models
from .models.generate import generate
from .tokenizer.vq_model import VQ_models

class CondRefARPipeline:
    def __init__(self, device=None, torch_dtype=torch.bfloat16):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch_dtype
        self.gpt = None
        self.vq = None
        self.image_size = None
        self.downsample = None
        self.n_q = 8  

    @classmethod
    def from_pretrained(cls, repo_or_path, gpt_config, vq_config, gpt_weights="weights/sketch-gpt-xl.safetensors", vq_weights="weights/vq-16.safetensors", device=None, torch_dtype=torch.bfloat16):
        pipe = cls(device=device, torch_dtype=torch_dtype)

        # 1) VQ
        pipe.downsample = int(vq_config["downsample_size"])
        codebook_size = int(vq_config["codebook_size"])
        codebook_embed_dim = int(vq_config["codebook_embed_dim"])
        pipe.vq = VQ_models[vq_config.get("model_name", "VQ-16")](codebook_size=codebook_size, codebook_embed_dim=codebook_embed_dim)
        vq_state = load_file(f"{repo_or_path}/{vq_weights}")
        pipe.vq.load_state_dict(vq_state, strict=True)
        pipe.vq.to(pipe.device)
        pipe.vq.eval()

        # 2) GPT
        pipe.image_size = int(gpt_config["image_size"])
        vocab_size = int(gpt_config["vocab_size"])
        latent_size = pipe.image_size // pipe.downsample
        block_size=latent_size ** 2
        num_classes = int(gpt_config.get("num_classes", 1000))
        cls_token_num = int(gpt_config.get("cls_token_num", 120))
        model_type = gpt_config.get("model_type", "t2i")
        adapter_size = gpt_config.get("adapter_size", "small")
        condition_type = gpt_config.get("condition_type", "sketch")
        

        pipe.gpt = GPT_models[gpt_config.get("gpt_name", "GPT-XL")](
            vocab_size=vocab_size,
            block_size=block_size,
            num_classes=num_classes,
            cls_token_num=cls_token_num,
            model_type=model_type,
            adapter_size=adapter_size,
            condition_type=condition_type
        ).to(device=pipe.device, dtype=pipe.dtype)
        gpt_state = load_file(f"{repo_or_path}/{gpt_weights}")
        pipe.gpt.load_state_dict(gpt_state, strict=False)
        pipe.gpt.eval()

        return pipe

    @torch.inference_mode()
    def __call__(self, prompt_emb, control_image, cfg_scale=4, cfg_interval=-1, temperature=1.0, top_k=2000, top_p=1.0):
        """

        prompt_emb: torch.Tensor [B, T_txt, D]

        control_image: np.ndarray/PIL

        Return: Image

        """
        # 预处理 control
        if isinstance(control_image, Image.Image):
            control_image = np.array(control_image.convert("RGB"))
        if isinstance(control_image, np.ndarray):
            # [H,W,C] uint8 -> [-1,1]
            control_image = torch.from_numpy(control_image).permute(2,0,1).unsqueeze(0).float()
            if control_image.max() > 1.0:
                control_image = control_image / 255.0
            control_image = 2.0 * (control_image - 0.5)
        control = control_image.to(self.device, dtype=self.dtype)
        # 文本嵌入
        c_indices = prompt_emb.to(self.device, dtype=self.dtype)
        # 这里的 emb_mask 若需要,可在外部构造后传入;为了最小示例,这里置 None
        c_emb_masks = None

        Hq = self.image_size // self.downsample
        Wq = Hq
        seq_len = Hq * Wq
        # 采样 codebook 索引序列(generate 返回 [B, n_q*Hq*Wq] 或 [B, seq_len] 逐 codebook 生成)
        index_sample = generate(
            self.gpt, c_indices, seq_len, c_emb_masks,
            condition=control, cfg_scale=cfg_scale, cfg_interval=cfg_interval,
            temperature=temperature, top_k=top_k, top_p=top_p, sample_logits=True
        )
        # 重排 [B, n_q, Hq, Wq]
        if index_sample.dim() == 2 and index_sample.shape[1] == self.n_q * Hq * Wq:
            tokens = index_sample.view(index_sample.size(0), self.n_q, Hq, Wq).long()
        elif index_sample.dim() == 2 and index_sample.shape[1] == Hq * Wq:
            tokens = index_sample.view(index_sample.size(0), 1, Hq, Wq).long()
        else:
            # 尝试自动推断 n_q
            n_q = max(1, index_sample.shape[1] // (Hq * Wq))
            tokens = index_sample[:, : n_q * Hq * Wq].view(index_sample.size(0), n_q, Hq, Wq).long()
        tokens = tokens.to(self.device)
        qzshape = [tokens.size(0), 8, Hq, Wq]
        samples = self.vq.decode_code(tokens, qzshape).detach().float().cpu()
        # [-1,1] -> [0,1]
        if samples.min() < -0.9:
            samples = (samples + 1.0) / 2.0
        samples = samples.clamp(0, 1)

        imgs = []
        arr = (samples * 255).to(torch.uint8).permute(0,2,3,1).numpy()
        for i in range(arr.shape[0]):
            imgs.append(Image.fromarray(arr[i]))
        return imgs