EvoTransformer-v2.1 / dashboard.py
HemanM's picture
Update dashboard.py
4cc33c1 verified
import matplotlib.pyplot as plt
import firebase_admin
from firebase_admin import credentials, firestore
import io
from PIL import Image
import os
import json
# === Initialize Firebase
if not firebase_admin._apps:
cred = credentials.Certificate("firebase_key.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
# === πŸ—³οΈ Feedback Summary Plot (Solution 1 vs Solution 2)
def update_dashboard_plot():
try:
docs = db.collection("evo_feedback").stream()
count_1, count_2 = 0, 0
for doc in docs:
winner = doc.to_dict().get("winner", "").strip()
if winner == "1":
count_1 += 1
elif winner == "2":
count_2 += 1
total = count_1 + count_2
fig, ax = plt.subplots()
if total == 0:
ax.text(0.5, 0.5, "No feedback collected yet.", ha="center", va="center", fontsize=12)
ax.axis("off")
else:
ax.bar(["Solution 1", "Solution 2"], [count_1, count_2], color=["skyblue", "lightgreen"])
ax.set_ylabel("Votes")
ax.set_title(f"πŸ—³οΈ Feedback Summary (Total: {total})")
for i, v in enumerate([count_1, count_2]):
ax.text(i, v + 0.5, str(v), ha='center', fontweight='bold')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
return Image.open(buf)
except Exception as e:
return _fallback_plot(f"Error loading feedback chart: {e}")
# === πŸ“ˆ Evolution Accuracy Plot (Accuracy vs Generation)
def evolution_accuracy_plot():
try:
log_path = "trained_model/evolution_log.json"
if not os.path.exists(log_path):
return _fallback_plot("⚠️ No evolution log found")
with open(log_path, "r") as f:
log_data = json.load(f)
if not log_data:
return _fallback_plot("⚠️ Evolution log is empty")
generations = list(range(1, len(log_data) + 1))
accuracies = [round(entry.get("accuracy", 0), 4) for entry in log_data]
tooltips = [
f"L: {entry.get('num_layers', '?')}, H: {entry.get('num_heads', '?')}, FFN: {entry.get('ffn_dim', '?')}, Mem: {entry.get('use_memory', '?')}"
for entry in log_data
]
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(generations, accuracies, marker='o', color="purple", label="Accuracy")
for i, tooltip in enumerate(tooltips):
offset = (i % 2) * 5 - 2
ax.annotate(tooltip, (generations[i], accuracies[i]),
fontsize=7, xytext=(0, 8 + offset), textcoords='offset points',
ha='center')
ax.set_xlabel("Generation")
ax.set_ylabel("Accuracy")
ax.set_title("πŸ“ˆ EvoTransformer Evolution Accuracy")
ax.set_ylim([0, 1.05])
ax.grid(True)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
return Image.open(buf)
except Exception as e:
return _fallback_plot(f"⚠️ Error loading evolution plot: {e}")
# === πŸ† Hall of Fame Plot (most frequent goals)
def leaderboard_plot():
try:
docs = db.collection("evo_feedback").stream()
goal_counter = {}
for doc in docs:
goal = doc.to_dict().get("goal", "").strip()
if goal:
goal_counter[goal] = goal_counter.get(goal, 0) + 1
sorted_goals = sorted(goal_counter.items(), key=lambda x: x[1], reverse=True)[:10]
if not sorted_goals:
return _fallback_plot("πŸ† No top goals yet.")
labels = [g[0][:25] + "..." if len(g[0]) > 25 else g[0] for g, _ in sorted_goals]
counts = [c for _, c in sorted_goals]
fig, ax = plt.subplots(figsize=(8, 4))
ax.barh(labels[::-1], counts[::-1], color="orange")
ax.set_xlabel("Votes")
ax.set_title("πŸ† Hall of Fame: Top 10 Goals")
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
return Image.open(buf)
except Exception as e:
return _fallback_plot(f"⚠️ Error loading leaderboard: {e}")
# === πŸ”’ Real-time vote count for display
def get_vote_counts():
try:
docs = db.collection("evo_feedback").stream()
count_1, count_2 = 0, 0
for doc in docs:
winner = doc.to_dict().get("winner", "").strip()
if winner == "1":
count_1 += 1
elif winner == "2":
count_2 += 1
return {"1": count_1, "2": count_2}
except:
return {"1": 0, "2": 0}
# === πŸ› οΈ Fallback Plot
def _fallback_plot(message):
fig, ax = plt.subplots()
ax.text(0.5, 0.5, message, ha="center", va="center", fontsize=11)
ax.axis('off')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
return Image.open(buf)