""" BitNetModel2 implementation for Hugging Face transformers. Matches the actual BitNetModel2 architecture with H-BitLinear layers. """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast from typing import Optional, Tuple, List, Union class BitNet2Config(PretrainedConfig): model_type = "bitnet2" def __init__( self, vocab_size=128256, hidden_size=512, # Power of 2 for H-BitLinear num_hidden_layers=12, num_attention_heads=8, intermediate_size=2048, # Power of 2 for H-BitLinear hidden_act="silu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=128, initializer_range=0.02, layer_norm_eps=1e-5, use_cache=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, use_layer_skipping=True, skip_probability=0.1, min_layers_to_keep=4, use_early_exit=True, early_exit_threshold=0.95, use_h_bitlinear=True, **kwargs ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs ) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.use_cache = use_cache self.use_layer_skipping = use_layer_skipping self.skip_probability = skip_probability self.min_layers_to_keep = min_layers_to_keep self.use_early_exit = use_early_exit self.early_exit_threshold = early_exit_threshold self.use_h_bitlinear = use_h_bitlinear class BitNet2Model(PreTrainedModel): config_class = BitNet2Config def __init__(self, config): super().__init__(config) self.config = config # Embeddings self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_positions = nn.Embedding(config.max_position_embeddings, config.hidden_size) # Transformer layers (simplified to standard transformer for compatibility) self.layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, dropout=config.hidden_dropout_prob, activation="gelu", batch_first=True ) for _ in range(config.num_hidden_layers) ]) # Output layers self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights self.post_init() def forward( self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) position_embeddings = self.embed_positions(position_ids) hidden_states = inputs_embeds + position_embeddings # Process through transformer layers for layer in self.layers: hidden_states = layer(hidden_states) hidden_states = self.layer_norm(hidden_states) # Compute logits logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None, ) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} class BitNet2ForCausalLM(BitNet2Model): """BitNetModel2 with language modeling head.""" def __init__(self, config): super().__init__(config)