File size: 2,330 Bytes
4583098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn


class QFormerProjModel(nn.Module):
    def __init__(self,

                 cross_attention_dim=4096,

                 id_embeddings_dim=1152,

                 num_tokens=128,

                 num_heads=8,

                 num_query_tokens=32):
        super().__init__()
        self.cross_attention_dim = cross_attention_dim
        self.num_tokens = num_tokens

        self.query_embeds = nn.Parameter(torch.randn(num_tokens, cross_attention_dim))

        self.id_proj = nn.Sequential(
            nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
            nn.GELU(),
            nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_query_tokens)
        )

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=cross_attention_dim,
            num_heads=num_heads,
            batch_first=True
        )
        self.cross_attn_norm = nn.LayerNorm(cross_attention_dim)

        self.norm = nn.LayerNorm(cross_attention_dim)

    def forward(self, id_embeds):
        batch_size = id_embeds.size(0)

        projected = self.id_proj(id_embeds)
        kv = projected.view(batch_size, -1, self.cross_attention_dim)

        queries = self.query_embeds.unsqueeze(0).expand(batch_size, -1, -1)

        attn_output, _ = self.cross_attn(
            query=queries,
            key=kv,
            value=kv
        )
        attn_output = self.cross_attn_norm(attn_output + queries)

        return self.norm(attn_output)


class MLPProjModel(torch.nn.Module):
    def __init__(self,

                 cross_attention_dim=768,

                 id_embeddings_dim=512,

                 num_tokens=4):
        super().__init__()
        self.cross_attention_dim = cross_attention_dim
        self.num_tokens = num_tokens

        self.proj = torch.nn.Sequential(
            torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
            torch.nn.GELU(),
            torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
        )
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, id_embeds):
        x = self.proj(id_embeds)
        x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
        x = self.norm(x)
        return x