Spaces:
Running
on
Zero
Running
on
Zero
# %% | |
import copy | |
from datetime import datetime | |
import os | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
from my_ipadapter_model import load_ipadapter, image_grid, generate | |
from my_intrinsic_dim import get_intrinsic_dim | |
from dino_clip_featextract import extract_dino_image_embeds, extract_clip_image_embeds, img_transform, img_transform_inv | |
from gradio_utils import add_download_button | |
from my_dino_correspondence import get_correspondence_plot, ncut_tsne_multiple_images, kway_cluster_per_image, get_single_multi_discrete_rgbs, match_centers_three_images, match_centers_two_images, get_center_features | |
from compression_model_mkii import CompressionModel, train_compression_model, free_memory, get_fg_mask | |
USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "false") | |
if USE_HUGGINGFACE_ZEROGPU: # huggingface ZeroGPU, dynamic GPU allocation | |
try: | |
import spaces | |
except: | |
USE_HUGGINGFACE_ZEROGPU = False | |
import torch | |
from PIL import Image | |
import numpy as np | |
import skdim | |
import matplotlib.pyplot as plt | |
plt.rcParams['font.family'] = 'monospace' | |
from omegaconf import OmegaConf | |
def train_mood_space(pil_images, lr=0.001, steps=5000, width=512, layers=4, dim=None, config_path="./config.yaml"): | |
images = load_gradio_images_helper(pil_images) | |
images = torch.stack([img_transform(image) for image in images]) | |
dino_image_embeds = extract_dino_image_embeds(images) | |
clip_image_embeds = extract_clip_image_embeds(images) | |
if dim is None: | |
dim = get_intrinsic_dim(dino_image_embeds.flatten(end_dim=-2)) | |
dim = int(dim) | |
print(f"intrinsic dim is {dim}") | |
else: | |
print(f"using user-specified dim: {dim}") | |
cfg = OmegaConf.load(config_path) | |
cfg.mood_dim = dim | |
cfg.lr = lr | |
cfg.steps = steps | |
cfg.latent_dim = width | |
cfg.n_layer = layers | |
model = CompressionModel(cfg, gradio_progress=True) #TODO: check if gradio_progress works without gradio | |
trainer = train_compression_model(model, cfg, dino_image_embeds, clip_image_embeds) | |
return model, trainer | |
if USE_HUGGINGFACE_ZEROGPU: | |
train_mood_space = spaces.GPU(duration=60)(train_mood_space) | |
def train_mood_space_visualize(image_embeds, dim=2, config_path="/workspace/n25c9900_2d.yaml"): | |
cfg = OmegaConf.load(config_path) | |
cfg.mood_dim = dim | |
cfg.in_dim = image_embeds.shape[-1] | |
cfg.out_dim = image_embeds.shape[-1] | |
model = CompressionModel(cfg, gradio_progress=True) #TODO: check if gradio_progress works without gradio | |
trainer = train_compression_model(model, cfg, image_embeds, image_embeds) | |
return model, trainer | |
def load_gradio_images_helper(pil_images): | |
if isinstance(pil_images[0], tuple): | |
pil_images = [image[0] for image in pil_images] | |
if isinstance(pil_images[0], str): | |
pil_images = [Image.open(image) for image in pil_images] | |
# convert to RGB | |
pil_images = [image.convert("RGB") for image in pil_images] | |
return pil_images | |
def find_direction_three_images(image_embeds, eigvecs, A2_to_A1, A1_to_B1): | |
# image_embeds: b, l, c; b = 3, A2, A1, B1 | |
# eigvecs: b, l | |
n_cluster = eigvecs[0].shape[-1] | |
A1_center_features = get_center_features(image_embeds[1], eigvecs[1].argmax(-1).cpu(), n_cluster=n_cluster) | |
B1_center_features = get_center_features(image_embeds[2], eigvecs[2].argmax(-1).cpu(), n_cluster=n_cluster) | |
direction_A_to_B = [] | |
for i_A, i_B in enumerate(A1_to_B1): | |
direction = B1_center_features[i_B] - A1_center_features[i_A] | |
# direction = B1_center_features[i_B] | |
# direction = direction / direction.norm(dim=-1, keepdim=True) | |
direction_A_to_B.append(direction) | |
direction_A_to_B = torch.stack(direction_A_to_B) | |
cluster_labels = eigvecs[0].argmax(-1).cpu() | |
n_cluster = eigvecs[0].shape[-1] | |
direction_for_A2 = torch.zeros_like(image_embeds[0]) | |
for i_cluster in range(n_cluster): | |
mask = cluster_labels == i_cluster | |
if mask.sum() > 0: | |
direction_for_A2[mask] = direction_A_to_B[A2_to_A1[i_cluster]] | |
return direction_for_A2 | |
def find_direction_two_images(image_embeds, eigvecs, A_to_B, unit_norm_direction=False): | |
# image_embeds: A, B | |
# eigvecs: A, B | |
n_cluster = eigvecs[0].shape[-1] | |
A_center_features = get_center_features(image_embeds[0], eigvecs[0].argmax(-1).cpu(), n_cluster=n_cluster) | |
B_center_features = get_center_features(image_embeds[1], eigvecs[1].argmax(-1).cpu(), n_cluster=n_cluster) | |
direction_A_to_B = [] | |
for i_A, i_B in enumerate(A_to_B): | |
direction = B_center_features[i_B] - A_center_features[i_A] | |
if unit_norm_direction: | |
direction = direction / direction.norm(dim=-1, keepdim=True) | |
direction_A_to_B.append(direction) | |
direction_A_to_B = torch.stack(direction_A_to_B) | |
cluster_labels = eigvecs[0].argmax(-1).cpu() | |
n_cluster = eigvecs[0].shape[-1] | |
direction_for_A = torch.zeros_like(image_embeds[0]) | |
for i_cluster in range(n_cluster): | |
mask = cluster_labels == i_cluster | |
if mask.sum() > 0: | |
direction_for_A[mask] = direction_A_to_B[i_cluster] | |
return direction_for_A | |
def analogy_three_images(image_list, model, ws, n_cluster=30, n_sample=1, match_method='hungarian'): | |
# image_list: A2, A1, B1 | |
# ws: list of float | |
# n_cluster: int | |
# n_sample: int | |
# match_method: str | |
free_memory() | |
images = torch.stack([img_transform(image) for image in image_list]) | |
dino_image_embeds = extract_dino_image_embeds(images) | |
compressed_image_embeds = model.compress(dino_image_embeds) | |
input_embeds = dino_image_embeds | |
_compressed_image_embeds = compressed_image_embeds | |
original_images = images | |
b, l, c = input_embeds.shape | |
joint_eigvecs, joint_rgbs = ncut_tsne_multiple_images(input_embeds, n_eig=30, gamma=0.5) | |
single_eigvecs = kway_cluster_per_image(input_embeds, n_cluster=n_cluster, gamma=0.5) | |
# single_eigvecs = kway_cluster_multiple_images(input_embeds, n_cluster=n_cluster, gamma=0.5) | |
discrete_rgbs = get_single_multi_discrete_rgbs(joint_rgbs, single_eigvecs) | |
A2_to_A1, A1_to_B1 = match_centers_three_images(dino_image_embeds, single_eigvecs, match_method=match_method) | |
direction = find_direction_three_images(_compressed_image_embeds, single_eigvecs, A2_to_A1, A1_to_B1) | |
cluster_orders = [ | |
np.arange(n_cluster), | |
A2_to_A1, | |
A1_to_B1[A2_to_A1], | |
] | |
correspondence_image = get_correspondence_plot(original_images, single_eigvecs, cluster_orders, discrete_rgbs, hw=16, n_cols=10) | |
ip_model = load_ipadapter() | |
n_steps = len(ws) | |
interpolated_images = [] | |
fig, axs = plt.subplots(n_sample, n_steps, figsize=(n_steps * 2, n_sample * 3)) | |
axs = axs.flatten() | |
progress = gr.Progress() | |
for i_w, w in enumerate(ws): | |
progress(i_w/n_steps, desc=f"Interpolating w={w:.2f}") | |
A2_interpolated = _compressed_image_embeds[0] + direction * w | |
A2_interpolated = model.uncompress(A2_interpolated) | |
gen_images = generate(ip_model, A2_interpolated, num_samples=n_sample) | |
interpolated_images.extend(gen_images) | |
for i_img in range(n_sample): | |
ax = axs[i_img * n_steps + i_w] | |
ax.imshow(gen_images[i_img]) | |
ax.axis('off') | |
if i_img == 0: | |
ax.set_title(f"w={w:.2f}") | |
fig.tight_layout() | |
del ip_model | |
free_memory() | |
return correspondence_image, fig, interpolated_images | |
if USE_HUGGINGFACE_ZEROGPU: | |
analogy_three_images = spaces.GPU(duration=60)(analogy_three_images) | |
def interpolate_two_images(image1, image2, model, ws, n_cluster=20, match_method='hungarian', unit_norm_direction=False, dino_matching=True, seed=None): | |
free_memory() | |
images = torch.stack([img_transform(image) for image in [image1, image2]]) | |
dino_image_embeds = extract_dino_image_embeds(images) | |
compressed_image_embeds = model.compress(dino_image_embeds) | |
input_embeds = dino_image_embeds | |
_compressed_image_embeds = compressed_image_embeds | |
original_images = images | |
b, l, c = input_embeds.shape | |
joint_eigvecs, joint_rgbs = ncut_tsne_multiple_images(input_embeds, n_eig=30, gamma=0.5) | |
single_eigvecs = kway_cluster_per_image(input_embeds, n_cluster=n_cluster, gamma=0.5) | |
# single_eigvecs = kway_cluster_multiple_images(input_embeds, n_cluster=n_cluster, gamma=0.5) | |
# discrete_rgbs = get_single_multi_discrete_rgbs(joint_rgbs, single_eigvecs) | |
A_to_B = match_centers_two_images(dino_image_embeds[0], dino_image_embeds[1], single_eigvecs[0], single_eigvecs[1], match_method=match_method) | |
if dino_matching: | |
direction = find_direction_two_images(_compressed_image_embeds, single_eigvecs, A_to_B, unit_norm_direction=unit_norm_direction) | |
else: | |
direction = _compressed_image_embeds[1] - _compressed_image_embeds[0] | |
ip_model = load_ipadapter() | |
n_steps = len(ws) | |
interpolated_images = [] | |
for i_w, w in enumerate(ws): | |
A_interpolated = _compressed_image_embeds[0] + direction * w | |
A_interpolated = model.uncompress(A_interpolated) | |
gen_images = generate(ip_model, A_interpolated, num_samples=1, seed=seed) | |
interpolated_images.extend(gen_images) | |
del ip_model | |
free_memory() | |
return interpolated_images | |
if USE_HUGGINGFACE_ZEROGPU: | |
interpolate_two_images = spaces.GPU(duration=60)(interpolate_two_images) | |
def plot_loss(model): | |
# Plot loss curves from trainer | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) | |
ax1.plot(model.loss_history['recon']) | |
ax1.set_xlabel('Steps') | |
ax1.set_ylabel('Loss') | |
ax1.set_title('Reconstruction Loss') | |
ax1.grid(True) | |
eigvec_loss = - np.array(model.loss_history['eigvec']) | |
ax2.plot(eigvec_loss) | |
ax2.set_xlabel('Steps') | |
ax2.set_ylabel('Loss') | |
ax2.set_title('Eigenvector Loss') | |
ax2.grid(True) | |
plt.tight_layout() | |
return fig | |
DEFAULT_IMAGES_PATH = ["./images/black_bear1.jpg", "./images/black_bear2.jpg", "./images/pink_bear1.jpg"] | |
DEFAULT_IMAGES = [Image.open(image_path) for image_path in DEFAULT_IMAGES_PATH] | |
# DEFAULT_IMAGES = [image.resize((512, 512), resample=Image.Resampling.LANCZOS) for image in DEFAULT_IMAGES] | |
if USE_HUGGINGFACE_ZEROGPU: | |
from download_models import download_ipadapter | |
download_ipadapter() | |
# %% | |
if __name__ == "__main__": | |
import gradio as gr | |
demo = gr.Blocks( | |
theme=gr.themes.Base(spacing_size='md', text_size='lg', primary_hue='blue', neutral_hue='slate', secondary_hue='pink'), | |
) | |
with demo: | |
model = gr.State([]) | |
with gr.Tab("1. Mood Space"): | |
gr.Markdown(""" | |
Instructions: | |
Please use the tabs to navigate through the app. | |
- Tab 1: Train a Mood Space compression model | |
- Tab 2: Interpolate between two images | |
- Tab 3: Path Lifting, given A1 -> B1, what's the A2 -> B2? | |
""") | |
# gr.Markdown("Train a Mood Space compression model") | |
with gr.Row(): | |
with gr.Column(): | |
input_images = gr.Gallery(label="Mood Board Images", show_label=False) | |
upload_button = gr.UploadButton(elem_id="upload_button", label="Upload", variant='secondary', file_types=["image"], file_count="multiple") | |
def convert_to_pil_and_append(images, new_images): | |
if images is None: | |
images = [] | |
if new_images is None: | |
return images | |
if isinstance(new_images, Image.Image): | |
images.append(new_images) | |
if isinstance(new_images, list): | |
images += [Image.open(new_image) for new_image in new_images] | |
if isinstance(new_images, str): | |
images.append(Image.open(new_images)) | |
gr.Info(f"Total images: {len(images)}") | |
return images | |
upload_button.upload(convert_to_pil_and_append, inputs=[input_images, upload_button], outputs=[input_images]) | |
# def load_example(): | |
# default_images = DEFAULT_IMAGES | |
# return default_images | |
def load_images(images): | |
return images | |
# load_example_button = gr.Button("Load Example Images") | |
# load_example_button.click(load_example, inputs=[], outputs=input_images) | |
# add_download_button(input_images, filename_prefix="mood_board_images") | |
with gr.Column(): | |
with gr.Accordion("Training Parameters", open=False): | |
lr = gr.Slider(minimum=0.0001, maximum=0.01, step=0.0001, value=0.001, label="Learning Rate") | |
steps = gr.Slider(minimum=1000, maximum=100000, step=100, value=1500, label="Training Steps") | |
width = gr.Slider(minimum=16, maximum=4096, step=16, value=512, label="MLP Width") | |
layers = gr.Slider(minimum=1, maximum=8, step=1, value=4, label="MLP Layers") | |
train_button = gr.Button("Train", variant="primary") | |
def _train_wrapper(images, lr, steps, width, layers): | |
model, trainer = train_mood_space(images, lr, steps, width, layers) | |
loss_plot = plot_loss(model) | |
gr.Info(f"Training complete.") | |
return model, loss_plot | |
loss_plot = gr.Plot(label="Training Loss") | |
train_button.click(_train_wrapper, inputs=[input_images, lr, steps, width, layers], outputs=[model, loss_plot]) | |
example_groups = { | |
"Dog -> Fish": ["./images/dog1.jpg", "./images/fish.jpg"], | |
"Dog -> Paper": ["./images/dog1.jpg", "./images/paper2.jpg"], | |
"Rotation": ["./images/black_bear1.jpg", "./images/black_bear2.jpg"], | |
"Rotation (Analogy)": ["./images/black_bear1.jpg", "./images/black_bear2.jpg", "./images/pink_bear1.jpg"], | |
"Duck -> Pixel": ["./images/duck1.jpg", "./images/duck_pixel.jpg"], | |
"Duck -> Paper": ["./images/duck1.jpg", "./images/toilet_paper.jpg"], | |
"Duck -> Paper (Analogy)": ["./images/duck1.jpg", "./images/toilet_paper.jpg", "./images/duck_pixel.jpg"], | |
} | |
def add_image_group_fn(group_gallery): | |
images = [tup[0] for tup in group_gallery] | |
# resize images to 512x512 | |
# images = [image.resize((512, 512), resample=Image.Resampling.LANCZOS) for image in images] | |
return images | |
gr.Markdown('## Examples') | |
for group_name, group_images in example_groups.items(): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
add_button = gr.Button(value=f'add example [{group_name}]', elem_classes=['small-button']) | |
with gr.Column(scale=7): | |
group_gallery = gr.Gallery( | |
value=group_images, | |
columns=5, | |
rows=1, | |
height=200, | |
object_fit='scale-down', | |
label=group_name, | |
elem_classes=['large-gallery'], | |
) | |
add_button.click( | |
add_image_group_fn, | |
inputs=[group_gallery], | |
outputs=[input_images], | |
) | |
with gr.Tab("2. Interpolate"): | |
# gr.Markdown("Interpolate between two images") | |
with gr.Row(): | |
input_A1 = gr.Image(label="A1", type="pil") | |
input_B1 = gr.Image(label="B1", type="pil") | |
with gr.Column(): | |
# def _load_two_images(): | |
# default_images = DEFAULT_IMAGES[:2] | |
# return default_images | |
# load_example_button3 = gr.Button("Load Example Images") | |
# load_example_button3.click(_load_two_images, inputs=[], outputs=[input_A1, input_B1]) | |
fill_in_images_button = gr.Button("Reload Images") | |
with gr.Accordion("Interpolation Parameters", open=False): | |
w_left = gr.Slider(minimum=-10, maximum=10, step=0.01, value=0, label="Start w") | |
w_right = gr.Slider(minimum=-10, maximum=10, step=0.01, value=1, label="End w") | |
n_steps = gr.Slider(minimum=1, maximum=100, step=2, value=10, label="N interpolation") | |
n_sample = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="N samples per interpolation") | |
n_cluster = gr.Slider(minimum=1, maximum=100, step=1, value=10, label="N segments", info="for correspondence matching") | |
match_method = gr.Radio(choices=['hungarian', 'argmin'], value='hungarian', label="Matching Method") | |
interpolate_button = gr.Button("Run Interpolation", variant="primary") | |
interpolated_images_plot = gr.Image(label="interpolated images") | |
interpolated_images = gr.Gallery(label="Interpolated Images", show_label=False, visible=False) | |
add_download_button(interpolated_images, filename_prefix="interpolated_images") | |
def _infer_two_images(A1, B1, model, w_left, w_right, n_steps, n_cluster, n_sample, match_method): | |
if model is None or model == []: | |
gr.Error("Please train a model first.") | |
return None, None, None | |
pil_images = [A1, B1] | |
images = load_gradio_images_helper(pil_images) | |
ws = torch.linspace(w_left, w_right, n_steps) | |
interpolated_images = interpolate_two_images(*images, model, ws, n_cluster, match_method) | |
# resize interpolated_images to 512x512 | |
interpolated_images = [image.resize((512, 512), resample=Image.Resampling.LANCZOS) for image in interpolated_images] | |
plot_images = [images[0].resize((512, 512), resample=Image.Resampling.LANCZOS)] + interpolated_images + [images[1].resize((512, 512), resample=Image.Resampling.LANCZOS)] | |
plot_images = image_grid(plot_images, 2, len(plot_images)//2) | |
return interpolated_images, plot_images | |
interpolate_button.click(_infer_two_images, | |
inputs=[input_A1, input_B1, model, w_left, w_right, n_steps, n_cluster, n_sample, match_method], | |
outputs=[interpolated_images, interpolated_images_plot]) | |
## fill in the images from input_images | |
def fill_in_images(input_images): | |
if input_images is None: | |
return None | |
return input_images[0][0], input_images[1][0] | |
fill_in_images_button.click(fill_in_images, inputs=[input_images], outputs=[input_A1, input_B1]) | |
input_images.change(fill_in_images, inputs=[input_images], outputs=[input_A1, input_B1]) | |
with gr.Tab("3. Path Lifting"): | |
gr.Markdown(""" | |
given A1 -> B1, infer A2 -> B2 | |
""") | |
with gr.Row(): | |
input_A1 = gr.Image(label="A1", type="pil") | |
input_B1 = gr.Image(label="B1", type="pil") | |
input_A2 = gr.Image(label="A2", type="pil") | |
picked_B2 = gr.Image(label="B2", type="pil", interactive=False) | |
with gr.Column(): | |
# def _load_three_images(): | |
# default_images = DEFAULT_IMAGES | |
# return default_images | |
# load_example_button2 = gr.Button("Load Example Images") | |
# load_example_button2.click(_load_three_images, inputs=[], outputs=[input_A2, input_A1, input_B1]) | |
fill_in_images_button2 = gr.Button("Reload Images") | |
with gr.Accordion("Interpolation Parameters", open=False): | |
w_left = gr.Slider(minimum=-10, maximum=10, step=0.01, value=0, label="Start w") | |
w_right = gr.Slider(minimum=-10, maximum=10, step=0.01, value=1., label="End w") | |
n_steps = gr.Slider(minimum=1, maximum=100, step=2, value=12, label="N interpolation") | |
n_sample = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="N samples per interpolation") | |
n_cluster = gr.Slider(minimum=1, maximum=100, step=1, value=10, label="N segments", info="for correspondence matching") | |
match_method = gr.Radio(choices=['hungarian', 'argmin'], value='hungarian', label="Matching Method") | |
interpolate_button = gr.Button("Run Path Lifting", variant="primary") | |
def revert_images(A1, B1, A2, B2): | |
return B1, A1, B2, A2 | |
revert_button = gr.Button("Revert Images", variant="secondary") | |
revert_button.click(revert_images, inputs=[input_A1, input_B1, input_A2, picked_B2], outputs=[input_A1, input_B1, input_A2, picked_B2]) | |
output_B2 = gr.Plot(label="B2 (interpolated)") | |
interpolated_images = gr.Gallery(label="Interpolated Images", show_label=False, visible=False) | |
correspondence_image = gr.Image(label="Correspondence Image", interactive=False) | |
add_download_button(interpolated_images, filename_prefix="interpolated_images") | |
def pick_best_image(interpolated_images, evt: gr.SelectData): | |
best_image = interpolated_images[evt.index][0] | |
logging_text = f"Selected Eigenvector at Index #{evt.index}" | |
label = F'Eigenvector at Index #{evt.index}' | |
return best_image | |
interpolated_images.select(pick_best_image, interpolated_images, [picked_B2]) | |
def _infer_three_images(A2, A1, B1, model, w_left, w_right, n_steps, n_cluster, n_sample, match_method): | |
if model is None or model == []: | |
gr.Error("Please train a model first.") | |
return None, None, None | |
pil_images = [A2, A1, B1] | |
images = load_gradio_images_helper(pil_images) | |
ws = torch.linspace(w_left, w_right, n_steps) | |
correspondence_image, fig, interpolated_images = analogy_three_images(images, model, ws, n_cluster, n_sample, match_method) | |
# resize interpolated_images to 512x512 | |
interpolated_images = [image.resize((512, 512), resample=Image.Resampling.LANCZOS) for image in interpolated_images] | |
return correspondence_image, fig, interpolated_images | |
interpolate_button.click(_infer_three_images, | |
inputs=[input_A2, input_A1, input_B1, model, w_left, w_right, n_steps, n_cluster, n_sample, match_method], | |
outputs=[correspondence_image, output_B2, interpolated_images]) | |
## fill in the images from input_images | |
def fill_in_images(input_images): | |
if input_images is None: | |
return None | |
if len(input_images) == 2: | |
return input_images[0][0], input_images[1][0], input_images[0][0] | |
elif len(input_images) == 3: | |
return input_images[0][0], input_images[1][0], input_images[2][0] | |
fill_in_images_button2.click(fill_in_images, inputs=[input_images], outputs=[input_A1, input_B1, input_A2]) | |
input_images.change(fill_in_images, inputs=[input_images], outputs=[input_A1, input_B1, input_A2]) | |
# with gr.Tab("3. Make Plot"): | |
# plot_button = gr.Button("Make Plot", variant="primary") | |
# gallery_fig = gr.Gallery(label="Gallery", show_label=False, type="filepath") | |
# add_download_button(gallery_fig, filename_prefix="output_images") | |
# def open_images(imgA1, imgB1, imgA2, imgB2): | |
# img_list = [imgA1, imgB1, imgA2, imgB2] | |
# for _img in [imgA1, imgB1, imgA2, imgB2]: | |
# img = load_gradio_images_helper([_img]) | |
# img = img[0].resize((512, 512), resample=Image.Resampling.LANCZOS) | |
# img_list.append(img) | |
# img_list = img_list[:4] | |
# img_grid = image_grid(img_list[:4], 1, 4) | |
# img_list.append(img_grid) | |
# img_grid = image_grid(img_list[:4], 2, 2) | |
# img_list.append(img_grid) | |
# return img_list | |
# plot_button.click(open_images, inputs=[input_A1, input_B1, input_A2, picked_B2], outputs=[gallery_fig]) | |
demo.launch(share=True) | |