Spaces:
Paused
Paused
File size: 6,263 Bytes
05fcd0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
from diffusers_helper.utils import crop_or_pad_yield_mask
@torch.no_grad()
def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
assert isinstance(prompt, str)
prompt = [prompt]
# LLAMA
# Check if there's a custom system prompt template in settings
custom_template = None
try:
from modules.settings import Settings
settings = Settings()
override_system_prompt = settings.get("override_system_prompt", False)
custom_template_str = settings.get("system_prompt_template")
if override_system_prompt and custom_template_str:
try:
# Convert the string representation to a dictionary
# Extract template and crop_start directly from the string using regex
import re
# Try to extract the template value
template_match = re.search(r"['\"]template['\"]\s*:\s*['\"](.+?)['\"](?=\s*,|\s*})", custom_template_str, re.DOTALL)
crop_start_match = re.search(r"['\"]crop_start['\"]\s*:\s*(\d+)", custom_template_str)
if template_match and crop_start_match:
template_value = template_match.group(1)
crop_start_value = int(crop_start_match.group(1))
# Unescape any escaped characters in the template
template_value = template_value.replace("\\n", "\n").replace("\\\"", "\"").replace("\\'", "'")
custom_template = {
"template": template_value,
"crop_start": crop_start_value
}
print(f"Using custom system prompt template from settings: {custom_template}")
else:
print(f"Could not extract template or crop_start from system prompt template string")
print(f"Falling back to default template")
custom_template = None
except Exception as e:
print(f"Error parsing custom system prompt template: {e}")
print(f"Falling back to default template")
custom_template = None
else:
if not override_system_prompt:
print(f"Override system prompt is disabled, using default template")
elif not custom_template_str:
print(f"No custom system prompt template found in settings")
custom_template = None
except Exception as e:
print(f"Error loading settings: {e}")
print(f"Falling back to default template")
custom_template = None
# Use custom template if available, otherwise use default
template = custom_template if custom_template else DEFAULT_PROMPT_TEMPLATE
prompt_llama = [template["template"].format(p) for p in prompt]
crop_start = template["crop_start"]
llama_inputs = tokenizer(
prompt_llama,
padding="max_length",
max_length=max_length + crop_start,
truncation=True,
return_tensors="pt",
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
)
llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
llama_attention_length = int(llama_attention_mask.sum())
llama_outputs = text_encoder(
input_ids=llama_input_ids,
attention_mask=llama_attention_mask,
output_hidden_states=True,
)
llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
# llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
assert torch.all(llama_attention_mask.bool())
# CLIP
clip_l_input_ids = tokenizer_2(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
).input_ids
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
return llama_vec, clip_l_pooler
@torch.no_grad()
def vae_decode_fake(latents):
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
[0.0696, 0.0795, 0.0518],
[0.0135, -0.0945, -0.0282],
[0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0804, -0.0254, -0.0639],
[-0.0991, 0.0271, -0.0669],
[-0.0646, -0.0422, -0.0400],
[-0.0696, -0.0595, -0.0894],
[-0.0799, -0.0208, -0.0375],
[0.1166, 0.1627, 0.0962],
[0.1165, 0.0432, 0.0407],
[-0.2315, -0.1920, -0.1355],
[-0.0270, 0.0401, -0.0821],
[-0.0616, -0.0997, -0.0727],
[0.0249, -0.0469, -0.1703]
] # From comfyui
latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
images = images.clamp(0.0, 1.0)
return images
@torch.no_grad()
def vae_decode(latents, vae, image_mode=False):
latents = latents / vae.config.scaling_factor
if not image_mode:
image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
else:
latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
image = torch.cat(image, dim=2)
return image
@torch.no_grad()
def vae_encode(image, vae):
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
return latents
|