File size: 1,416 Bytes
7389710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class PhaseFormerTransformerLayer(nn.Module):
    """
    Transformer layer with phase-based temporal gating applied
    to attention and feed-forward residual paths.

    Args:
        d_model (int): Input/output dimension.
        nhead (int): Number of attention heads.
        dim_feedforward (int): FFN hidden layer size.
        dropout (float): Dropout probability.
        decay_rate (float): Decay coefficient lambda.
    """
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, decay_rate=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.decay_rate = decay_rate
        self.phase_proj = nn.Linear(d_model, d_model)

    def forward(self, src, t: float):
        D_t = math.exp(-self.decay_rate * t)
        phase = self.phase_proj(src)
        g = D_t * torch.sin(phase)

        attn_out, _ = self.self_attn(src, src, src)
        src2 = self.norm1(src + g * attn_out)

        ff = self.linear2(self.dropout(F.relu(self.linear1(src2))))
        return self.norm2(src2 + g * ff)