Spaces:
Paused
Paused
import os | |
import sys | |
import copy | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, CLIPTextModel | |
from diffusers import AutoencoderKL, UNet2DConditionModel | |
from peft import LoraConfig | |
from peft.utils import get_peft_model_state_dict | |
p = "src/" | |
sys.path.append(p) | |
from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd, download_url | |
class VAE_encode(nn.Module): | |
def __init__(self, vae, vae_b2a=None): | |
super(VAE_encode, self).__init__() | |
self.vae = vae | |
self.vae_b2a = vae_b2a | |
def forward(self, x, direction): | |
assert direction in ["a2b", "b2a"] | |
if direction == "a2b": | |
_vae = self.vae | |
else: | |
_vae = self.vae_b2a | |
return _vae.encode(x).latent_dist.sample() * _vae.config.scaling_factor | |
class VAE_decode(nn.Module): | |
def __init__(self, vae, vae_b2a=None): | |
super(VAE_decode, self).__init__() | |
self.vae = vae | |
self.vae_b2a = vae_b2a | |
def forward(self, x, direction): | |
assert direction in ["a2b", "b2a"] | |
if direction == "a2b": | |
_vae = self.vae | |
else: | |
_vae = self.vae_b2a | |
assert _vae.encoder.current_down_blocks is not None | |
_vae.decoder.incoming_skip_acts = _vae.encoder.current_down_blocks | |
x_decoded = (_vae.decode(x / _vae.config.scaling_factor).sample).clamp(-1, 1) | |
return x_decoded | |
def initialize_unet(rank, return_lora_module_names=False): | |
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet") | |
unet.requires_grad_(False) | |
unet.train() | |
l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], [] | |
l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"] | |
for n, p in unet.named_parameters(): | |
if "bias" in n or "norm" in n: continue | |
for pattern in l_grep: | |
if pattern in n and ("down_blocks" in n or "conv_in" in n): | |
l_target_modules_encoder.append(n.replace(".weight","")) | |
break | |
elif pattern in n and "up_blocks" in n: | |
l_target_modules_decoder.append(n.replace(".weight","")) | |
break | |
elif pattern in n: | |
l_modules_others.append(n.replace(".weight","")) | |
break | |
lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder, lora_alpha=rank) | |
lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder, lora_alpha=rank) | |
lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others, lora_alpha=rank) | |
unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") | |
unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") | |
unet.add_adapter(lora_conf_others, adapter_name="default_others") | |
unet.set_adapters(["default_encoder", "default_decoder", "default_others"]) | |
if return_lora_module_names: | |
return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others | |
else: | |
return unet | |
def initialize_vae(rank=4, return_lora_module_names=False): | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae") | |
vae.requires_grad_(False) | |
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__) | |
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__) | |
vae.requires_grad_(True) | |
vae.train() | |
# add the skip connection convs | |
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True) | |
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True) | |
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True) | |
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True) | |
torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5) | |
torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5) | |
torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5) | |
torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5) | |
vae.decoder.ignore_skip = False | |
vae.decoder.gamma = 1 | |
l_vae_target_modules = ["conv1","conv2","conv_in", "conv_shortcut", | |
"conv", "conv_out", "skip_conv_1", "skip_conv_2", "skip_conv_3", | |
"skip_conv_4", "to_k", "to_q", "to_v", "to_out.0", | |
] | |
vae_lora_config = LoraConfig(r=rank, init_lora_weights="gaussian", target_modules=l_vae_target_modules) | |
vae.add_adapter(vae_lora_config, adapter_name="vae_skip") | |
if return_lora_module_names: | |
return vae, l_vae_target_modules | |
else: | |
return vae | |
class CycleGAN_Turbo(torch.nn.Module): | |
def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4): | |
super().__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer") | |
self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda() | |
self.sched = make_1step_sched() | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae") | |
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet") | |
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__) | |
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__) | |
# add the skip connection convs | |
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() | |
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() | |
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() | |
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() | |
vae.decoder.ignore_skip = False | |
self.unet, self.vae = unet, vae | |
if pretrained_name == "day_to_night": | |
url = "https://www.cs.cmu.edu/~img2img-turbo/models/day2night.pkl" | |
self.load_ckpt_from_url(url, ckpt_folder) | |
self.timesteps = torch.tensor([999], device="cuda").long() | |
self.caption = "driving in the night" | |
self.direction = "a2b" | |
elif pretrained_name == "night_to_day": | |
url = "https://www.cs.cmu.edu/~img2img-turbo/models/night2day.pkl" | |
self.load_ckpt_from_url(url, ckpt_folder) | |
self.timesteps = torch.tensor([999], device="cuda").long() | |
self.caption = "driving in the day" | |
self.direction = "b2a" | |
elif pretrained_name == "clear_to_rainy": | |
url = "https://www.cs.cmu.edu/~img2img-turbo/models/clear2rainy.pkl" | |
self.load_ckpt_from_url(url, ckpt_folder) | |
self.timesteps = torch.tensor([999], device="cuda").long() | |
self.caption = "driving in heavy rain" | |
self.direction = "a2b" | |
elif pretrained_name == "rainy_to_clear": | |
url = "https://www.cs.cmu.edu/~img2img-turbo/models/rainy2clear.pkl" | |
self.load_ckpt_from_url(url, ckpt_folder) | |
self.timesteps = torch.tensor([999], device="cuda").long() | |
self.caption = "driving in the day" | |
self.direction = "b2a" | |
elif pretrained_path is not None: | |
sd = torch.load(pretrained_path) | |
self.load_ckpt_from_state_dict(sd) | |
self.timesteps = torch.tensor([999], device="cuda").long() | |
self.caption = None | |
self.direction = None | |
self.vae_enc.cuda() | |
self.vae_dec.cuda() | |
self.unet.cuda() | |
def load_ckpt_from_state_dict(self, sd): | |
lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_encoder"], lora_alpha=sd["rank_unet"]) | |
lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_decoder"], lora_alpha=sd["rank_unet"]) | |
lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_modules_others"], lora_alpha=sd["rank_unet"]) | |
self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder") | |
self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder") | |
self.unet.add_adapter(lora_conf_others, adapter_name="default_others") | |
for n, p in self.unet.named_parameters(): | |
name_sd = n.replace(".default_encoder.weight", ".weight") | |
if "lora" in n and "default_encoder" in n: | |
p.data.copy_(sd["sd_encoder"][name_sd]) | |
for n, p in self.unet.named_parameters(): | |
name_sd = n.replace(".default_decoder.weight", ".weight") | |
if "lora" in n and "default_decoder" in n: | |
p.data.copy_(sd["sd_decoder"][name_sd]) | |
for n, p in self.unet.named_parameters(): | |
name_sd = n.replace(".default_others.weight", ".weight") | |
if "lora" in n and "default_others" in n: | |
p.data.copy_(sd["sd_other"][name_sd]) | |
self.unet.set_adapter(["default_encoder", "default_decoder", "default_others"]) | |
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"]) | |
self.vae.add_adapter(vae_lora_config, adapter_name="vae_skip") | |
self.vae.decoder.gamma = 1 | |
self.vae_b2a = copy.deepcopy(self.vae) | |
self.vae_enc = VAE_encode(self.vae, vae_b2a=self.vae_b2a) | |
self.vae_enc.load_state_dict(sd["sd_vae_enc"]) | |
self.vae_dec = VAE_decode(self.vae, vae_b2a=self.vae_b2a) | |
self.vae_dec.load_state_dict(sd["sd_vae_dec"]) | |
def load_ckpt_from_url(self, url, ckpt_folder): | |
os.makedirs(ckpt_folder, exist_ok=True) | |
outf = os.path.join(ckpt_folder, os.path.basename(url)) | |
download_url(url, outf) | |
sd = torch.load(outf) | |
self.load_ckpt_from_state_dict(sd) | |
def forward_with_networks(x, direction, vae_enc, unet, vae_dec, sched, timesteps, text_emb): | |
B = x.shape[0] | |
assert direction in ["a2b", "b2a"] | |
x_enc = vae_enc(x, direction=direction).to(x.dtype) | |
model_pred = unet(x_enc, timesteps, encoder_hidden_states=text_emb,).sample | |
x_out = torch.stack([sched.step(model_pred[i], timesteps[i], x_enc[i], return_dict=True).prev_sample for i in range(B)]) | |
x_out_decoded = vae_dec(x_out, direction=direction) | |
return x_out_decoded | |
def get_traininable_params(unet, vae_a2b, vae_b2a): | |
# add all unet parameters | |
params_gen = list(unet.conv_in.parameters()) | |
unet.conv_in.requires_grad_(True) | |
unet.set_adapters(["default_encoder", "default_decoder", "default_others"]) | |
for n,p in unet.named_parameters(): | |
if "lora" in n and "default" in n: | |
assert p.requires_grad | |
params_gen.append(p) | |
# add all vae_a2b parameters | |
for n,p in vae_a2b.named_parameters(): | |
if "lora" in n and "vae_skip" in n: | |
assert p.requires_grad | |
params_gen.append(p) | |
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_1.parameters()) | |
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_2.parameters()) | |
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_3.parameters()) | |
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_4.parameters()) | |
# add all vae_b2a parameters | |
for n,p in vae_b2a.named_parameters(): | |
if "lora" in n and "vae_skip" in n: | |
assert p.requires_grad | |
params_gen.append(p) | |
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_1.parameters()) | |
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_2.parameters()) | |
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_3.parameters()) | |
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_4.parameters()) | |
return params_gen | |
def forward(self, x_t, direction=None, caption=None, caption_emb=None): | |
if direction is None: | |
assert self.direction is not None | |
direction = self.direction | |
if caption is None and caption_emb is None: | |
assert self.caption is not None | |
caption = self.caption | |
if caption_emb is not None: | |
caption_enc = caption_emb | |
else: | |
caption_tokens = self.tokenizer(caption, max_length=self.tokenizer.model_max_length, | |
padding="max_length", truncation=True, return_tensors="pt").input_ids.to(x_t.device) | |
caption_enc = self.text_encoder(caption_tokens)[0].detach().clone() | |
return self.forward_with_networks(x_t, direction, self.vae_enc, self.unet, self.vae_dec, self.sched, self.timesteps, caption_enc) | |