File size: 5,025 Bytes
e7e30db
c0a6a03
ce1faad
da42a90
e7e30db
 
3489232
 
ef5a88b
 
 
9b50edd
e7e30db
 
 
 
 
9b50edd
3489232
9b50edd
3489232
 
cae5830
e7e30db
3b4268c
 
 
 
9b50edd
 
3b4268c
9b50edd
 
3b4268c
9b50edd
 
 
 
 
 
 
 
 
 
3b4268c
9b50edd
 
 
3489232
 
 
 
 
 
 
e7e30db
9b50edd
ce1faad
 
 
 
 
 
 
 
 
 
9b50edd
ef5a88b
9b50edd
 
3b4268c
 
 
 
9b50edd
3489232
9b50edd
ef5a88b
 
3b4268c
ef5a88b
 
e7e30db
9b50edd
ef5a88b
3489232
bf6e0ca
 
ce1faad
bf6e0ca
ef5a88b
e7e30db
bf6e0ca
 
 
 
ef5a88b
 
3489232
ef5a88b
 
3b4268c
9b50edd
ef5a88b
 
9b50edd
e7e30db
bf6e0ca
 
 
 
3b4268c
9b50edd
 
 
 
 
c0a6a03
3b4268c
c0a6a03
ef5a88b
9b50edd
c0a6a03
 
 
 
ce1faad
 
 
 
 
 
 
c0a6a03
 
9b50edd
c0a6a03
bf6e0ca
c0a6a03
bf6e0ca
ef5a88b
 
3b4268c
ef5a88b
 
 
40911df
3b4268c
ef5a88b
 
 
 
 
3b4268c
ef5a88b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import json
import random
import torch
import firebase_admin
from firebase_admin import credentials, firestore
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
from transformers import BertTokenizer
from init_model import load_model
from dashboard import evolution_accuracy_plot

# === Initialize Firebase
if not firebase_admin._apps:
    cred = credentials.Certificate("firebase_key.json")
    firebase_admin.initialize_app(cred)
db = firestore.client()

# === Fetch and tokenize feedback from Firestore
def fetch_training_data(tokenizer):
    docs = db.collection("evo_feedback").stream()

    input_ids, attention_masks, labels = [], [], []
    for doc in docs:
        data = doc.to_dict()

        goal = data.get("goal", "").strip()
        sol1 = data.get("solution_1", "").strip()
        sol2 = data.get("solution_2", "").strip()
        winner = data.get("winner", "").strip()

        if not goal or not sol1 or not sol2 or not winner:
            continue

        text = f"{goal} [SEP] {sol1 if winner == '1' else sol2}"
        encoding = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=128,
            return_tensors="pt"
        )

        input_ids.append(encoding["input_ids"][0])
        attention_masks.append(encoding["attention_mask"][0])
        label = 0 if winner == "1" else 1
        labels.append(label)

    if len(input_ids) < 2:
        return None, None, None

    return (
        torch.stack(input_ids),
        torch.stack(attention_masks),
        torch.tensor(labels, dtype=torch.long)
    )

# === Random architecture mutation (NAS-like)
def mutate_config():
    return EvoTransformerConfig(
        hidden_size=384,
        num_layers=random.choice([4, 6, 8]),
        num_labels=2,
        num_heads=random.choice([4, 6, 8]),
        ffn_dim=random.choice([512, 1024, 2048]),
        use_memory=random.choice([False, True])
    )

# === Model summary text
def get_architecture_summary(model):
    cfg = model.config
    return (
        f"Layers: {cfg.num_layers}\n"
        f"Attention Heads: {cfg.num_heads}\n"
        f"FFN Dim: {cfg.ffn_dim}\n"
        f"Memory Enabled: {cfg.use_memory}"
    )

# === Main retraining logic
def retrain_model():
    try:
        print("πŸ” Starting retrain... fetching data")
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        input_ids, attention_masks, labels = fetch_training_data(tokenizer)

        if input_ids is None:
            return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        config = mutate_config()
        model = EvoTransformerForClassification(config).to(device)
        model.train()

        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)
        labels = labels.to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
        loss_fn = torch.nn.CrossEntropyLoss()

        for epoch in range(3):
            optimizer.zero_grad()
            logits = model(input_ids, attention_mask=attention_masks)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()
            print(f"πŸ” Epoch {epoch+1}: Loss = {loss.item():.4f}")

        # Sanity check logits
        if logits.shape[-1] < 2:
            raise ValueError("Logits shape invalid. Retrained model did not output 2 classes.")

        # Accuracy
        model.eval()
        with torch.no_grad():
            preds = torch.argmax(logits, dim=1)
            correct = (preds == labels).sum().item()
            accuracy = round(correct / len(labels), 4)

        # Log evolution
        log_path = "trained_model/evolution_log.json"
        os.makedirs("trained_model", exist_ok=True)
        history = []
        if os.path.exists(log_path):
            with open(log_path, "r") as f:
                history = json.load(f)

        history.append({
            "accuracy": accuracy,
            "num_layers": config.num_layers,
            "num_heads": config.num_heads,
            "ffn_dim": config.ffn_dim,
            "use_memory": config.use_memory
        })

        with open(log_path, "w") as f:
            json.dump(history, f, indent=2)

        # Save model + tokenizer
        model.save_pretrained("trained_model")
        tokenizer.save_pretrained("trained_model")
        print("βœ… EvoTransformer retrained and saved.")

        # Load updated summary + plot
        updated_model = load_model()
        arch_text = get_architecture_summary(updated_model)
        plot = evolution_accuracy_plot()

        return arch_text, plot, f"βœ… Retrained successfully β€” Accuracy: {accuracy * 100:.2f}%"

    except Exception as e:
        print(f"❌ Retraining failed: {e}")
        return "❌ Error", None, f"Retrain failed: {e}"

# CLI Trigger
if __name__ == "__main__":
    retrain_model()