EvoTransformer-v2.1 / watchdog.py
HemanM's picture
Update watchdog.py
bf6e0ca verified
import os
import json
import random
import torch
import firebase_admin
from firebase_admin import credentials, firestore
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
from transformers import BertTokenizer
from init_model import load_model
from dashboard import evolution_accuracy_plot
# === Initialize Firebase
if not firebase_admin._apps:
cred = credentials.Certificate("firebase_key.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
# === Fetch and tokenize feedback from Firestore
def fetch_training_data(tokenizer):
docs = db.collection("evo_feedback").stream()
input_ids, attention_masks, labels = [], [], []
for doc in docs:
data = doc.to_dict()
goal = data.get("goal", "").strip()
sol1 = data.get("solution_1", "").strip()
sol2 = data.get("solution_2", "").strip()
winner = data.get("winner", "").strip()
if not goal or not sol1 or not sol2 or not winner:
continue
text = f"{goal} [SEP] {sol1 if winner == '1' else sol2}"
encoding = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
)
input_ids.append(encoding["input_ids"][0])
attention_masks.append(encoding["attention_mask"][0])
label = 0 if winner == "1" else 1
labels.append(label)
if len(input_ids) < 2:
return None, None, None
return (
torch.stack(input_ids),
torch.stack(attention_masks),
torch.tensor(labels, dtype=torch.long)
)
# === Random architecture mutation (NAS-like)
def mutate_config():
return EvoTransformerConfig(
hidden_size=384,
num_layers=random.choice([4, 6, 8]),
num_labels=2,
num_heads=random.choice([4, 6, 8]),
ffn_dim=random.choice([512, 1024, 2048]),
use_memory=random.choice([False, True])
)
# === Model summary text
def get_architecture_summary(model):
cfg = model.config
return (
f"Layers: {cfg.num_layers}\n"
f"Attention Heads: {cfg.num_heads}\n"
f"FFN Dim: {cfg.ffn_dim}\n"
f"Memory Enabled: {cfg.use_memory}"
)
# === Main retraining logic
def retrain_model():
try:
print("πŸ” Starting retrain... fetching data")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
input_ids, attention_masks, labels = fetch_training_data(tokenizer)
if input_ids is None:
return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = mutate_config()
model = EvoTransformerForClassification(config).to(device)
model.train()
input_ids = input_ids.to(device)
attention_masks = attention_masks.to(device)
labels = labels.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(3):
optimizer.zero_grad()
logits = model(input_ids, attention_mask=attention_masks)
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
print(f"πŸ” Epoch {epoch+1}: Loss = {loss.item():.4f}")
# Sanity check logits
if logits.shape[-1] < 2:
raise ValueError("Logits shape invalid. Retrained model did not output 2 classes.")
# Accuracy
model.eval()
with torch.no_grad():
preds = torch.argmax(logits, dim=1)
correct = (preds == labels).sum().item()
accuracy = round(correct / len(labels), 4)
# Log evolution
log_path = "trained_model/evolution_log.json"
os.makedirs("trained_model", exist_ok=True)
history = []
if os.path.exists(log_path):
with open(log_path, "r") as f:
history = json.load(f)
history.append({
"accuracy": accuracy,
"num_layers": config.num_layers,
"num_heads": config.num_heads,
"ffn_dim": config.ffn_dim,
"use_memory": config.use_memory
})
with open(log_path, "w") as f:
json.dump(history, f, indent=2)
# Save model + tokenizer
model.save_pretrained("trained_model")
tokenizer.save_pretrained("trained_model")
print("βœ… EvoTransformer retrained and saved.")
# Load updated summary + plot
updated_model = load_model()
arch_text = get_architecture_summary(updated_model)
plot = evolution_accuracy_plot()
return arch_text, plot, f"βœ… Retrained successfully β€” Accuracy: {accuracy * 100:.2f}%"
except Exception as e:
print(f"❌ Retraining failed: {e}")
return "❌ Error", None, f"Retrain failed: {e}"
# CLI Trigger
if __name__ == "__main__":
retrain_model()