File size: 6,279 Bytes
46d4033
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""

BitNet2 model with H-BitLinear layers for Hugging Face compatibility.

This maintains the original BitNetModel2 architecture with H-BitLinear layers.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

class BitNet2Config(PretrainedConfig):
    model_type = "bitnet2"
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

class HBitLinear(nn.Module):
    """H-BitLinear layer implementation."""
    
    def __init__(self, in_features, out_features, bias=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Initialize weights
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)
    
    def forward(self, x):
        # Apply Hadamard transform for H-BitLinear
        # This is a simplified version - in practice you'd use the full H-BitLinear implementation
        return F.linear(x, self.weight, self.bias)

class BitNet2Layer(nn.Module):
    """Single BitNet2 layer with H-BitLinear."""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Layer norms
        self.self_attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.feed_forward_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        # Self attention
        self.self_attn = nn.MultiheadAttention(
            config.hidden_size,
            config.num_attention_heads,
            dropout=config.attention_probs_dropout_prob,
            batch_first=True
        )
        
        # Feed forward with H-BitLinear
        self.feed_forward = nn.Sequential(
            HBitLinear(config.hidden_size, config.intermediate_size, bias=False),
            nn.GELU(),
            HBitLinear(config.intermediate_size, config.hidden_size, bias=False)
        )
    
    def forward(self, hidden_states, attention_mask=None):
        # Self attention
        # Handle attention mask properly for MultiheadAttention
        if attention_mask is not None:
            # Convert to boolean and expand to 2D if needed
            if attention_mask.dtype != torch.bool:
                attention_mask = attention_mask.bool()
            
            # If it's a 1D mask, we need to create a 2D causal mask
            if attention_mask.dim() == 1:
                seq_len = hidden_states.size(1)
                # Create causal mask (lower triangular)
                causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=hidden_states.device), diagonal=1).bool()
                attention_mask = causal_mask
            elif attention_mask.dim() == 2 and attention_mask.size(0) == 1:
                # Expand batch dimension
                seq_len = attention_mask.size(1)
                causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=hidden_states.device), diagonal=1).bool()
                attention_mask = causal_mask
        
        attn_output, _ = self.self_attn(
            self.self_attn_norm(hidden_states),
            self.self_attn_norm(hidden_states),
            self.self_attn_norm(hidden_states),
            attn_mask=attention_mask
        )
        hidden_states = hidden_states + attn_output
        
        # Feed forward
        ff_output = self.feed_forward(self.feed_forward_norm(hidden_states))
        hidden_states = hidden_states + ff_output
        
        return hidden_states

class BitNet2Model(PreTrainedModel):
    config_class = BitNet2Config
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # Embeddings
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.embed_positions = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        
        # Transformer layers
        self.layers = nn.ModuleList([
            BitNet2Layer(config)
            for _ in range(config.num_hidden_layers)
        ])
        
        # Output layers
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        self.post_init()
    
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        if input_ids is None:
            raise ValueError("input_ids must be provided")
            
        batch_size, seq_length = input_ids.shape
        
        # Get embeddings
        inputs_embeds = self.embed_tokens(input_ids)
        position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
        position_embeds = self.embed_positions(position_ids)
        hidden_states = inputs_embeds + position_embeds
        
        # Process through layers
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)
        
        # Final norm and projection
        hidden_states = self.layer_norm(hidden_states)
        logits = self.lm_head(hidden_states)
        
        # Calculate loss
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = nn.functional.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
        
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )
    
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

class BitNet2ForCausalLM(BitNet2Model):
    def __init__(self, config):
        super().__init__(config)