Calligrapher / models /projection_models.py
Calligrapher2025's picture
Init.
4583098 verified
import torch
import torch.nn as nn
class QFormerProjModel(nn.Module):
def __init__(self,
cross_attention_dim=4096,
id_embeddings_dim=1152,
num_tokens=128,
num_heads=8,
num_query_tokens=32):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.query_embeds = nn.Parameter(torch.randn(num_tokens, cross_attention_dim))
self.id_proj = nn.Sequential(
nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
nn.GELU(),
nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_query_tokens)
)
self.cross_attn = nn.MultiheadAttention(
embed_dim=cross_attention_dim,
num_heads=num_heads,
batch_first=True
)
self.cross_attn_norm = nn.LayerNorm(cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
batch_size = id_embeds.size(0)
projected = self.id_proj(id_embeds)
kv = projected.view(batch_size, -1, self.cross_attention_dim)
queries = self.query_embeds.unsqueeze(0).expand(batch_size, -1, -1)
attn_output, _ = self.cross_attn(
query=queries,
key=kv,
value=kv
)
attn_output = self.cross_attn_norm(attn_output + queries)
return self.norm(attn_output)
class MLPProjModel(torch.nn.Module):
def __init__(self,
cross_attention_dim=768,
id_embeddings_dim=512,
num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x