Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import spaces | |
import gradio as gr | |
import torch | |
from artistic_portrait.pipeline import ArtisticPortraitXLPipeline | |
from diffusers import ControlNetModel, DPMSolverMultistepScheduler | |
from ip_adapter_diffusers.ip_adapter import * | |
from huggingface_hub import hf_hub_download | |
style_adapter_path = "models/ip_adapter_art_sdxl_512.pth" | |
id_adapter_path = "models/pulid_adapter_diffusers_1.1.pth" | |
if not os.path.exists("models/csd_clip.pth"): | |
hf_hub_download( | |
repo_id="AisingioroHao0/IP-Adapter-Art", | |
filename="csd_clip.pth", | |
local_dir="models", | |
) | |
if not os.path.exists(style_adapter_path): | |
hf_hub_download( | |
repo_id="AisingioroHao0/IP-Adapter-Art", | |
filename="ip_adapter_art_sdxl_512.pth", | |
local_dir="models", | |
) | |
if not os.path.exists(id_adapter_path): | |
hf_hub_download( | |
repo_id="AisingioroHao0/IP-Adapter-Art", | |
filename="pulid_adapter_diffusers_1.1.pth", | |
local_dir="models", | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
sdxl_repo_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
torch_dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 | |
# Load pretrained models. | |
print("Initializing pipeline...") | |
controlnet = ControlNetModel.from_pretrained( | |
"xinsir/controlnet-openpose-sdxl-1.0", | |
torch_dtype=torch_dtype, | |
).to(device) | |
pipe = ArtisticPortraitXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
controlnet=controlnet, | |
safety_checker=None, | |
torch_dtype=torch_dtype, | |
style_adapter_path=style_adapter_path, | |
id_adapter_path=id_adapter_path, | |
device=device, | |
).to(device) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config, timestep_spacing="trailing" | |
) | |
load_ip_adapter( | |
pipe.controlnet, | |
"models/ip_adapter_art_sdxl_512.pth", | |
) | |
example_inputs = [ | |
[ | |
"datasets/test/style_dataset/Abstract D'Oyley.jpg", | |
"datasets/test/id_dataset/lifeifei.jpg", | |
], | |
[ | |
"datasets/test/style_dataset/Adam Zyglis.jpg", | |
"datasets/test/id_dataset/lecun.jpg", | |
], | |
[ | |
"datasets/test/style_dataset/Diffused lighting.jpg", | |
"datasets/test/id_dataset/liuyifei.jpg", | |
], | |
[ | |
"datasets/test/style_dataset/Shirley Hughes.jpg", | |
"datasets/test/id_dataset/rihanna.jpg", | |
], | |
[ | |
"datasets/test/style_dataset/Winter.jpg", | |
"datasets/test/id_dataset/hinton.jpg", | |
], | |
] | |
def generation( | |
style_image=None, | |
id_image=None, | |
pose_image=None, | |
prompt="portrait, solo, looking at viewer, best quality, masterpiece", | |
negative_prompt="flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, low resolution, partially rendered objects, deformed or partially rendered eyes, deformed, deformed eyeballs, cross-eyed", | |
num_inference_steps=20, | |
guidance_scale=7.0, | |
style_scale=1.0, | |
id_scale=1.0, | |
controlnet_scale=0.9, | |
seed=42, | |
height=1024, | |
width=1024, | |
artify_contorlnet_scale=0.0, | |
): | |
set_ip_adapter_scale(pipe.controlnet, artify_contorlnet_scale) | |
result = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
control_image=pose_image, | |
controlnet_conditioning_scale=controlnet_scale, | |
width=width, | |
height=height, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
style_image=style_image, | |
id_image=id_image, | |
generator=torch.Generator(device).manual_seed(seed), | |
id_scale=id_scale, | |
style_scale=style_scale, | |
).images[0] | |
return result | |
with gr.Blocks(delete_cache=(3600, 3600)) as demo: | |
gr.Markdown( | |
""" | |
# Artistic Portrait Generation 0.9: Generate Customized Artistic Portrait through Style Reference Images | |
**Implementation based on [Art-Adapter](https://github.com/aihao2000/IP-Adapter-Art), [PuLID-Adapter](https://github.com/ToTheBeginning/PuLID), and [Instant Style](https://github.com/instantX-research/InstantStyle).** | |
## Basic usage: | |
- Stylized Portrait Generation: Upload the style reference image and ID reference image, and click "Generation" to generate the artistic portrait directly. | |
- Text-guided Stylization Generation: Set ID Scale to 0, modify prompt, and then try text-guided stylized image generation through **Art-Adapter**. **(Note that ID image cannot be empty in the current version.)** | |
_If the style similarity is low, try increasing the Stylize Contorlnet Scale, or set the Controlnet Scale to 0._ | |
## News | |
- 2025.3.24: We released Artistic Portrait Generation 0.9. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
style_image = gr.Image( | |
label="Style Reference Image", | |
type="pil", | |
) | |
id_image = gr.Image( | |
label="ID Reference Image", | |
type="pil", | |
) | |
pose_image = gr.Image( | |
label="Pose Reference Image", | |
type="pil", | |
value="datasets/test/pose.jpg", | |
) | |
with gr.Row(): | |
clear_btn = gr.ClearButton() | |
generation_btn = gr.Button("Generation") | |
with gr.Row(): | |
id_scale = gr.Number(label="ID Scale", value=1.0, step=0.1) | |
style_scale = gr.Number(label="Style Scale", value=1.0, step=0.1) | |
controlnet_scale = gr.Number( | |
label="ControlNet Scale", value=0.9, step=0.1 | |
) | |
stylize_contorlnet_scale = gr.Number( | |
label="Stylize ControlNet Scale", value=0.0, step=0.1 | |
) | |
guidance_scale = gr.Number(label="CFG Scale", value=7.0, step=0.1) | |
with gr.Row(): | |
height = gr.Number(label="Height", step=1, maximum=1024, value=1024) | |
width = gr.Number(label="Width", step=1, maximum=1024, value=1024) | |
seed = gr.Number(label="Seed", value=42, step=1) | |
num_inference_steps = gr.Number(label="Steps", value=20, step=1) | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="portrait, solo, looking at viewer, best quality, masterpiece", | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, low resolution, partially rendered objects, deformed or partially rendered eyes, deformed, deformed eyeballs, cross-eyed", | |
) | |
with gr.Column(): | |
output = gr.Image(label="Result", type="pil") | |
with gr.Row(): | |
examples = gr.Examples( | |
examples=example_inputs, | |
inputs=[style_image, id_image], | |
outputs=[ | |
output, | |
], | |
fn=lambda x, y: None, | |
cache_examples=False, | |
) | |
clear_btn.add([style_image, id_image, pose_image, output]) | |
generation_btn.click( | |
generation, | |
inputs=[ | |
style_image, | |
id_image, | |
pose_image, | |
prompt, | |
negative_prompt, | |
num_inference_steps, | |
guidance_scale, | |
style_scale, | |
id_scale, | |
controlnet_scale, | |
seed, | |
height, | |
width, | |
stylize_contorlnet_scale, | |
], | |
outputs=[output], | |
api_name="artistic_portrait_gen", | |
) | |
demo.queue().launch(share=True) | |