|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from embeddings import TechEmbeddingLayer, create_padding_mask, create_causal_mask |
|
class MultiHeadAttention(nn.Module): |
|
"""Multi-head attention mechanism optimized for technical content""" |
|
def __init__(self, d_model, n_heads, dropout=0.1): |
|
super(MultiHeadAttention, self).__init__() |
|
assert d_model % n_heads == 0 |
|
self.d_model = d_model |
|
self.n_heads = n_heads |
|
self.d_k = d_model // n_heads |
|
self.w_q = nn.Linear(d_model, d_model, bias=False) |
|
self.w_k = nn.Linear(d_model, d_model, bias=False) |
|
self.w_v = nn.Linear(d_model, d_model, bias=False) |
|
self.w_o = nn.Linear(d_model, d_model) |
|
self.dropout = nn.Dropout(dropout) |
|
self._init_weights() |
|
def _init_weights(self): |
|
"""Initialize weights with Xavier uniform""" |
|
for module in [self.w_q, self.w_k, self.w_v, self.w_o]: |
|
nn.init.xavier_uniform_(module.weight) |
|
def forward(self, query, key, value, mask=None, pos_encoding=None): |
|
batch_size, seq_len, d_model = query.size() |
|
Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) |
|
K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) |
|
V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) |
|
if pos_encoding is not None: |
|
Q, K = pos_encoding(Q, K) |
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) |
|
if mask is not None: |
|
mask = mask.unsqueeze(1).expand(batch_size, self.n_heads, seq_len, seq_len) |
|
scores.masked_fill_(mask, float('-inf')) |
|
attention_weights = F.softmax(scores, dim=-1) |
|
attention_weights = self.dropout(attention_weights) |
|
attended = torch.matmul(attention_weights, V) |
|
attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model) |
|
output = self.w_o(attended) |
|
return output |
|
class FeedForward(nn.Module): |
|
"""Position-wise feed forward network with GELU activation""" |
|
def __init__(self, d_model, dim_feedforward, dropout=0.1): |
|
super(FeedForward, self).__init__() |
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
self.dropout = nn.Dropout(dropout) |
|
nn.init.xavier_uniform_(self.linear1.weight) |
|
nn.init.xavier_uniform_(self.linear2.weight) |
|
def forward(self, x): |
|
x = F.gelu(self.linear1(x)) |
|
x = self.dropout(x) |
|
x = self.linear2(x) |
|
return x |
|
class RecursionRouter(nn.Module): |
|
"""Router to decide recursion steps for different types of technical problems""" |
|
def __init__(self, d_model, max_steps=4, router_type="adaptive"): |
|
super(RecursionRouter, self).__init__() |
|
self.max_steps = max_steps |
|
self.router_type = router_type |
|
if router_type == "adaptive": |
|
self.complexity_classifier = nn.Sequential( |
|
nn.Linear(d_model, d_model // 4), |
|
nn.GELU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(d_model // 4, max_steps + 1), |
|
nn.Softmax(dim=-1) |
|
) |
|
elif router_type == "fixed": |
|
self.fixed_steps = max_steps |
|
def forward(self, x): |
|
if self.router_type == "adaptive": |
|
seq_repr = x.mean(dim=1) |
|
step_probs = self.complexity_classifier(seq_repr) |
|
steps = torch.argmax(step_probs, dim=-1) |
|
return steps |
|
return self.fixed_steps |
|
class RecursiveTransformerLayer(nn.Module): |
|
"""Transformer layer with recursive computation capability""" |
|
def __init__(self, d_model, n_heads, dim_feedforward, max_steps=4, |
|
dropout=0.1, router_type="adaptive"): |
|
super(RecursiveTransformerLayer, self).__init__() |
|
self.max_steps = max_steps |
|
self.d_model = d_model |
|
self.attention = MultiHeadAttention(d_model, n_heads, dropout) |
|
self.feedforward = FeedForward(d_model, dim_feedforward, dropout) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.dropout = nn.Dropout(dropout) |
|
self.router = RecursionRouter(d_model, max_steps, router_type) |
|
self.step_projections = nn.ModuleList([ |
|
nn.Linear(d_model, d_model) for _ in range(max_steps) |
|
]) |
|
def forward(self, x, mask=None, pos_encoding=None): |
|
steps = self.router(x) |
|
if isinstance(steps, int): |
|
num_steps = min(steps, self.max_steps) |
|
return self._recursive_forward_fixed(x, mask, num_steps, pos_encoding) |
|
return self._recursive_forward_adaptive(x, mask, steps, pos_encoding) |
|
def _recursive_forward_fixed(self, x, mask, num_steps, pos_encoding): |
|
device = x.device |
|
batch_size = x.shape[0] |
|
computation_loss = torch.tensor(0.0, device=device) |
|
for step in range(num_steps): |
|
step_input = self.step_projections[step](x) if step < len(self.step_projections) else x |
|
attended = self.attention(step_input, step_input, step_input, mask, pos_encoding) |
|
x = self.norm1(x + self.dropout(attended)) |
|
fed_forward = self.feedforward(x) |
|
x = self.norm2(x + self.dropout(fed_forward)) |
|
computation_loss += torch.tensor(0.1, device=device) * batch_size |
|
return x, computation_loss |
|
def _recursive_forward_adaptive(self, x, mask, steps, pos_encoding): |
|
batch_size, seq_len, d_model = x.shape |
|
device = x.device |
|
max_batch_steps = int(steps.max().item()) |
|
computation_loss = torch.tensor(0.0, device=device) |
|
active_batches = torch.ones(batch_size, device=device, dtype=torch.bool) |
|
for step in range(max_batch_steps): |
|
step_mask = (steps > step) & active_batches |
|
if not step_mask.any(): |
|
break |
|
step_input = self.step_projections[step](x) if step < len(self.step_projections) else x |
|
attended = self.attention(step_input, step_input, step_input, mask, pos_encoding) |
|
attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended)) |
|
x = self.norm1(x + self.dropout(attended)) |
|
fed_forward = self.feedforward(x) |
|
fed_forward = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), fed_forward, torch.zeros_like(fed_forward)) |
|
x = self.norm2(x + self.dropout(fed_forward)) |
|
computation_loss += torch.tensor(0.1, device=device) * step_mask.sum() |
|
active_batches &= (steps > step) |
|
return x, computation_loss |
|
class MixtureOfRecursions(nn.Module): |
|
"""Main model with mixture of recursive transformer layers""" |
|
def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8, |
|
max_steps=4, dim_feedforward=2048, dropout=0.1, |
|
max_seq_len=512, router_type="adaptive", padding_idx=0): |
|
super(MixtureOfRecursions, self).__init__() |
|
self.d_model = d_model |
|
self.vocab_size = vocab_size |
|
self.padding_idx = padding_idx |
|
self.embeddings = TechEmbeddingLayer( |
|
vocab_size=vocab_size, |
|
d_model=d_model, |
|
max_seq_len=max_seq_len, |
|
dropout=dropout, |
|
padding_idx=padding_idx, |
|
pos_encoding="learned" |
|
) |
|
self.layers = nn.ModuleList([ |
|
RecursiveTransformerLayer( |
|
d_model=d_model, |
|
n_heads=n_heads, |
|
dim_feedforward=dim_feedforward, |
|
max_steps=max_steps, |
|
dropout=dropout, |
|
router_type=router_type |
|
) for _ in range(n_layers) |
|
]) |
|
self.final_norm = nn.LayerNorm(d_model) |
|
self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
|
self._init_weights() |
|
def _init_weights(self): |
|
nn.init.xavier_uniform_(self.lm_head.weight) |
|
def forward(self, input_ids, attention_mask=None): |
|
batch_size, seq_len = input_ids.shape |
|
padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0) |
|
causal_mask = create_causal_mask(seq_len, input_ids.device) |
|
padding_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) |
|
combined_mask = padding_mask | causal_mask.unsqueeze(0) |
|
x = self.embeddings(input_ids) |
|
pos_encoding = self.embeddings.get_positional_encoding() |
|
device = x.device |
|
total_computation_loss = torch.tensor(0.0, device=device) |
|
for layer in self.layers: |
|
x, comp_loss = layer(x, combined_mask, pos_encoding) |
|
total_computation_loss += comp_loss |
|
x = self.final_norm(x) |
|
logits = self.lm_head(x) |
|
return logits, total_computation_loss |
|
def generate_step(self, input_ids, temperature=1.0, top_k=None, top_p=None): |
|
self.eval() |
|
with torch.no_grad(): |
|
logits, _ = self.forward(input_ids) |
|
last_logits = logits[:, -1, :] / temperature |
|
if top_k is not None: |
|
indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None] |
|
last_logits[indices_to_remove] = float('-inf') |
|
if top_p is not None: |
|
sorted_logits, sorted_indices = torch.sort(last_logits, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
last_logits[indices_to_remove] = float('-inf') |
|
probs = F.softmax(last_logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
return next_token |
|
class TextGenerator: |
|
"""Text generation utility for the tech model""" |
|
def __init__(self, model, tokenizer, max_length=100, device=None): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
self.device = device if device else next(model.parameters()).device |
|
self.model.to(self.device) |
|
self.eos_token_id = tokenizer.vocab.get('<|endoftext|>', -1) |
|
self.assistant_token_id = tokenizer.vocab.get('<|assistant|>', -1) |
|
def generate(self, prompt, method="nucleus", temperature=1.0, top_k=50, top_p=0.9, max_new_tokens=None): |
|
if max_new_tokens is None: |
|
max_new_tokens = self.max_length |
|
input_text = f"<|user|> {prompt}" |
|
input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True) |
|
input_tensor = torch.tensor([input_ids], device=self.device) |
|
self.model.eval() |
|
generated_ids = [] |
|
with torch.no_grad(): |
|
for _ in range(max_new_tokens): |
|
if input_tensor.size(1) > self.max_length: |
|
input_tensor = input_tensor[:, -self.max_length:] |
|
|
|
if method == "greedy": |
|
next_token = self._greedy_generate(input_tensor) |
|
elif method == "sample": |
|
next_token = self._sample_generate(input_tensor, temperature) |
|
elif method == "top_k": |
|
next_token = self._top_k_generate(input_tensor, temperature, top_k) |
|
elif method == "nucleus" or method == "top_p": |
|
next_token = self._nucleus_generate(input_tensor, temperature, top_p) |
|
else: |
|
raise ValueError(f"Unknown generation method: {method}") |
|
next_token_id = next_token.item() |
|
generated_ids.append(next_token_id) |
|
input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1) |
|
if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id): |
|
break |
|
|
|
full_ids = input_ids + generated_ids |
|
full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False) |
|
|
|
if "<|assistant|>" in full_text: |
|
response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip() |
|
else: |
|
response = full_text.split("<|endoftext|>")[0].strip() |
|
return response if response else "No response generated." |
|
def _greedy_generate(self, input_tensor): |
|
logits, _ = self.model(input_tensor) |
|
return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) |
|
def _sample_generate(self, input_tensor, temperature): |
|
logits, _ = self.model(input_tensor) |
|
logits = logits[:, -1, :] / temperature |
|
probs = F.softmax(logits, dim=-1) |
|
return torch.multinomial(probs, num_samples=1) |
|
def _top_k_generate(self, input_tensor, temperature, top_k): |
|
logits, _ = self.model(input_tensor) |
|
logits = logits[:, -1, :] / temperature |
|
top_k_logits, top_k_indices = torch.topk(logits, top_k) |
|
probs = F.softmax(top_k_logits, dim=-1) |
|
next_token_idx = torch.multinomial(probs, num_samples=1) |
|
return top_k_indices.gather(-1, next_token_idx) |
|
def _nucleus_generate(self, input_tensor, temperature, top_p): |
|
return self.model.generate_step(input_tensor, temperature, top_p=top_p) |
|
def count_parameters(model): |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
return total_params, trainable_params |
|
def main(): |
|
vocab_size = 10000 |
|
d_model = 512 |
|
n_layers = 6 |
|
n_heads = 8 |
|
seq_len = 128 |
|
batch_size = 4 |
|
print("Initializing MixtureOfRecursions model...") |
|
model = MixtureOfRecursions( |
|
vocab_size=vocab_size, |
|
d_model=d_model, |
|
n_layers=n_layers, |
|
n_heads=n_heads, |
|
max_steps=4, |
|
dim_feedforward=2048, |
|
dropout=0.1, |
|
router_type="adaptive" |
|
) |
|
total_params, trainable_params = count_parameters(model) |
|
print(f"Total parameters: {total_params:,}") |
|
print(f"Trainable parameters: {trainable_params:,}") |
|
print("\nTesting forward pass...") |
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) |
|
attention_mask = torch.ones_like(input_ids) |
|
attention_mask[:, -10:] = 0 |
|
print(f"Input shape: {input_ids.shape}") |
|
logits, comp_loss = model(input_ids, attention_mask) |
|
print(f"Output logits shape: {logits.shape}") |
|
print(f"Computation loss: {comp_loss}") |
|
print(f"Expected logits shape: ({batch_size}, {seq_len}, {vocab_size})") |
|
print("\nTesting generation step...") |
|
next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9) |
|
print(f"Generated next token: {next_token}") |
|
print("\nModel test completed successfully!") |
|
if __name__ == "__main__": |
|
main() |