Stand-In / models /wan_video_vace.py
fffiloni's picture
Migrated from GitHub
26557da verified
import torch
from .wan_video_dit import DiTBlock
from .utils import hash_state_dict_keys
class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = torch.nn.Linear(self.dim, self.dim)
self.after_proj = torch.nn.Linear(self.dim, self.dim)
def forward(self, c, x, context, t_mod, freqs):
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c, _ = super().forward(c, context, t_mod, freqs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class VaceWanModel(torch.nn.Module):
def __init__(
self,
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
vace_in_dim=96,
patch_size=(1, 2, 2),
has_image_input=False,
dim=1536,
num_heads=12,
ffn_dim=8960,
eps=1e-6,
):
super().__init__()
self.vace_layers = vace_layers
self.vace_in_dim = vace_in_dim
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
# vace blocks
self.vace_blocks = torch.nn.ModuleList(
[
VaceWanAttentionBlock(
has_image_input, dim, num_heads, ffn_dim, eps, block_id=i
)
for i in self.vace_layers
]
)
# vace patch embeddings
self.vace_patch_embedding = torch.nn.Conv3d(
vace_in_dim, dim, kernel_size=patch_size, stride=patch_size
)
def forward(
self,
x,
vace_context,
context,
t_mod,
freqs,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat(
[
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))], dim=1)
for u in c
]
)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.vace_blocks:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c,
x,
context,
t_mod,
freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c,
x,
context,
t_mod,
freqs,
use_reentrant=False,
)
else:
c = block(c, x, context, t_mod, freqs)
hints = torch.unbind(c)[:-1]
return hints
@staticmethod
def state_dict_converter():
return VaceWanModelDictConverter()
class VaceWanModelDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict_ = {
name: param for name, param in state_dict.items() if name.startswith("vace")
}
if (
hash_state_dict_keys(state_dict_) == "3b2726384e4f64837bdf216eea3f310d"
): # vace 14B
config = {
"vace_layers": (0, 5, 10, 15, 20, 25, 30, 35),
"vace_in_dim": 96,
"patch_size": (1, 2, 2),
"has_image_input": False,
"dim": 5120,
"num_heads": 40,
"ffn_dim": 13824,
"eps": 1e-06,
}
else:
config = {}
return state_dict_, config