File size: 3,076 Bytes
25bb7a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from einops import rearrange
from diffusers.models.attention import Attention
from .globals import get_enhance_weight, get_num_frames

# def get_feta_scores(query, key):
#     img_q, img_k = query, key
   
#     num_frames = get_num_frames()
    
#     B, S, N, C = img_q.shape

#     # Calculate spatial dimension
#     spatial_dim = S // num_frames
    
#     # Add time dimension between spatial and head dims
#     query_image = img_q.reshape(B, spatial_dim, num_frames, N, C)
#     key_image = img_k.reshape(B, spatial_dim, num_frames, N, C)
    
#     # Expand time dimension
#     query_image = query_image.expand(-1, -1, num_frames, -1, -1)  # [B, S, T, N, C]
#     key_image = key_image.expand(-1, -1, num_frames, -1, -1)      # [B, S, T, N, C]
    
#     # Reshape to match feta_score input format: [(B S) N T C]
#     query_image = rearrange(query_image, "b s t n c -> (b s) n t c")  #torch.Size([3200, 24, 5, 128])
#     key_image = rearrange(key_image, "b s t n c -> (b s) n t c")
    
#     return feta_score(query_image, key_image, C, num_frames)
 
def get_feta_scores(

        attn: Attention,

        query: torch.Tensor,

        key: torch.Tensor,

        head_dim: int,

        text_seq_length: int,

    ) -> torch.Tensor:
        num_frames = get_num_frames()
        spatial_dim = int((query.shape[2] - text_seq_length) / num_frames)

        query_image = rearrange(
            query[:, :, text_seq_length:],
            "B N (T S) C -> (B S) N T C",
            N=attn.heads,
            T=num_frames,
            S=spatial_dim,
            C=head_dim,
        )
        key_image = rearrange(
            key[:, :, text_seq_length:],
            "B N (T S) C -> (B S) N T C",
            N=attn.heads,
            T=num_frames,
            S=spatial_dim,
            C=head_dim,
        )
        return feta_score(query_image, key_image, head_dim, num_frames)

def feta_score(query_image, key_image, head_dim, num_frames):
    scale = head_dim**-0.5
    query_image = query_image * scale
    attn_temp = query_image @ key_image.transpose(-2, -1)  # translate attn to float32
    attn_temp = attn_temp.to(torch.float32)
    attn_temp = attn_temp.softmax(dim=-1)

    # Reshape to [batch_size * num_tokens, num_frames, num_frames]
    attn_temp = attn_temp.reshape(-1, num_frames, num_frames)

    # Create a mask for diagonal elements
    diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
    diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)

    # Zero out diagonal elements
    attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)

    # Calculate mean for each token's attention matrix
    # Number of off-diagonal elements per matrix is n*n - n
    num_off_diag = num_frames * num_frames - num_frames
    mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag

    enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight())
    enhance_scores = enhance_scores.clamp(min=1)
    return enhance_scores