Girinath11's picture
Create model_slm.py
a17939e verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from embeddings import TechEmbeddingLayer, create_padding_mask, create_causal_mask
class MultiHeadAttention(nn.Module):
"""Multi-head attention mechanism optimized for technical content"""
def __init__(self, d_model, n_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
"""Initialize weights with Xavier uniform"""
for module in [self.w_q, self.w_k, self.w_v, self.w_o]:
nn.init.xavier_uniform_(module.weight)
def forward(self, query, key, value, mask=None, pos_encoding=None):
batch_size, seq_len, d_model = query.size()
Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
if pos_encoding is not None:
Q, K = pos_encoding(Q, K)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
mask = mask.unsqueeze(1).expand(batch_size, self.n_heads, seq_len, seq_len)
scores.masked_fill_(mask, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
attended = torch.matmul(attention_weights, V)
attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
output = self.w_o(attended)
return output
class FeedForward(nn.Module):
"""Position-wise feed forward network with GELU activation"""
def __init__(self, d_model, dim_feedforward, dropout=0.1):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.linear1.weight)
nn.init.xavier_uniform_(self.linear2.weight)
def forward(self, x):
x = F.gelu(self.linear1(x))
x = self.dropout(x)
x = self.linear2(x)
return x
class RecursionRouter(nn.Module):
"""Router to decide recursion steps for different types of technical problems"""
def __init__(self, d_model, max_steps=4, router_type="adaptive"):
super(RecursionRouter, self).__init__()
self.max_steps = max_steps
self.router_type = router_type
if router_type == "adaptive":
self.complexity_classifier = nn.Sequential(
nn.Linear(d_model, d_model // 4),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(d_model // 4, max_steps + 1),
nn.Softmax(dim=-1)
)
elif router_type == "fixed":
self.fixed_steps = max_steps
def forward(self, x):
if self.router_type == "adaptive":
seq_repr = x.mean(dim=1)
step_probs = self.complexity_classifier(seq_repr)
steps = torch.argmax(step_probs, dim=-1)
return steps
return self.fixed_steps
class RecursiveTransformerLayer(nn.Module):
"""Transformer layer with recursive computation capability"""
def __init__(self, d_model, n_heads, dim_feedforward, max_steps=4,
dropout=0.1, router_type="adaptive"):
super(RecursiveTransformerLayer, self).__init__()
self.max_steps = max_steps
self.d_model = d_model
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feedforward = FeedForward(d_model, dim_feedforward, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.router = RecursionRouter(d_model, max_steps, router_type)
self.step_projections = nn.ModuleList([
nn.Linear(d_model, d_model) for _ in range(max_steps)
])
def forward(self, x, mask=None, pos_encoding=None):
steps = self.router(x)
if isinstance(steps, int):
num_steps = min(steps, self.max_steps)
return self._recursive_forward_fixed(x, mask, num_steps, pos_encoding)
return self._recursive_forward_adaptive(x, mask, steps, pos_encoding)
def _recursive_forward_fixed(self, x, mask, num_steps, pos_encoding):
device = x.device
batch_size = x.shape[0]
computation_loss = torch.tensor(0.0, device=device)
for step in range(num_steps):
step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
x = self.norm1(x + self.dropout(attended))
fed_forward = self.feedforward(x)
x = self.norm2(x + self.dropout(fed_forward))
computation_loss += torch.tensor(0.1, device=device) * batch_size
return x, computation_loss
def _recursive_forward_adaptive(self, x, mask, steps, pos_encoding):
batch_size, seq_len, d_model = x.shape
device = x.device
max_batch_steps = int(steps.max().item())
computation_loss = torch.tensor(0.0, device=device)
active_batches = torch.ones(batch_size, device=device, dtype=torch.bool)
for step in range(max_batch_steps):
step_mask = (steps > step) & active_batches
if not step_mask.any():
break
step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended))
x = self.norm1(x + self.dropout(attended))
fed_forward = self.feedforward(x)
fed_forward = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), fed_forward, torch.zeros_like(fed_forward))
x = self.norm2(x + self.dropout(fed_forward))
computation_loss += torch.tensor(0.1, device=device) * step_mask.sum()
active_batches &= (steps > step)
return x, computation_loss
class MixtureOfRecursions(nn.Module):
"""Main model with mixture of recursive transformer layers"""
def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8,
max_steps=4, dim_feedforward=2048, dropout=0.1,
max_seq_len=512, router_type="adaptive", padding_idx=0):
super(MixtureOfRecursions, self).__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.padding_idx = padding_idx
self.embeddings = TechEmbeddingLayer(
vocab_size=vocab_size,
d_model=d_model,
max_seq_len=max_seq_len,
dropout=dropout,
padding_idx=padding_idx,
pos_encoding="learned"
)
self.layers = nn.ModuleList([
RecursiveTransformerLayer(
d_model=d_model,
n_heads=n_heads,
dim_feedforward=dim_feedforward,
max_steps=max_steps,
dropout=dropout,
router_type=router_type
) for _ in range(n_layers)
])
self.final_norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.lm_head.weight)
def forward(self, input_ids, attention_mask=None):
batch_size, seq_len = input_ids.shape
padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
causal_mask = create_causal_mask(seq_len, input_ids.device)
padding_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)
combined_mask = padding_mask | causal_mask.unsqueeze(0)
x = self.embeddings(input_ids)
pos_encoding = self.embeddings.get_positional_encoding()
device = x.device
total_computation_loss = torch.tensor(0.0, device=device)
for layer in self.layers:
x, comp_loss = layer(x, combined_mask, pos_encoding)
total_computation_loss += comp_loss
x = self.final_norm(x)
logits = self.lm_head(x)
return logits, total_computation_loss
def generate_step(self, input_ids, temperature=1.0, top_k=None, top_p=None):
self.eval()
with torch.no_grad():
logits, _ = self.forward(input_ids)
last_logits = logits[:, -1, :] / temperature
if top_k is not None:
indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
last_logits[indices_to_remove] = float('-inf')
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
last_logits[indices_to_remove] = float('-inf')
probs = F.softmax(last_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return next_token
class TextGenerator:
"""Text generation utility for the tech model"""
def __init__(self, model, tokenizer, max_length=100, device=None):
self.model = model
self.tokenizer = tokenizer
self.max_length = max_length
self.device = device if device else next(model.parameters()).device
self.model.to(self.device)
self.eos_token_id = tokenizer.vocab.get('<|endoftext|>', -1)
self.assistant_token_id = tokenizer.vocab.get('<|assistant|>', -1)
def generate(self, prompt, method="nucleus", temperature=1.0, top_k=50, top_p=0.9, max_new_tokens=None):
if max_new_tokens is None:
max_new_tokens = self.max_length
input_text = f"<|user|> {prompt}"
input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
input_tensor = torch.tensor([input_ids], device=self.device)
self.model.eval()
generated_ids = []
with torch.no_grad():
for _ in range(max_new_tokens):
if input_tensor.size(1) > self.max_length:
input_tensor = input_tensor[:, -self.max_length:]
# Generate next token
if method == "greedy":
next_token = self._greedy_generate(input_tensor)
elif method == "sample":
next_token = self._sample_generate(input_tensor, temperature)
elif method == "top_k":
next_token = self._top_k_generate(input_tensor, temperature, top_k)
elif method == "nucleus" or method == "top_p":
next_token = self._nucleus_generate(input_tensor, temperature, top_p)
else:
raise ValueError(f"Unknown generation method: {method}")
next_token_id = next_token.item()
generated_ids.append(next_token_id)
input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
break
# Decode the full sequence
full_ids = input_ids + generated_ids
full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
# Extract assistant response
if "<|assistant|>" in full_text:
response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
else:
response = full_text.split("<|endoftext|>")[0].strip()
return response if response else "No response generated."
def _greedy_generate(self, input_tensor):
logits, _ = self.model(input_tensor)
return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
def _sample_generate(self, input_tensor, temperature):
logits, _ = self.model(input_tensor)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
def _top_k_generate(self, input_tensor, temperature, top_k):
logits, _ = self.model(input_tensor)
logits = logits[:, -1, :] / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = F.softmax(top_k_logits, dim=-1)
next_token_idx = torch.multinomial(probs, num_samples=1)
return top_k_indices.gather(-1, next_token_idx)
def _nucleus_generate(self, input_tensor, temperature, top_p):
return self.model.generate_step(input_tensor, temperature, top_p=top_p)
def count_parameters(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params
def main():
vocab_size = 10000
d_model = 512
n_layers = 6
n_heads = 8
seq_len = 128
batch_size = 4
print("Initializing MixtureOfRecursions model...")
model = MixtureOfRecursions(
vocab_size=vocab_size,
d_model=d_model,
n_layers=n_layers,
n_heads=n_heads,
max_steps=4,
dim_feedforward=2048,
dropout=0.1,
router_type="adaptive"
)
total_params, trainable_params = count_parameters(model)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print("\nTesting forward pass...")
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
attention_mask = torch.ones_like(input_ids)
attention_mask[:, -10:] = 0
print(f"Input shape: {input_ids.shape}")
logits, comp_loss = model(input_ids, attention_mask)
print(f"Output logits shape: {logits.shape}")
print(f"Computation loss: {comp_loss}")
print(f"Expected logits shape: ({batch_size}, {seq_len}, {vocab_size})")
print("\nTesting generation step...")
next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
print(f"Generated next token: {next_token}")
print("\nModel test completed successfully!")
if __name__ == "__main__":
main()