proper-bitnet2-model / modeling_bitnet2.py
Ram07's picture
Initial upload of Proper BitNet2 model with H-BitLinear layers
46d4033 verified
"""
BitNet2 model with H-BitLinear layers for Hugging Face compatibility.
This maintains the original BitNetModel2 architecture with H-BitLinear layers.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
class BitNet2Config(PretrainedConfig):
model_type = "bitnet2"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class HBitLinear(nn.Module):
"""H-BitLinear layer implementation."""
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize weights
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter('bias', None)
def forward(self, x):
# Apply Hadamard transform for H-BitLinear
# This is a simplified version - in practice you'd use the full H-BitLinear implementation
return F.linear(x, self.weight, self.bias)
class BitNet2Layer(nn.Module):
"""Single BitNet2 layer with H-BitLinear."""
def __init__(self, config):
super().__init__()
self.config = config
# Layer norms
self.self_attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Self attention
self.self_attn = nn.MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
dropout=config.attention_probs_dropout_prob,
batch_first=True
)
# Feed forward with H-BitLinear
self.feed_forward = nn.Sequential(
HBitLinear(config.hidden_size, config.intermediate_size, bias=False),
nn.GELU(),
HBitLinear(config.intermediate_size, config.hidden_size, bias=False)
)
def forward(self, hidden_states, attention_mask=None):
# Self attention
# Handle attention mask properly for MultiheadAttention
if attention_mask is not None:
# Convert to boolean and expand to 2D if needed
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask.bool()
# If it's a 1D mask, we need to create a 2D causal mask
if attention_mask.dim() == 1:
seq_len = hidden_states.size(1)
# Create causal mask (lower triangular)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=hidden_states.device), diagonal=1).bool()
attention_mask = causal_mask
elif attention_mask.dim() == 2 and attention_mask.size(0) == 1:
# Expand batch dimension
seq_len = attention_mask.size(1)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=hidden_states.device), diagonal=1).bool()
attention_mask = causal_mask
attn_output, _ = self.self_attn(
self.self_attn_norm(hidden_states),
self.self_attn_norm(hidden_states),
self.self_attn_norm(hidden_states),
attn_mask=attention_mask
)
hidden_states = hidden_states + attn_output
# Feed forward
ff_output = self.feed_forward(self.feed_forward_norm(hidden_states))
hidden_states = hidden_states + ff_output
return hidden_states
class BitNet2Model(PreTrainedModel):
config_class = BitNet2Config
def __init__(self, config):
super().__init__(config)
self.config = config
# Embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_positions = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# Transformer layers
self.layers = nn.ModuleList([
BitNet2Layer(config)
for _ in range(config.num_hidden_layers)
])
# Output layers
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
if input_ids is None:
raise ValueError("input_ids must be provided")
batch_size, seq_length = input_ids.shape
# Get embeddings
inputs_embeds = self.embed_tokens(input_ids)
position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
position_embeds = self.embed_positions(position_ids)
hidden_states = inputs_embeds + position_embeds
# Process through layers
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
# Final norm and projection
hidden_states = self.layer_norm(hidden_states)
logits = self.lm_head(hidden_states)
# Calculate loss
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
class BitNet2ForCausalLM(BitNet2Model):
def __init__(self, config):
super().__init__(config)