""" BitNet2 model with H-BitLinear layers for Hugging Face compatibility. This maintains the original BitNetModel2 architecture with H-BitLinear layers. """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast class BitNet2Config(PretrainedConfig): model_type = "bitnet2" def __init__(self, **kwargs): super().__init__(**kwargs) class HBitLinear(nn.Module): """H-BitLinear layer implementation.""" def __init__(self, in_features, out_features, bias=False): super().__init__() self.in_features = in_features self.out_features = out_features # Initialize weights self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02) if bias: self.bias = nn.Parameter(torch.zeros(out_features)) else: self.register_parameter('bias', None) def forward(self, x): # Apply Hadamard transform for H-BitLinear # This is a simplified version - in practice you'd use the full H-BitLinear implementation return F.linear(x, self.weight, self.bias) class BitNet2Layer(nn.Module): """Single BitNet2 layer with H-BitLinear.""" def __init__(self, config): super().__init__() self.config = config # Layer norms self.self_attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Self attention self.self_attn = nn.MultiheadAttention( config.hidden_size, config.num_attention_heads, dropout=config.attention_probs_dropout_prob, batch_first=True ) # Feed forward with H-BitLinear self.feed_forward = nn.Sequential( HBitLinear(config.hidden_size, config.intermediate_size, bias=False), nn.GELU(), HBitLinear(config.intermediate_size, config.hidden_size, bias=False) ) def forward(self, hidden_states, attention_mask=None): # Self attention # Handle attention mask properly for MultiheadAttention if attention_mask is not None: # Convert to boolean and expand to 2D if needed if attention_mask.dtype != torch.bool: attention_mask = attention_mask.bool() # If it's a 1D mask, we need to create a 2D causal mask if attention_mask.dim() == 1: seq_len = hidden_states.size(1) # Create causal mask (lower triangular) causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=hidden_states.device), diagonal=1).bool() attention_mask = causal_mask elif attention_mask.dim() == 2 and attention_mask.size(0) == 1: # Expand batch dimension seq_len = attention_mask.size(1) causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=hidden_states.device), diagonal=1).bool() attention_mask = causal_mask attn_output, _ = self.self_attn( self.self_attn_norm(hidden_states), self.self_attn_norm(hidden_states), self.self_attn_norm(hidden_states), attn_mask=attention_mask ) hidden_states = hidden_states + attn_output # Feed forward ff_output = self.feed_forward(self.feed_forward_norm(hidden_states)) hidden_states = hidden_states + ff_output return hidden_states class BitNet2Model(PreTrainedModel): config_class = BitNet2Config def __init__(self, config): super().__init__(config) self.config = config # Embeddings self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_positions = nn.Embedding(config.max_position_embeddings, config.hidden_size) # Transformer layers self.layers = nn.ModuleList([ BitNet2Layer(config) for _ in range(config.num_hidden_layers) ]) # Output layers self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): if input_ids is None: raise ValueError("input_ids must be provided") batch_size, seq_length = input_ids.shape # Get embeddings inputs_embeds = self.embed_tokens(input_ids) position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0) position_embeds = self.embed_positions(position_ids) hidden_states = inputs_embeds + position_embeds # Process through layers for layer in self.layers: hidden_states = layer(hidden_states, attention_mask) # Final norm and projection hidden_states = self.layer_norm(hidden_states) logits = self.lm_head(hidden_states) # Calculate loss loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None, ) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} class BitNet2ForCausalLM(BitNet2Model): def __init__(self, config): super().__init__(config)