bitnet-8bit-v2 / modeling_bitnet.py
Ram07's picture
Upload BitNet model checkpoint
274e06d verified
"""
BitNet model implementation for Hugging Face transformers.
This is a simplified version for inference compatibility.
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
class BitNetConfig(PretrainedConfig):
model_type = "bitnet"
def __init__(
self,
vocab_size=128256,
hidden_size=1024,
num_hidden_layers=12,
num_attention_heads=8,
intermediate_size=4096,
hidden_act="silu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
use_layer_skipping=True,
skip_probability=0.1,
min_layers_to_keep=4,
use_early_exit=True,
early_exit_threshold=0.95,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.use_layer_skipping = use_layer_skipping
self.skip_probability = skip_probability
self.min_layers_to_keep = min_layers_to_keep
self.use_early_exit = use_early_exit
self.early_exit_threshold = early_exit_threshold
class BitNetModel(PreTrainedModel):
config_class = BitNetConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_positions = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# Transformer layers
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.hidden_dropout_prob,
activation="gelu",
batch_first=True
)
for _ in range(config.num_hidden_layers)
])
# Output layers
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
position_embeddings = self.embed_positions(position_ids)
hidden_states = inputs_embeds + position_embeddings
# Process through transformer layers
for layer in self.layers:
hidden_states = layer(hidden_states)
hidden_states = self.layer_norm(hidden_states)
# Compute logits
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
class BitNetForCausalLM(BitNetModel):
"""BitNet model with language modeling head."""
def __init__(self, config):
super().__init__(config)