Spaces:
Running
Running
import gradio as gr | |
from io import BytesIO | |
import os | |
import sys | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import numpy as np | |
from PIL import Image | |
from omegaconf import OmegaConf | |
import torch | |
from torchvision import transforms as T | |
from revq.models.quantizer import sinkhorn | |
from revq.models.preprocessor import Preprocessor | |
from revq.models.revq import ReVQ | |
from revq.models.revq_quantizer import Quantizer | |
from revq.utils.init import seed_everything | |
seed_everything(42) | |
from revq.models.vqgan_hf import VQModelHF | |
# matplotlib.rcParams['font.family'] = 'Times New Roman' | |
from diffusers import AutoencoderDC | |
################# | |
handler = None | |
device = torch.device("cpu") | |
################# | |
def load_preprocessor(device, is_eval: bool = True, ckpt_path: str = "./ckpt/preprocessor.pth"): | |
preprocessor = Preprocessor( | |
input_data_size=[32,8,8] | |
).to(device) | |
preprocessor.load_state_dict( | |
torch.load(ckpt_path, map_location=device, weights_only=True) | |
) | |
if is_eval: | |
preprocessor.eval() | |
return preprocessor | |
# ReVQ: for reset strategy | |
def fig_to_array(fig): | |
buf = BytesIO() | |
fig.savefig(buf, format='png') # 改为 png,不用 webp | |
buf.seek(0) | |
image = Image.open(buf) | |
return np.array(image) | |
def get_codebook(quantizer): | |
with torch.no_grad(): | |
codes = quantizer.embeddings.squeeze().detach() | |
return codes | |
def draw_fig(ax, quantizer, data, color="r", title=""): | |
codes = get_codebook(quantizer) | |
ax.scatter(data[:, 0], data[:, 1], s=60, marker="*") | |
if color == "r": | |
ax.scatter(codes[:, 0], codes[:, 1], s=40, c='red', alpha=0.5) | |
else: | |
ax.scatter(codes[:, 0], codes[:, 1], s=40, c='green', alpha=0.5) | |
ax.set_xlim(-5, 10) | |
ax.set_ylim(-10, 5) | |
ax.tick_params(axis='x', labelsize=22) | |
ax.tick_params(axis='y', labelsize=22) | |
ax.set_xticks(np.arange(-5, 11, 5)) | |
ax.set_yticks(np.arange(-10, 6, 5)) | |
ax.grid(linestyle='--', color='#333333', alpha=0.7) | |
ax.set_title(f"{title}", fontsize=24) | |
def draw_arrow(ax, start, end): | |
for i in range(len(start)): | |
ax.arrow(start[i][0], start[i][1], end[i][0] - start[i][0], end[i][1] - start[i][1], | |
head_width=0.1, head_length=0.1, fc='orange', ec='orange', alpha=0.8, | |
ls="-", lw=1) | |
def draw_reset_result(num_data=16, num_code=12): | |
fig_reset, ax_reset = plt.subplots(1, 6, figsize=(36, 6), dpi=400) | |
fig_nreset, ax_nreset = plt.subplots(1, 6, figsize=(36, 6), dpi=400) | |
x = torch.randn(num_data, 1) * 2 + 5 | |
y = torch.randn(num_data, 1) * 2 - 5 | |
data = torch.cat([x, y], dim=1) | |
quantizer = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1) | |
optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1) | |
quantizer_nreset = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1, auto_reset=False) | |
optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1) | |
draw_fig(ax_reset[0], quantizer, data, color='g', title=f"Initialization") | |
draw_fig(ax_nreset[0], quantizer_nreset, data, color='r', title=f"Initialization") | |
ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=24) | |
ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=24) | |
i_list = [1, 3, 10, 50, 200] | |
count = 0 | |
for i in range(500): | |
optimizer.zero_grad() | |
optimizer_nreset.zero_grad() | |
output_dict = quantizer(data.unsqueeze(1)) | |
output_dict_nreset = quantizer_nreset(data.unsqueeze(1)) | |
quant_data = output_dict["x_quant"].squeeze() | |
quant_data_nreset = output_dict_nreset["x_quant"].squeeze() | |
indices = output_dict["indices"].squeeze() | |
indices = output_dict_nreset["indices"].squeeze() | |
loss = torch.mean((quant_data - data) ** 2) | |
loss_nreset = torch.mean((quant_data_nreset - data) ** 2) | |
loss.backward() | |
loss_nreset.backward() | |
optimizer.step() | |
optimizer_nreset.step() | |
if (i+1) in i_list: | |
count += 1 | |
draw_fig(ax_reset[count], quantizer, data, color='g', title=f"Iters: {i+1}, MSE: {loss.item():.1f}") | |
draw_arrow(ax_reset[count], quant_data.detach().numpy(), data.numpy()) | |
draw_fig(ax_nreset[count], quantizer_nreset, data, color='r', title=f"Iters: {i+1}, MSE: {loss_nreset.item():.1f}") | |
draw_arrow(ax_nreset[count], quant_data_nreset.detach().numpy(), data.numpy()) | |
quantizer.reset() | |
fig_reset.suptitle("VQ Codebook Training with Reset", fontsize=24, y=1.05) | |
fig_nreset.suptitle("VQ Codebook Training without Reset", fontsize=24, y=1.05) | |
img_reset = fig_to_array(fig_reset) | |
img_nreset = fig_to_array(fig_nreset) | |
return img_nreset, img_reset | |
# end | |
# ReVQ: for multi-group | |
def get_codebook_v2(quantizer): | |
with torch.no_grad(): | |
embedding = quantizer.embeddings | |
if quantizer.num_group == 1: | |
group1 = embedding[0].squeeze() | |
group2 = embedding[0].squeeze() | |
else: | |
group1 = embedding[0].squeeze() | |
group2 = embedding[1].squeeze() | |
codes = torch.cartesian_prod(group1, group2) | |
return codes | |
def draw_fig_v2(ax, quantizer, data, color='r', title=""): | |
codes = get_codebook_v2(quantizer) | |
ax.scatter(data[:, 0], data[:, 1], s=60, marker="*") | |
if color == "r": | |
ax.scatter(codes[:, 0], codes[:, 1], s=20, c='red', alpha=0.5) | |
else: | |
ax.scatter(codes[:, 0], codes[:, 1], s=20, c='green', alpha=0.5) | |
ax.plot([-12, 12], [-12, 12], color='orange', linestyle='--', linewidth=2) | |
ax.set_xlim(-12, 12) | |
ax.set_ylim(-12, 12) | |
ax.tick_params(axis='x', labelsize=22) | |
ax.tick_params(axis='y', labelsize=22) | |
ax.set_xticks(np.arange(-10, 11, 5)) | |
ax.set_yticks(np.arange(-10, 11, 5)) | |
ax.grid(linestyle='--', color='#333333', alpha=0.7) | |
ax.set_title(f"{title}", fontsize=26) | |
def draw_multi_group_result(num_data=16, num_code=12): | |
fig_s, ax_s = plt.subplots(1, 6, figsize=(36, 6), dpi=400) | |
fig_m, ax_m = plt.subplots(1, 6, figsize=(36, 6), dpi=400) | |
x = torch.randn(num_data, 1) * 3 + 4 | |
y = torch.randn(num_data, 1) * 3 - 4 | |
data = torch.cat([x, y], dim=1) | |
quantizer_s = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=1, tokens_per_data=2) | |
optimizer_s = torch.optim.SGD(quantizer_s.parameters(), lr=0.1) | |
quantizer_m = Quantizer(TYPE='vq', code_dim=1, num_code=num_code, num_group=2, tokens_per_data=2) | |
optimizer_m = torch.optim.SGD(quantizer_m.parameters(), lr=0.1) | |
draw_fig_v2(ax_s[0], quantizer_s, data, color='r', title=f"Initialization") | |
draw_fig_v2(ax_m[0], quantizer_m, data, color='g', title=f"Initialization") | |
ax_s[0].legend(["Data", "Code"], loc="upper right", fontsize=24) | |
ax_m[0].legend(["Data", "Code"], loc="upper right", fontsize=24) | |
i_list = [5, 20, 50, 200, 1000] | |
count = 0 | |
for i in range(1500): | |
optimizer_s.zero_grad() | |
optimizer_m.zero_grad() | |
quant_data_s = quantizer_s(data.unsqueeze(-1))["x_quant"].squeeze() | |
quant_data_m = quantizer_m(data.unsqueeze(-1))["x_quant"].squeeze() | |
loss_s = torch.mean((quant_data_s - data) ** 2) | |
loss_m = torch.mean((quant_data_m - data) ** 2) | |
loss_s.backward() | |
loss_m.backward() | |
optimizer_s.step() | |
optimizer_m.step() | |
if (i+1) in i_list: | |
count += 1 | |
draw_fig_v2(ax_s[count], quantizer_s, data, color='r', title=f"Iters: {i+1}, MSE: {loss_s.item():.1f}") | |
draw_fig_v2(ax_m[count], quantizer_m, data, color='g', title=f"Iters: {i+1}, MSE: {loss_m.item():.1f}") | |
quantizer_s.reset() | |
quantizer_m.reset() | |
fig_s.suptitle("VQ Codebook Training with Single Group", fontsize=24, y=1.05) | |
fig_m.suptitle("VQ Codebook Training with Multi Group", fontsize=24, y=1.05) | |
img_s = fig_to_array(fig_s) | |
img_m = fig_to_array(fig_m) | |
return img_s, img_m | |
# end | |
# ReVQ: for image reconstruction | |
class Handler: | |
def __init__(self, device): | |
self.transform = T.Compose([ | |
T.Resize(256), | |
T.CenterCrop(256), | |
T.ToTensor() | |
]) | |
self.device = device | |
self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4") | |
self.basevq.to(self.device) | |
self.basevq.eval() | |
self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16") | |
self.vqgan.to(self.device) | |
self.vqgan.eval() | |
self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4") | |
self.optvq.to(self.device) | |
self.optvq.eval() | |
self.vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers") | |
self.vae.to(self.device) | |
self.vae.eval() | |
self.preprocesser = load_preprocessor(self.device) | |
self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T") | |
self.revq.to(self.device) | |
self.revq.eval() | |
# print("Models loaded successfully!") | |
def tensor_to_image(self, tensor): | |
img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy() | |
img = (img + 1) / 2 * 255 | |
img = img.astype("uint8") | |
return img | |
def process_image(self, img: np.ndarray): | |
img = Image.fromarray(img.astype("uint8")) | |
img = self.transform(img) | |
img = img.unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
img = 2 * img - 1 | |
# basevq | |
quant, *_ = self.basevq.encode(img) | |
basevq_rec = self.basevq.decode(quant) | |
# vqgan | |
quant, *_ = self.vqgan.encode(img) | |
vqgan_rec = self.vqgan.decode(quant) | |
# revq | |
lat = self.vae.encode(img).latent | |
lat = lat.contiguous() | |
lat = self.preprocesser(lat) | |
lat = self.revq.quantize(lat) | |
revq_rec = self.revq.decode(lat) | |
revq_rec = revq_rec.contiguous() | |
revq_rec = self.preprocesser.inverse(revq_rec) | |
revq_rec = self.vae.decode(revq_rec).sample | |
# tensor to PIL image | |
img = self.tensor_to_image(img) | |
basevq_rec = self.tensor_to_image(basevq_rec) | |
vqgan_rec = self.tensor_to_image(vqgan_rec) | |
revq_rec = self.tensor_to_image(revq_rec) | |
return basevq_rec, vqgan_rec, revq_rec | |
if __name__ == "__main__": | |
# create the model handler | |
handler = Handler(device=device) | |
print("Creating Gradio interface...") | |
# Demo 1 接口:图像重建 | |
demo1 = gr.Interface( | |
fn=handler.process_image, | |
inputs=gr.Image(label="Input Image", type="numpy"), | |
outputs=[ | |
gr.Image(label="BaseVQ Reconstruction", type="numpy"), | |
gr.Image(label="VQGAN Reconstruction", type="numpy"), | |
gr.Image(label="ReVQ Reconstruction", type="numpy"), | |
], | |
title="Demo 1: Image Reconstruction", | |
description="Upload an image to see how different VQ models (BaseVQ, VQGAN, ReVQ) reconstruct it from latent codes." | |
) | |
with gr.Blocks() as demo2: | |
gr.Markdown("## Demo 2: Codebook Reset Strategy Visualization") | |
gr.Markdown("Visualizes codebook and data movement at different training steps with or without codebook reset strategy.") | |
with gr.Row(): | |
num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1) | |
num_code = gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1) | |
submit_btn = gr.Button("Run Visualization") | |
with gr.Column(): # 垂直输出 | |
out_without_reset = gr.Image(label="Without Reset") | |
out_with_reset = gr.Image(label="With Reset") | |
submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_without_reset, out_with_reset]) | |
with gr.Blocks() as demo3: | |
gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization") | |
gr.Markdown("Visualizes codebook and data movement at different training steps with or without multi-group strategy.") | |
with gr.Row(): | |
num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1) | |
num_code = gr.Slider(label="num_code", value=8, minimum=6, maximum=10, step=1) | |
submit_btn = gr.Button("Run Visualization") | |
with gr.Column(): # 垂直输出 | |
out_s = gr.Image(label="Single Group") | |
out_m = gr.Image(label="Multi Group") | |
submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m]) | |
demo = gr.TabbedInterface( | |
interface_list=[demo1, demo2, demo3], | |
tab_names=["Image Reconstruction", "Reset Strategy", "Channel Multi-Group Strategy"] | |
) | |
demo.launch(share=True) | |