Spaces:
Sleeping
Sleeping
import os | |
import json | |
import random | |
import torch | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
from evo_model import EvoTransformerForClassification, EvoTransformerConfig | |
from transformers import BertTokenizer | |
from init_model import load_model | |
from dashboard import evolution_accuracy_plot | |
# === Initialize Firebase | |
if not firebase_admin._apps: | |
cred = credentials.Certificate("firebase_key.json") | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
# === Fetch and tokenize feedback from Firestore | |
def fetch_training_data(tokenizer): | |
docs = db.collection("evo_feedback").stream() | |
input_ids, attention_masks, labels = [], [], [] | |
for doc in docs: | |
data = doc.to_dict() | |
goal = data.get("goal", "").strip() | |
sol1 = data.get("solution_1", "").strip() | |
sol2 = data.get("solution_2", "").strip() | |
winner = data.get("winner", "").strip() | |
if not goal or not sol1 or not sol2 or not winner: | |
continue | |
text = f"{goal} [SEP] {sol1 if winner == '1' else sol2}" | |
encoding = tokenizer( | |
text, | |
truncation=True, | |
padding="max_length", | |
max_length=128, | |
return_tensors="pt" | |
) | |
input_ids.append(encoding["input_ids"][0]) | |
attention_masks.append(encoding["attention_mask"][0]) | |
label = 0 if winner == "1" else 1 | |
labels.append(label) | |
if len(input_ids) < 2: | |
return None, None, None | |
return ( | |
torch.stack(input_ids), | |
torch.stack(attention_masks), | |
torch.tensor(labels, dtype=torch.long) | |
) | |
# === Random architecture mutation (NAS-like) | |
def mutate_config(): | |
return EvoTransformerConfig( | |
hidden_size=384, | |
num_layers=random.choice([4, 6, 8]), | |
num_labels=2, | |
num_heads=random.choice([4, 6, 8]), | |
ffn_dim=random.choice([512, 1024, 2048]), | |
use_memory=random.choice([False, True]) | |
) | |
# === Model summary text | |
def get_architecture_summary(model): | |
cfg = model.config | |
return ( | |
f"Layers: {cfg.num_layers}\n" | |
f"Attention Heads: {cfg.num_heads}\n" | |
f"FFN Dim: {cfg.ffn_dim}\n" | |
f"Memory Enabled: {cfg.use_memory}" | |
) | |
# === Main retraining logic | |
def retrain_model(): | |
try: | |
print("π Starting retrain... fetching data") | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
input_ids, attention_masks, labels = fetch_training_data(tokenizer) | |
if input_ids is None: | |
return "β οΈ Not enough data to retrain.", None, "Please log more feedback first." | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
config = mutate_config() | |
model = EvoTransformerForClassification(config).to(device) | |
model.train() | |
input_ids = input_ids.to(device) | |
attention_masks = attention_masks.to(device) | |
labels = labels.to(device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) | |
loss_fn = torch.nn.CrossEntropyLoss() | |
for epoch in range(3): | |
optimizer.zero_grad() | |
logits = model(input_ids, attention_mask=attention_masks) | |
loss = loss_fn(logits, labels) | |
loss.backward() | |
optimizer.step() | |
print(f"π Epoch {epoch+1}: Loss = {loss.item():.4f}") | |
# Sanity check logits | |
if logits.shape[-1] < 2: | |
raise ValueError("Logits shape invalid. Retrained model did not output 2 classes.") | |
# Accuracy | |
model.eval() | |
with torch.no_grad(): | |
preds = torch.argmax(logits, dim=1) | |
correct = (preds == labels).sum().item() | |
accuracy = round(correct / len(labels), 4) | |
# Log evolution | |
log_path = "trained_model/evolution_log.json" | |
os.makedirs("trained_model", exist_ok=True) | |
history = [] | |
if os.path.exists(log_path): | |
with open(log_path, "r") as f: | |
history = json.load(f) | |
history.append({ | |
"accuracy": accuracy, | |
"num_layers": config.num_layers, | |
"num_heads": config.num_heads, | |
"ffn_dim": config.ffn_dim, | |
"use_memory": config.use_memory | |
}) | |
with open(log_path, "w") as f: | |
json.dump(history, f, indent=2) | |
# Save model + tokenizer | |
model.save_pretrained("trained_model") | |
tokenizer.save_pretrained("trained_model") | |
print("β EvoTransformer retrained and saved.") | |
# Load updated summary + plot | |
updated_model = load_model() | |
arch_text = get_architecture_summary(updated_model) | |
plot = evolution_accuracy_plot() | |
return arch_text, plot, f"β Retrained successfully β Accuracy: {accuracy * 100:.2f}%" | |
except Exception as e: | |
print(f"β Retraining failed: {e}") | |
return "β Error", None, f"Retrain failed: {e}" | |
# CLI Trigger | |
if __name__ == "__main__": | |
retrain_model() | |