|
import os |
|
import binascii |
|
from safetensors import safe_open |
|
|
|
import torch |
|
|
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint |
|
|
|
def rand_name(length=8, suffix=''): |
|
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') |
|
if suffix: |
|
if not suffix.startswith('.'): |
|
suffix = '.' + suffix |
|
name += suffix |
|
return name |
|
|
|
def cycle(dl): |
|
while True: |
|
for data in dl: |
|
yield data |
|
|
|
def exists(x): |
|
return x is not None |
|
|
|
def identity(x): |
|
return x |
|
|
|
def load_dreambooth_lora(unet, vae=None, model_path=None, alpha=1.0, model_base=""): |
|
if model_path is None: return unet |
|
|
|
if model_path.endswith(".ckpt"): |
|
base_state_dict = torch.load(model_path)['state_dict'] |
|
elif model_path.endswith(".safetensors"): |
|
state_dict = {} |
|
with safe_open(model_path, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
state_dict[key] = f.get_tensor(key) |
|
|
|
is_lora = all("lora" in k for k in state_dict.keys()) |
|
if not is_lora: |
|
base_state_dict = state_dict |
|
else: |
|
base_state_dict = {} |
|
with safe_open(model_base, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
base_state_dict[key] = f.get_tensor(key) |
|
|
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, unet.config) |
|
unet_state_dict = unet.state_dict() |
|
for key in converted_unet_checkpoint: |
|
converted_unet_checkpoint[key] = alpha * converted_unet_checkpoint[key] + (1.0-alpha) * unet_state_dict[key] |
|
unet.load_state_dict(converted_unet_checkpoint, strict=False) |
|
|
|
if vae is not None: |
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, vae.config) |
|
vae.load_state_dict(converted_vae_checkpoint) |
|
|
|
return unet, vae |