File size: 6,279 Bytes
46d4033 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
"""
BitNet2 model with H-BitLinear layers for Hugging Face compatibility.
This maintains the original BitNetModel2 architecture with H-BitLinear layers.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
class BitNet2Config(PretrainedConfig):
model_type = "bitnet2"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class HBitLinear(nn.Module):
"""H-BitLinear layer implementation."""
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize weights
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter('bias', None)
def forward(self, x):
# Apply Hadamard transform for H-BitLinear
# This is a simplified version - in practice you'd use the full H-BitLinear implementation
return F.linear(x, self.weight, self.bias)
class BitNet2Layer(nn.Module):
"""Single BitNet2 layer with H-BitLinear."""
def __init__(self, config):
super().__init__()
self.config = config
# Layer norms
self.self_attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Self attention
self.self_attn = nn.MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
dropout=config.attention_probs_dropout_prob,
batch_first=True
)
# Feed forward with H-BitLinear
self.feed_forward = nn.Sequential(
HBitLinear(config.hidden_size, config.intermediate_size, bias=False),
nn.GELU(),
HBitLinear(config.intermediate_size, config.hidden_size, bias=False)
)
def forward(self, hidden_states, attention_mask=None):
# Self attention
# Handle attention mask properly for MultiheadAttention
if attention_mask is not None:
# Convert to boolean and expand to 2D if needed
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask.bool()
# If it's a 1D mask, we need to create a 2D causal mask
if attention_mask.dim() == 1:
seq_len = hidden_states.size(1)
# Create causal mask (lower triangular)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=hidden_states.device), diagonal=1).bool()
attention_mask = causal_mask
elif attention_mask.dim() == 2 and attention_mask.size(0) == 1:
# Expand batch dimension
seq_len = attention_mask.size(1)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=hidden_states.device), diagonal=1).bool()
attention_mask = causal_mask
attn_output, _ = self.self_attn(
self.self_attn_norm(hidden_states),
self.self_attn_norm(hidden_states),
self.self_attn_norm(hidden_states),
attn_mask=attention_mask
)
hidden_states = hidden_states + attn_output
# Feed forward
ff_output = self.feed_forward(self.feed_forward_norm(hidden_states))
hidden_states = hidden_states + ff_output
return hidden_states
class BitNet2Model(PreTrainedModel):
config_class = BitNet2Config
def __init__(self, config):
super().__init__(config)
self.config = config
# Embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_positions = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# Transformer layers
self.layers = nn.ModuleList([
BitNet2Layer(config)
for _ in range(config.num_hidden_layers)
])
# Output layers
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
if input_ids is None:
raise ValueError("input_ids must be provided")
batch_size, seq_length = input_ids.shape
# Get embeddings
inputs_embeds = self.embed_tokens(input_ids)
position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
position_embeds = self.embed_positions(position_ids)
hidden_states = inputs_embeds + position_embeds
# Process through layers
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
# Final norm and projection
hidden_states = self.layer_norm(hidden_states)
logits = self.lm_head(hidden_states)
# Calculate loss
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
class BitNet2ForCausalLM(BitNet2Model):
def __init__(self, config):
super().__init__(config)
|