File size: 6,933 Bytes
a65288c 744de0d a65288c 744de0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
"""
A custom model for causal language modeling, compatible with HuggingFace.
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
class SSLLMConfig(PretrainedConfig):
"""Configuration class for SSLLM model compatible with HuggingFace."""
model_type = "ssllm"
def __init__(
self,
vocab_size=100277,
d_model=768,
num_heads=12,
num_layers=10,
d_ff=2560,
max_seq_len=1024,
dropout_rate=0.1,
attention_dropout=0.1,
stochastic_depth_rate=0.1,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.d_ff = d_ff
self.max_seq_len = max_seq_len
self.dropout_rate = dropout_rate
self.attention_dropout = attention_dropout
self.stochastic_depth_rate = stochastic_depth_rate
# HuggingFace compatibility
self.hidden_size = d_model
self.num_attention_heads = num_heads
self.num_hidden_layers = num_layers
self.intermediate_size = d_ff
self.max_position_embeddings = max_seq_len
class MultiHeadSelfAttention(nn.Module):
"""Multi-head self-attention module matching SSLLM exactly."""
def __init__(self, d_model, num_heads, attention_dropout, dropout_rate):
super().__init__()
self.attention = nn.MultiheadAttention(
d_model,
num_heads,
dropout=attention_dropout,
bias=True,
batch_first=True
)
self.resid_dropout = nn.Dropout(dropout_rate)
def forward(self, x):
B, T, C = x.size()
# Create causal mask
causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
# Apply attention
attn_output, _ = self.attention(x, x, x, attn_mask=causal_mask, is_causal=True)
return self.resid_dropout(attn_output)
class TransformerBlock(nn.Module):
"""Transformer block matching SSLLM exactly."""
def __init__(self, d_model, num_heads, d_ff, dropout_rate, attention_dropout, stochastic_depth_rate):
super().__init__()
self.attn = MultiHeadSelfAttention(d_model, num_heads, attention_dropout, dropout_rate)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout_rate)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout_rate)
self.drop_path = nn.Dropout(stochastic_depth_rate) if stochastic_depth_rate > 0 else nn.Identity()
def forward(self, x):
# Pre-layer norm for attention
normed_x = self.norm1(x)
attn_out = self.attn(normed_x)
x = x + self.dropout(attn_out)
# Pre-layer norm for feed-forward
normed_x = self.norm2(x)
ff_out = self.ff(normed_x)
x = x + self.dropout(ff_out)
return x
class SSLLMForCausalLM(PreTrainedModel, GenerationMixin):
"""SSLLM model for causal language modeling, compatible with HuggingFace."""
config_class = SSLLMConfig
def __init__(self, config):
super().__init__(config)
self.token_embed = nn.Embedding(config.vocab_size, config.d_model)
self.pos_embed = nn.Parameter(torch.zeros(1, config.max_seq_len, config.d_model))
self.dropout = nn.Dropout(config.dropout_rate)
# Create transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(
config.d_model,
config.num_heads,
config.d_ff,
config.dropout_rate,
config.attention_dropout,
config.stochastic_depth_rate
) for _ in range(config.num_layers)
])
# Final layer norm and head
self.ln_f = nn.LayerNorm(config.d_model)
self.head = nn.Linear(config.d_model, config.vocab_size)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
def forward(self, input_ids, attention_mask=None, labels=None, past_key_values=None, **kwargs):
B, T = input_ids.size()
# Embeddings
tok_emb = self.token_embed(input_ids)
pos_emb = self.pos_embed[:, :T, :]
x = self.dropout(tok_emb + pos_emb)
# Apply transformer blocks
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.head(x)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, past_key_values=None, **kwargs):
"""Prepare inputs for generation."""
# If attention_mask is not provided, create one
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
}
def get_output_embeddings(self):
"""Get output embeddings for generation."""
return self.head
def set_output_embeddings(self, new_embeddings):
"""Set output embeddings."""
self.head = new_embeddings
def get_input_embeddings(self):
"""Get input embeddings."""
return self.token_embed
def set_input_embeddings(self, new_embeddings):
"""Set input embeddings."""
self.token_embed = new_embeddings |