Spaces:
Running on Zero

lad / llama_diffusion_model.py
Ruurd's picture
First commit
7252f98
raw
history blame
5.43 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast
from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
from peft import LoraConfig, get_peft_model
import os
hf_token = os.getenv("HF_TOKEN")
class BidirectionalLlamaAttention(nn.Module):
def __init__(self, original_layer, masking='unidirectional'):
super().__init__()
self.original = original_layer
self.masking = masking
self.q_proj = original_layer.q_proj
self.k_proj = original_layer.k_proj
self.v_proj = original_layer.v_proj
self.o_proj = original_layer.o_proj
self.head_dim = self.q_proj.out_features // original_layer.num_heads
self.num_heads = original_layer.num_heads
self.num_key_value_groups = original_layer.num_key_value_groups
self.attention_dropout = original_layer.attention_dropout
self.layer_idx = original_layer.layer_idx
self.scaling = original_layer.scaling
def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, cache_position=None, **kwargs):
bsz, seq_len, _ = hidden_states.size()
query_states = self._split_heads(self.q_proj(hidden_states))
key_states = self._split_heads(self.k_proj(hidden_states))
value_states = self._split_heads(self.v_proj(hidden_states))
cos, sin = position_embeddings
query_states, key_states = self._apply_rotary(query_states, key_states, cos, sin)
if self.masking == 'bidirectional':
attn_mask = torch.ones((bsz, 1, seq_len, seq_len), device=hidden_states.device)
else:
attn_mask = torch.tril(torch.ones(seq_len, seq_len, device=hidden_states.device)).unsqueeze(0).unsqueeze(0)
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling
attn_weights = attn_weights + attn_mask.log()
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = self._merge_heads(attn_output)
return self.o_proj(attn_output), attn_weights
def _split_heads(self, x):
return x.view(x.size(0), x.size(1), self.num_heads, self.head_dim).transpose(1, 2)
def _merge_heads(self, x):
return x.transpose(1, 2).contiguous().view(x.size(0), -1, self.num_heads * self.head_dim)
def _apply_rotary(self, q, k, cos, sin):
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_rot = (q * cos) + (self._rotate_half(q) * sin)
k_rot = (k * cos) + (self._rotate_half(k) * sin)
return q_rot, k_rot
def _rotate_half(self, x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class CustomTransformerConfig(PretrainedConfig):
def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0, max_position_embeddings=4096, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.prediction_chunk = prediction_chunk
self.max_position_embeddings = max_position_embeddings
class CustomTransformerModel(PreTrainedModel):
config_class = CustomTransformerConfig
def __init__(self, config):
super().__init__(config)
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, token=hf_token)
self.llama.resize_token_embeddings(config.vocab_size)
for i, layer in enumerate(self.llama.model.layers):
layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking='bidirectional')
for param in self.llama.parameters():
param.requires_grad = False
for param in self.llama.lm_head.parameters():
param.requires_grad = True
lora_config = LoraConfig(
r=256,
lora_alpha=256,
lora_dropout=0.0,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
bias="none",
task_type=None
)
self.llama = get_peft_model(self.llama, lora_config)
self.llama = self.llama.to(torch.float16)
def forward(self, input_ids, labels=None, **kwargs):
batch_size, seq_length = input_ids.shape
assert seq_length == self.config.prediction_chunk
with autocast("cuda", dtype=torch.float16):
outputs = self.llama(input_ids=input_ids, output_hidden_states=True, **kwargs)
logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, self.config.prediction_chunk, self.config.vocab_size)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
def disable_dropout(model):
for name, module in model.named_modules():
if isinstance(module, nn.Dropout):
setattr(model, name, nn.Identity())
return model