Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
class QFormerProjModel(nn.Module): | |
def __init__(self, | |
cross_attention_dim=4096, | |
id_embeddings_dim=1152, | |
num_tokens=128, | |
num_heads=8, | |
num_query_tokens=32): | |
super().__init__() | |
self.cross_attention_dim = cross_attention_dim | |
self.num_tokens = num_tokens | |
self.query_embeds = nn.Parameter(torch.randn(num_tokens, cross_attention_dim)) | |
self.id_proj = nn.Sequential( | |
nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), | |
nn.GELU(), | |
nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_query_tokens) | |
) | |
self.cross_attn = nn.MultiheadAttention( | |
embed_dim=cross_attention_dim, | |
num_heads=num_heads, | |
batch_first=True | |
) | |
self.cross_attn_norm = nn.LayerNorm(cross_attention_dim) | |
self.norm = nn.LayerNorm(cross_attention_dim) | |
def forward(self, id_embeds): | |
batch_size = id_embeds.size(0) | |
projected = self.id_proj(id_embeds) | |
kv = projected.view(batch_size, -1, self.cross_attention_dim) | |
queries = self.query_embeds.unsqueeze(0).expand(batch_size, -1, -1) | |
attn_output, _ = self.cross_attn( | |
query=queries, | |
key=kv, | |
value=kv | |
) | |
attn_output = self.cross_attn_norm(attn_output + queries) | |
return self.norm(attn_output) | |
class MLPProjModel(torch.nn.Module): | |
def __init__(self, | |
cross_attention_dim=768, | |
id_embeddings_dim=512, | |
num_tokens=4): | |
super().__init__() | |
self.cross_attention_dim = cross_attention_dim | |
self.num_tokens = num_tokens | |
self.proj = torch.nn.Sequential( | |
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), | |
torch.nn.GELU(), | |
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens), | |
) | |
self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
def forward(self, id_embeds): | |
x = self.proj(id_embeds) | |
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) | |
x = self.norm(x) | |
return x | |