EvoTransformer-v2.1 / evo_model.py
HemanM's picture
Update evo_model.py
cdd7a15 verified
import os
import json
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
class EvoTransformerConfig(PretrainedConfig):
def __init__(
self,
hidden_size=384,
num_layers=6,
num_labels=2,
num_heads=6,
ffn_dim=1024,
use_memory=False,
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_labels = num_labels
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.use_memory = use_memory
class EvoTransformerForClassification(PreTrainedModel):
config_class = EvoTransformerConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# === Architecture traits for UI, mutation, etc.
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.ffn_dim = config.ffn_dim
self.use_memory = config.use_memory
self.embedding = nn.Embedding(30522, config.hidden_size) # BERT vocab size
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_heads,
dim_feedforward=config.ffn_dim,
batch_first=False # Required for transpose trick
)
for _ in range(config.num_layers)
])
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size, 256),
nn.ReLU(),
nn.Linear(256, config.num_labels)
)
self.init_weights()
def forward(self, input_ids, attention_mask=None, labels=None):
# Embedding and prep for transformer
x = self.embedding(input_ids) # [batch, seq_len, hidden]
x = x.transpose(0, 1) # [seq_len, batch, hidden]
key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
for layer in self.layers:
x = layer(x, src_key_padding_mask=key_padding_mask)
x = x.mean(dim=0) # [batch, hidden] — mean pooling
logits = self.classifier(x)
if labels is not None:
loss = nn.functional.cross_entropy(logits, labels)
return loss, logits
return logits
def save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
with open(os.path.join(save_directory, "config.json"), "w") as f:
f.write(self.config.to_json_string())
@classmethod
def from_pretrained(cls, load_directory):
config_path = os.path.join(load_directory, "config.json")
model_path = os.path.join(load_directory, "pytorch_model.bin")
config = EvoTransformerConfig.from_json_file(config_path)
model = cls(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
return model