|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple, List, Dict, Any |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
try: |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
except Exception as e: |
|
|
raise ImportError( |
|
|
"Harap instal transformers >= 4.40.0. " |
|
|
"pip install transformers" |
|
|
) from e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GTransformerConfig(PretrainedConfig): |
|
|
model_type = "gtransformer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 65536, |
|
|
hidden_size: int = 8192, |
|
|
intermediate_size: int = 22016, |
|
|
num_hidden_layers: int = 48, |
|
|
num_attention_heads: int = 64, |
|
|
max_position_embeddings: int = 65536, |
|
|
hidden_act: str = "swiglu", |
|
|
layer_norm_epsilon: float = 1e-5, |
|
|
attention_dropout: float = 0.05, |
|
|
hidden_dropout_prob: float = 0.05, |
|
|
rotary_emb_base: int = 10000, |
|
|
use_flash_attention: bool = True, |
|
|
use_low_rank_ffn: bool = True, |
|
|
use_entropy_gate: bool = True, |
|
|
use_moe: bool = False, |
|
|
num_experts: int = 0, |
|
|
top_k_experts: int = 0, |
|
|
fp8_precision: bool = False, |
|
|
dvfs_enabled: bool = False, |
|
|
informational_constant_kI: float = 2.612e-20, |
|
|
energy_per_token_target_J: float = 0.07, |
|
|
delta_I_gate: float = 0.75, |
|
|
local_window: int = 512, |
|
|
global_rank: int = 64, |
|
|
kv_compression_rank: int = 64, |
|
|
bos_token_id: int = 1, |
|
|
eos_token_id: int = 2, |
|
|
pad_token_id: int = 0, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.hidden_act = hidden_act |
|
|
self.layer_norm_epsilon = layer_norm_epsilon |
|
|
self.attention_dropout = attention_dropout |
|
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
|
self.rotary_emb_base = rotary_emb_base |
|
|
|
|
|
self.use_flash_attention = use_flash_attention |
|
|
self.use_low_rank_ffn = use_low_rank_ffn |
|
|
self.use_entropy_gate = use_entropy_gate |
|
|
|
|
|
self.use_moe = use_moe |
|
|
self.num_experts = num_experts |
|
|
self.top_k_experts = top_k_experts |
|
|
|
|
|
self.fp8_precision = fp8_precision |
|
|
self.dvfs_enabled = dvfs_enabled |
|
|
|
|
|
self.informational_constant_kI = informational_constant_kI |
|
|
self.energy_per_token_target_J = energy_per_token_target_J |
|
|
|
|
|
self.delta_I_gate = delta_I_gate |
|
|
self.local_window = local_window |
|
|
self.global_rank = global_rank |
|
|
self.kv_compression_rank = kv_compression_rank |
|
|
|
|
|
self.bos_token_id = bos_token_id |
|
|
self.eos_token_id = eos_token_id |
|
|
self.pad_token_id = pad_token_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def swiglu(x: torch.Tensor) -> torch.Tensor: |
|
|
x1, x2 = x.chunk(2, dim=-1) |
|
|
return F.silu(x1) * x2 |
|
|
|
|
|
|
|
|
def build_activation(name: str): |
|
|
if name.lower() == "swiglu": |
|
|
return swiglu |
|
|
return getattr(F, name) |
|
|
|
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
def __init__(self, dim: int, base: int = 10000): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor, seq_len: int): |
|
|
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos()[None, None, :, :] |
|
|
sin = emb.sin()[None, None, :, :] |
|
|
return cos, sin |
|
|
|
|
|
|
|
|
def apply_rotary(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): |
|
|
|
|
|
def rotate(x): |
|
|
x1, x2 = x[..., ::2], x[..., 1::2] |
|
|
x_rot = torch.stack((-x2, x1), dim=-1).flatten(-2) |
|
|
return x_rot |
|
|
q_rot = (q * cos) + (rotate(q) * sin) |
|
|
k_rot = (k * cos) + (rotate(k) * sin) |
|
|
return q_rot, k_rot |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InformationalAttention(nn.Module): |
|
|
""" |
|
|
Atensi hemat energi. |
|
|
1. Atensi lokal dengan jendela w. |
|
|
2. Seleksi token global berbasis skor informasi. |
|
|
3. Proyeksi low-rank untuk jalur global. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: GTransformerConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.d_model = config.hidden_size |
|
|
self.n_heads = config.num_attention_heads |
|
|
self.head_dim = self.d_model // self.n_heads |
|
|
assert self.d_model % self.n_heads == 0 |
|
|
|
|
|
self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model, bias=False) |
|
|
self.w_o = nn.Linear(self.d_model, self.d_model, bias=False) |
|
|
|
|
|
self.rotary = RotaryEmbedding(self.head_dim) |
|
|
|
|
|
|
|
|
self.rank = config.global_rank |
|
|
self.Pk = nn.Linear(self.head_dim, self.rank, bias=False) |
|
|
self.Pv = nn.Linear(self.head_dim, self.rank, bias=False) |
|
|
self.Uo = nn.Linear(self.rank, self.head_dim, bias=False) |
|
|
|
|
|
|
|
|
self.info_scorer = nn.Sequential( |
|
|
nn.Linear(self.d_model, self.d_model // 4, bias=False), |
|
|
nn.GELU(), |
|
|
nn.Linear(self.d_model // 4, 1, bias=False), |
|
|
) |
|
|
|
|
|
self.attn_drop = nn.Dropout(config.attention_dropout) |
|
|
self.proj_drop = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.local_window = config.local_window |
|
|
self.delta_I_gate = config.delta_I_gate |
|
|
self.use_entropy_gate = config.use_entropy_gate |
|
|
|
|
|
def _causal_local_mask(self, T: int, w: int, device) -> torch.Tensor: |
|
|
idxs = torch.arange(T, device=device) |
|
|
mask = idxs[None, :] - idxs[:, None] |
|
|
|
|
|
mask = (mask > 0) | (mask < -(w - 1)) |
|
|
return mask |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
|
|
|
B, T, C = x.shape |
|
|
H, D = self.n_heads, self.head_dim |
|
|
|
|
|
qkv = self.w_qkv(x) |
|
|
q, k, v = qkv.split(C, dim=-1) |
|
|
q = q.view(B, T, H, D).transpose(1, 2) |
|
|
k = k.view(B, T, H, D).transpose(1, 2) |
|
|
v = v.view(B, T, H, D).transpose(1, 2) |
|
|
|
|
|
cos, sin = self.rotary(q, T) |
|
|
q, k = apply_rotary(q, k, cos, sin) |
|
|
|
|
|
|
|
|
if past_key_value is not None: |
|
|
pk, pv = past_key_value |
|
|
k = torch.cat([pk, k], dim=2) |
|
|
v = torch.cat([pv, v], dim=2) |
|
|
T_total = k.size(2) |
|
|
else: |
|
|
T_total = T |
|
|
|
|
|
|
|
|
w = min(self.local_window, T_total) |
|
|
scale = 1.0 / math.sqrt(D) |
|
|
attn_scores = torch.einsum("bhtd,bhSd->bhtS", q, k) * scale |
|
|
|
|
|
local_mask = self._causal_local_mask(T_total, w, x.device) |
|
|
local_mask = local_mask[-T:] |
|
|
attn_scores = attn_scores.masked_fill(local_mask[None, None, :, :], float("-inf")) |
|
|
if attention_mask is not None: |
|
|
attn_scores = attn_scores + attention_mask |
|
|
|
|
|
attn_w_local = F.softmax(attn_scores, dim=-1) |
|
|
attn_w_local = self.attn_drop(attn_w_local) |
|
|
ctx_local = torch.einsum("bhtS,bhSd->bhtd", attn_w_local, v) |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
info_score = self.info_scorer(x).squeeze(-1) |
|
|
|
|
|
info_score = torch.sigmoid(info_score) |
|
|
if self.use_entropy_gate: |
|
|
gate = (info_score > self.delta_I_gate).float() |
|
|
else: |
|
|
gate = torch.ones_like(info_score) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx_global = torch.zeros_like(ctx_local) |
|
|
if gate.sum() > 0: |
|
|
|
|
|
k_r = self.Pk(k) |
|
|
v_r = self.Pv(v) |
|
|
q_r = self.Pk(q) |
|
|
|
|
|
|
|
|
|
|
|
gate_q = gate[:, -T:].unsqueeze(1).unsqueeze(-1) |
|
|
attn_scores_g = torch.einsum("bhtr,bhsr->bhts", q_r, k_r) * (scale * D / self.rank) |
|
|
attn_w_g = F.softmax(attn_scores_g, dim=-1) |
|
|
attn_w_g = self.attn_drop(attn_w_g) |
|
|
ctx_g_r = torch.einsum("bhts,bhsr->bhtr", attn_w_g, v_r) |
|
|
ctx_g = self.Uo(ctx_g_r) |
|
|
ctx_global = ctx_g * gate_q |
|
|
|
|
|
ctx = ctx_local + ctx_global |
|
|
ctx = ctx.transpose(1, 2).contiguous().view(B, T, C) |
|
|
out = self.w_o(ctx) |
|
|
out = self.proj_drop(out) |
|
|
|
|
|
present = (k, v) if use_cache else None |
|
|
return out, present |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowRankFFN(nn.Module): |
|
|
def __init__(self, config: GTransformerConfig): |
|
|
super().__init__() |
|
|
d = config.hidden_size |
|
|
i = config.intermediate_size |
|
|
act = build_activation(config.hidden_act) |
|
|
self.act = act |
|
|
|
|
|
r_ffn = max(128, i // 8) |
|
|
self.w1a = nn.Linear(d, r_ffn, bias=False) |
|
|
self.w1b = nn.Linear(d, r_ffn, bias=False) |
|
|
self.w2 = nn.Linear(r_ffn, d, bias=False) |
|
|
self.drop = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
u = self.w1a(x) |
|
|
v = self.w1b(x) |
|
|
h = swiglu(torch.cat([u, v], dim=-1)) |
|
|
out = self.w2(h) |
|
|
return self.drop(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EntropyMoE(nn.Module): |
|
|
def __init__(self, config: GTransformerConfig): |
|
|
super().__init__() |
|
|
assert config.num_experts > 0 |
|
|
self.num_experts = config.num_experts |
|
|
self.top_k = max(1, config.top_k_experts) |
|
|
d = config.hidden_size |
|
|
i = config.intermediate_size |
|
|
|
|
|
self.router = nn.Sequential( |
|
|
nn.Linear(d, d // 2, bias=False), |
|
|
nn.GELU(), |
|
|
nn.Linear(d // 2, self.num_experts, bias=False), |
|
|
) |
|
|
self.experts = nn.ModuleList( |
|
|
[nn.Sequential(nn.Linear(d, i), nn.GELU(), nn.Linear(i, d)) for _ in range(self.num_experts)] |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B, T, D = x.shape |
|
|
logits = self.router(x) |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
topk = torch.topk(probs, k=self.top_k, dim=-1) |
|
|
idx = topk.indices |
|
|
wgt = topk.values |
|
|
|
|
|
out = torch.zeros_like(x) |
|
|
for k in range(self.top_k): |
|
|
sel = idx[..., k] |
|
|
|
|
|
for e in range(self.num_experts): |
|
|
mask = (sel == e).float().unsqueeze(-1) |
|
|
if mask.sum() == 0: |
|
|
continue |
|
|
xe = x * mask |
|
|
ye = self.experts[e](xe) |
|
|
out = out + ye * (wgt[..., k].unsqueeze(-1)) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GTransformerBlock(nn.Module): |
|
|
def __init__(self, config: GTransformerConfig): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
|
|
self.attn = InformationalAttention(config) |
|
|
self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
|
|
if config.use_moe and config.num_experts > 0: |
|
|
self.ff = EntropyMoE(config) |
|
|
else: |
|
|
self.ff = LowRankFFN(config) if config.use_low_rank_ffn else nn.Sequential( |
|
|
nn.Linear(config.hidden_size, config.intermediate_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(config.intermediate_size, config.hidden_size), |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
h, present = self.attn(self.ln1(x), attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache) |
|
|
x = x + h |
|
|
x = x + self.ff(self.ln2(x)) |
|
|
return x, present |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GTransformerModel(PreTrainedModel): |
|
|
config_class = GTransformerConfig |
|
|
|
|
|
def __init__(self, config: GTransformerConfig): |
|
|
super().__init__(config) |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.layers = nn.ModuleList([GTransformerBlock(config) for _ in range(config.num_hidden_layers)]) |
|
|
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
**kwargs, |
|
|
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: |
|
|
|
|
|
B, T = input_ids.shape |
|
|
x = self.embed_tokens(input_ids) |
|
|
|
|
|
new_past = [] if use_cache else None |
|
|
for i, layer in enumerate(self.layers): |
|
|
pkv = None if past_key_values is None else past_key_values[i] |
|
|
x, present = layer(x, attention_mask=attention_mask, past_key_value=pkv, use_cache=use_cache) |
|
|
if use_cache: |
|
|
new_past.append(present) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
return x, new_past |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GTransformerForCausalLM(PreTrainedModel): |
|
|
config_class = GTransformerConfig |
|
|
|
|
|
def __init__(self, config: GTransformerConfig): |
|
|
super().__init__(config) |
|
|
self.transformer = GTransformerModel(config) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.transformer.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
|
self.transformer.embed_tokens = new_embeddings |
|
|
|
|
|
def tie_weights(self): |
|
|
|
|
|
pass |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
**kwargs, |
|
|
) -> CausalLMOutputWithPast: |
|
|
|
|
|
hidden_states, new_past = self.transformer( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[:, :-1, :].contiguous() |
|
|
shift_labels = labels[:, 1:].contiguous() |
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100, |
|
|
) |
|
|
|
|
|
|
|
|
if self.config.use_entropy_gate: |
|
|
with torch.no_grad(): |
|
|
probs = F.softmax(shift_logits, dim=-1) |
|
|
logp = torch.log(probs + 1e-9) |
|
|
H = -(probs * logp).sum(dim=-1).mean() |
|
|
|
|
|
loss = loss + 1e-4 * H |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=new_past, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_simple( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
max_new_tokens: int = 64, |
|
|
temperature: float = 1.0, |
|
|
) -> torch.LongTensor: |
|
|
self.eval() |
|
|
past = None |
|
|
out = input_ids |
|
|
for _ in range(max_new_tokens): |
|
|
logits = self(out[:, -1:].contiguous(), use_cache=True, past_key_values=past).logits |
|
|
past = self(out[:, -1:].contiguous(), use_cache=True, past_key_values=past).past_key_values |
|
|
next_token = torch.distributions.Categorical(logits=logits[:, -1, :] / max(1e-6, temperature)).sample() |
|
|
out = torch.cat([out, next_token.unsqueeze(-1)], dim=1) |
|
|
if int(next_token[0].item()) == self.config.eos_token_id: |
|
|
break |
|
|
return out |
|
|
|