G-Transformer / src /model.py
Syamsuddin's picture
Update src/model.py
bb09e93 verified
# Copyright (c) 2025
# G-Transformer: Energy-Efficient Transformer based on GIT
# Author: Syamsuddin B. Ideris, S.Pd.MM
import math
from typing import Optional, Tuple, List, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
except Exception as e:
raise ImportError(
"Harap instal transformers >= 4.40.0. "
"pip install transformers"
) from e
# ----------------------------
# Konfigurasi
# ----------------------------
class GTransformerConfig(PretrainedConfig):
model_type = "gtransformer"
def __init__(
self,
vocab_size: int = 65536,
hidden_size: int = 8192,
intermediate_size: int = 22016,
num_hidden_layers: int = 48,
num_attention_heads: int = 64,
max_position_embeddings: int = 65536,
hidden_act: str = "swiglu",
layer_norm_epsilon: float = 1e-5,
attention_dropout: float = 0.05,
hidden_dropout_prob: float = 0.05,
rotary_emb_base: int = 10000,
use_flash_attention: bool = True,
use_low_rank_ffn: bool = True,
use_entropy_gate: bool = True,
use_moe: bool = False,
num_experts: int = 0,
top_k_experts: int = 0,
fp8_precision: bool = False,
dvfs_enabled: bool = False,
informational_constant_kI: float = 2.612e-20,
energy_per_token_target_J: float = 0.07,
delta_I_gate: float = 0.75,
local_window: int = 512,
global_rank: int = 64,
kv_compression_rank: int = 64,
bos_token_id: int = 1,
eos_token_id: int = 2,
pad_token_id: int = 0,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.hidden_act = hidden_act
self.layer_norm_epsilon = layer_norm_epsilon
self.attention_dropout = attention_dropout
self.hidden_dropout_prob = hidden_dropout_prob
self.rotary_emb_base = rotary_emb_base
self.use_flash_attention = use_flash_attention
self.use_low_rank_ffn = use_low_rank_ffn
self.use_entropy_gate = use_entropy_gate
self.use_moe = use_moe
self.num_experts = num_experts
self.top_k_experts = top_k_experts
self.fp8_precision = fp8_precision
self.dvfs_enabled = dvfs_enabled
self.informational_constant_kI = informational_constant_kI
self.energy_per_token_target_J = energy_per_token_target_J
self.delta_I_gate = delta_I_gate
self.local_window = local_window
self.global_rank = global_rank
self.kv_compression_rank = kv_compression_rank
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
# ----------------------------
# Utilitas
# ----------------------------
def swiglu(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return F.silu(x1) * x2
def build_activation(name: str):
if name.lower() == "swiglu":
return swiglu
return getattr(F, name)
# Rotary posisi sederhana
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, base: int = 10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x: torch.Tensor, seq_len: int):
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]
return cos, sin
def apply_rotary(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# q,k: [B, H, T, D]
def rotate(x):
x1, x2 = x[..., ::2], x[..., 1::2]
x_rot = torch.stack((-x2, x1), dim=-1).flatten(-2)
return x_rot
q_rot = (q * cos) + (rotate(q) * sin)
k_rot = (k * cos) + (rotate(k) * sin)
return q_rot, k_rot
# ----------------------------
# IA-Attention
# ----------------------------
class InformationalAttention(nn.Module):
"""
Atensi hemat energi.
1. Atensi lokal dengan jendela w.
2. Seleksi token global berbasis skor informasi.
3. Proyeksi low-rank untuk jalur global.
"""
def __init__(self, config: GTransformerConfig):
super().__init__()
self.config = config
self.d_model = config.hidden_size
self.n_heads = config.num_attention_heads
self.head_dim = self.d_model // self.n_heads
assert self.d_model % self.n_heads == 0
self.w_qkv = nn.Linear(self.d_model, 3 * self.d_model, bias=False)
self.w_o = nn.Linear(self.d_model, self.d_model, bias=False)
self.rotary = RotaryEmbedding(self.head_dim)
# Proyeksi low rank global
self.rank = config.global_rank
self.Pk = nn.Linear(self.head_dim, self.rank, bias=False)
self.Pv = nn.Linear(self.head_dim, self.rank, bias=False)
self.Uo = nn.Linear(self.rank, self.head_dim, bias=False)
# Skorer informasi
self.info_scorer = nn.Sequential(
nn.Linear(self.d_model, self.d_model // 4, bias=False),
nn.GELU(),
nn.Linear(self.d_model // 4, 1, bias=False),
)
self.attn_drop = nn.Dropout(config.attention_dropout)
self.proj_drop = nn.Dropout(config.hidden_dropout_prob)
self.local_window = config.local_window
self.delta_I_gate = config.delta_I_gate
self.use_entropy_gate = config.use_entropy_gate
def _causal_local_mask(self, T: int, w: int, device) -> torch.Tensor:
idxs = torch.arange(T, device=device)
mask = idxs[None, :] - idxs[:, None]
# izinkan hanya masa lalu dalam jendela lokal
mask = (mask > 0) | (mask < -(w - 1))
return mask # True berarti masked
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, C = x.shape
H, D = self.n_heads, self.head_dim
qkv = self.w_qkv(x) # [B, T, 3C]
q, k, v = qkv.split(C, dim=-1)
q = q.view(B, T, H, D).transpose(1, 2) # [B, H, T, D]
k = k.view(B, T, H, D).transpose(1, 2)
v = v.view(B, T, H, D).transpose(1, 2)
cos, sin = self.rotary(q, T)
q, k = apply_rotary(q, k, cos, sin)
# Tambah cache jika ada
if past_key_value is not None:
pk, pv = past_key_value # [B, H, T_past, D]
k = torch.cat([pk, k], dim=2)
v = torch.cat([pv, v], dim=2)
T_total = k.size(2)
else:
T_total = T
# Atensi lokal
w = min(self.local_window, T_total)
scale = 1.0 / math.sqrt(D)
attn_scores = torch.einsum("bhtd,bhSd->bhtS", q, k) * scale # S = T_total
# Mask kausal lokal
local_mask = self._causal_local_mask(T_total, w, x.device) # [T_total, T_total]
local_mask = local_mask[-T:] # baris untuk query saat ini
attn_scores = attn_scores.masked_fill(local_mask[None, None, :, :], float("-inf"))
if attention_mask is not None:
attn_scores = attn_scores + attention_mask # bentuk harus broadcastable
attn_w_local = F.softmax(attn_scores, dim=-1)
attn_w_local = self.attn_drop(attn_w_local)
ctx_local = torch.einsum("bhtS,bhSd->bhtd", attn_w_local, v)
# Seleksi global berbasis informasi
# Skor informasi dari representasi x
with torch.no_grad():
info_score = self.info_scorer(x).squeeze(-1) # [B, T]
# skala ke 0..1 via sigmoid
info_score = torch.sigmoid(info_score)
if self.use_entropy_gate:
gate = (info_score > self.delta_I_gate).float() # [B, T]
else:
gate = torch.ones_like(info_score)
# Proyeksi low rank untuk jalur global hanya pada token bergated
# Bentuk sederhana: kompres k,v ke rank kecil lalu atensi penuh pada subset
# Buat mask indeks global per batch
ctx_global = torch.zeros_like(ctx_local)
if gate.sum() > 0:
# kompres k,v
k_r = self.Pk(k) # [B,H,T_total,R]
v_r = self.Pv(v) # [B,H,T_total,R]
q_r = self.Pk(q) # reuse Pk untuk q
# gunakan atensi penuh pada subset dengan gate
# bentuk sederhana, gunakan semua posisi, tapi bobot query di-skala gate query
gate_q = gate[:, -T:].unsqueeze(1).unsqueeze(-1) # [B,1,T,1]
attn_scores_g = torch.einsum("bhtr,bhsr->bhts", q_r, k_r) * (scale * D / self.rank)
attn_w_g = F.softmax(attn_scores_g, dim=-1)
attn_w_g = self.attn_drop(attn_w_g)
ctx_g_r = torch.einsum("bhts,bhsr->bhtr", attn_w_g, v_r)
ctx_g = self.Uo(ctx_g_r) # [B,H,T,D]
ctx_global = ctx_g * gate_q
ctx = ctx_local + ctx_global
ctx = ctx.transpose(1, 2).contiguous().view(B, T, C)
out = self.w_o(ctx)
out = self.proj_drop(out)
present = (k, v) if use_cache else None
return out, present
# ----------------------------
# Low-Rank FFN
# ----------------------------
class LowRankFFN(nn.Module):
def __init__(self, config: GTransformerConfig):
super().__init__()
d = config.hidden_size
i = config.intermediate_size
act = build_activation(config.hidden_act)
self.act = act
# Faktorisasi: d -> i -> d, dengan bottleneck rank r_ffn
r_ffn = max(128, i // 8)
self.w1a = nn.Linear(d, r_ffn, bias=False)
self.w1b = nn.Linear(d, r_ffn, bias=False)
self.w2 = nn.Linear(r_ffn, d, bias=False)
self.drop = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SWiGLU low-rank
u = self.w1a(x)
v = self.w1b(x)
h = swiglu(torch.cat([u, v], dim=-1))
out = self.w2(h)
return self.drop(out)
# ----------------------------
# MoE Router opsional
# ----------------------------
class EntropyMoE(nn.Module):
def __init__(self, config: GTransformerConfig):
super().__init__()
assert config.num_experts > 0
self.num_experts = config.num_experts
self.top_k = max(1, config.top_k_experts)
d = config.hidden_size
i = config.intermediate_size
self.router = nn.Sequential(
nn.Linear(d, d // 2, bias=False),
nn.GELU(),
nn.Linear(d // 2, self.num_experts, bias=False),
)
self.experts = nn.ModuleList(
[nn.Sequential(nn.Linear(d, i), nn.GELU(), nn.Linear(i, d)) for _ in range(self.num_experts)]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, D = x.shape
logits = self.router(x) # [B,T,E]
probs = F.softmax(logits, dim=-1)
topk = torch.topk(probs, k=self.top_k, dim=-1)
idx = topk.indices # [B,T,K]
wgt = topk.values # [B,T,K]
out = torch.zeros_like(x)
for k in range(self.top_k):
sel = idx[..., k] # [B,T]
# kumpulkan untuk tiap expert
for e in range(self.num_experts):
mask = (sel == e).float().unsqueeze(-1) # [B,T,1]
if mask.sum() == 0:
continue
xe = x * mask
ye = self.experts[e](xe)
out = out + ye * (wgt[..., k].unsqueeze(-1))
return out
# ----------------------------
# Blok Transformer
# ----------------------------
class GTransformerBlock(nn.Module):
def __init__(self, config: GTransformerConfig):
super().__init__()
self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.attn = InformationalAttention(config)
self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
if config.use_moe and config.num_experts > 0:
self.ff = EntropyMoE(config)
else:
self.ff = LowRankFFN(config) if config.use_low_rank_ffn else nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.GELU(),
nn.Linear(config.intermediate_size, config.hidden_size),
)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
h, present = self.attn(self.ln1(x), attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache)
x = x + h
x = x + self.ff(self.ln2(x))
return x, present
# ----------------------------
# Model dasar
# ----------------------------
class GTransformerModel(PreTrainedModel):
config_class = GTransformerConfig
def __init__(self, config: GTransformerConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([GTransformerBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
B, T = input_ids.shape
x = self.embed_tokens(input_ids)
new_past = [] if use_cache else None
for i, layer in enumerate(self.layers):
pkv = None if past_key_values is None else past_key_values[i]
x, present = layer(x, attention_mask=attention_mask, past_key_value=pkv, use_cache=use_cache)
if use_cache:
new_past.append(present)
x = self.ln_f(x)
return x, new_past
# ----------------------------
# Causal LM
# ----------------------------
class GTransformerForCausalLM(PreTrainedModel):
config_class = GTransformerConfig
def __init__(self, config: GTransformerConfig):
super().__init__(config)
self.transformer = GTransformerModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.transformer.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.transformer.embed_tokens = new_embeddings
def tie_weights(self):
# opsional tidak diikat agar stabil FP8
pass
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
**kwargs,
) -> CausalLMOutputWithPast:
hidden_states, new_past = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
# Regularisasi informasi sederhana
if self.config.use_entropy_gate:
with torch.no_grad():
probs = F.softmax(shift_logits, dim=-1)
logp = torch.log(probs + 1e-9)
H = -(probs * logp).sum(dim=-1).mean()
# target penurunan entropi moderat
loss = loss + 1e-4 * H
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=new_past,
hidden_states=None,
attentions=None,
)
@torch.no_grad()
def generate_simple(
self,
input_ids: torch.LongTensor,
max_new_tokens: int = 64,
temperature: float = 1.0,
) -> torch.LongTensor:
self.eval()
past = None
out = input_ids
for _ in range(max_new_tokens):
logits = self(out[:, -1:].contiguous(), use_cache=True, past_key_values=past).logits
past = self(out[:, -1:].contiguous(), use_cache=True, past_key_values=past).past_key_values
next_token = torch.distributions.Categorical(logits=logits[:, -1, :] / max(1e-6, temperature)).sample()
out = torch.cat([out, next_token.unsqueeze(-1)], dim=1)
if int(next_token[0].item()) == self.config.eos_token_id:
break
return out