File size: 3,148 Bytes
cdd7a15
 
daeebb8
c6b6ef5
6604c50
daeebb8
 
87069f6
 
 
 
 
 
 
 
 
 
daeebb8
c6b6ef5
 
 
87069f6
 
 
daeebb8
e67c8ca
daeebb8
 
 
 
 
c6b6ef5
87069f6
cdd7a15
87069f6
 
 
 
 
c6b6ef5
e67c8ca
c6b6ef5
87069f6
 
 
cdd7a15
 
87069f6
c6b6ef5
 
e67c8ca
c6b6ef5
 
 
 
daeebb8
87069f6
c6b6ef5
 
 
cdd7a15
 
 
 
 
e67c8ca
c6b6ef5
cdd7a15
e67c8ca
cdd7a15
c6b6ef5
daeebb8
c6b6ef5
 
 
cdd7a15
e897bf0
daeebb8
 
c6b6ef5
cdd7a15
 
c6b6ef5
daeebb8
 
 
cdd7a15
 
c6b6ef5
daeebb8
cdd7a15
 
 
e897bf0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import json
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig

class EvoTransformerConfig(PretrainedConfig):
    def __init__(
        self,
        hidden_size=384,
        num_layers=6,
        num_labels=2,
        num_heads=6,
        ffn_dim=1024,
        use_memory=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_labels = num_labels
        self.num_heads = num_heads
        self.ffn_dim = ffn_dim
        self.use_memory = use_memory


class EvoTransformerForClassification(PreTrainedModel):
    config_class = EvoTransformerConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # === Architecture traits for UI, mutation, etc.
        self.num_layers = config.num_layers
        self.num_heads = config.num_heads
        self.ffn_dim = config.ffn_dim
        self.use_memory = config.use_memory

        self.embedding = nn.Embedding(30522, config.hidden_size)  # BERT vocab size

        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=config.hidden_size,
                nhead=config.num_heads,
                dim_feedforward=config.ffn_dim,
                batch_first=False  # Required for transpose trick
            )
            for _ in range(config.num_layers)
        ])

        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, config.num_labels)
        )

        self.init_weights()

    def forward(self, input_ids, attention_mask=None, labels=None):
        # Embedding and prep for transformer
        x = self.embedding(input_ids)  # [batch, seq_len, hidden]
        x = x.transpose(0, 1)          # [seq_len, batch, hidden]

        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None

        for layer in self.layers:
            x = layer(x, src_key_padding_mask=key_padding_mask)

        x = x.mean(dim=0)  # [batch, hidden] — mean pooling
        logits = self.classifier(x)

        if labels is not None:
            loss = nn.functional.cross_entropy(logits, labels)
            return loss, logits

        return logits

    def save_pretrained(self, save_directory):
        os.makedirs(save_directory, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
        with open(os.path.join(save_directory, "config.json"), "w") as f:
            f.write(self.config.to_json_string())

    @classmethod
    def from_pretrained(cls, load_directory):
        config_path = os.path.join(load_directory, "config.json")
        model_path = os.path.join(load_directory, "pytorch_model.bin")
        config = EvoTransformerConfig.from_json_file(config_path)
        model = cls(config)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        return model