import torch from diffusers import ModelMixin from einops import rearrange from torch import nn class Motion2bucketModel(ModelMixin): def __init__(self, window_size=5, blocks=12, channels=1024, clip_channels=1280, intermediate_dim=512, output_dim=768, context_tokens=32, clip_token_num=1, final_output_dim=5120): super().__init__() self.window_size = window_size self.clip_token_num = clip_token_num self.blocks = blocks self.channels = channels # self.input_dim = (window_size * blocks * channels + clip_channels*clip_token_num) self.input_dim = (window_size * channels + clip_channels * clip_token_num) self.intermediate_dim = intermediate_dim self.context_tokens = context_tokens self.output_dim = output_dim # define multiple linear layers self.proj1 = nn.Linear(self.input_dim, intermediate_dim) self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) self.act = nn.SiLU() self.final_proj = torch.nn.Linear(output_dim, final_output_dim) self.final_norm = torch.nn.LayerNorm(final_output_dim) nn.init.constant_(self.final_proj.weight, 0) if self.final_proj.bias is not None: nn.init.constant_(self.final_proj.bias, 0) def forward(self, audio_embeds, clip_embeds): """ Defines the forward pass for the AudioProjModel. Parameters: audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). Returns: context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). """ # merge video_length = audio_embeds.shape[1] # audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") audio_embeds = rearrange(audio_embeds, "bz f w c -> (bz f) w c") clip_embeds = clip_embeds.repeat(audio_embeds.size()[0]//clip_embeds.size()[0], 1, 1) clip_embeds = rearrange(clip_embeds, "b n d -> b (n d)") # batch_size, window_size, blocks, channels = audio_embeds.shape # audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) batch_size, window_size, channels = audio_embeds.shape audio_embeds = audio_embeds.view(batch_size, window_size * channels) audio_embeds = torch.cat([audio_embeds, clip_embeds], dim=-1) audio_embeds = self.act(self.proj1(audio_embeds)) audio_embeds = self.act(self.proj2(audio_embeds)) context_tokens = self.proj3(audio_embeds).reshape( batch_size, self.context_tokens, self.output_dim ) # context_tokens = self.norm(context_tokens) context_tokens = rearrange( context_tokens, "(bz f) m c -> bz f m c", f=video_length ) context_tokens = self.act(context_tokens) context_tokens = self.final_norm(self.final_proj(context_tokens)) return context_tokens if __name__ == '__main__': model = Motion2bucketModel(window_size=5) # audio_features = torch.randn(1, 81, 5, 12, 768) audio_features = torch.randn(1, 81, 5, 1024) clip_image_features = torch.randn(1, 1, 1280) out = model(audio_features, clip_image_features).mean(dim=2).mean(dim=1) print(out.size())