import pandas as pd
import gradio as gr
import os
import re
import requests
from dotenv import load_dotenv
from matplotlib.colors import LinearSegmentedColormap
import plotly.express as px
import plotly.graph_objects as go
# from sklearn.linear_model import LinearRegression
import numpy as np
from huggingface_hub import HfApi
from huggingface_hub.hf_api import HTTPError
from huggingface_hub.utils import GatedRepoError
from gradio_rangeslider import RangeSlider
import datetime
from gradio.themes.utils.colors import slate
load_dotenv()
webhook_url = os.environ.get("WEBHOOK_URL")
file_name_list = [
"14b",
"9b",
"7b",
"3b",
"1b5",
"other",
]
sheet_name_list = [
"cr",
"bpc",
"bpb",
]
metric_list = [
"Compression Rate (%)",
"Bits Per Character (BPC)",
"Bits Per Byte (BPB)",
]
model_size_list = [
"~14B",
"~9B",
"~7B",
"~3B",
"~1.5B",
"Other",
]
metric_to_sheet = {
"Compression Rate (%)": "cr",
"Bits Per Character (BPC)": "bpc",
"Bits Per Byte (BPB)": "bpb",
}
model_size_to_file_name = {
"~14B": "14b",
"~9B": "9b",
"~7B": "7b",
"~3B": "3b",
"~1.5B": "1b5",
"Other": "other",
}
def read_about_md():
with open('about.md', 'r', encoding='utf-8') as f:
return f.read()
def rename_columns(df):
df.columns = [col.rsplit("_", maxsplit=1)[0] for col in df.columns]
return df
def get_folders_matching_format(directory):
pattern = re.compile(r"^\d{4}-\d{2}$")
folders = []
if not os.path.exists(directory):
return folders
for item in os.listdir(directory):
full_path = os.path.join(directory, item)
if os.path.isdir(full_path) and pattern.match(item):
folders.append(full_path)
return folders
def get_unique_column_names(data=None):
return [
"ao3_\u200benglish",
"bbc_\u200bnews",
"wikipedia_\u200benglish",
"arxiv_\u200bcomputer_\u200bscience",
"arxiv_\u200bphysics",
"github_\u200bcpp",
"github_\u200bpython",
]
def color_cell(value):
return "background-color: #fffdd0" if pd.notna(value) else "default"
# def color_cell_themed(value):
# return "background-color: rgba(255, 253, 208, 1.0)" if pd.notna(value) else "default"
# --- 核心改动点 1: 修改 update_table 函数 ---
# 添加 request: gr.Request = None 参数来接收主题模式信息
# 默认值为 None 是为了处理初始加载
def update_table(period: str, models_size: list, metric: str, visible_columns: list, color_columns: list, size_range: list, midpoint: float = 0.5, sort_by: str = "Average (lower=better)", ascending: bool = True, request: gr.Request = None):
# 打印日志并检查当前模式
is_dark_mode = request.is_dark if request else False
print(f"Updating - time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, period: {period}, models: {models_size}, metric: {metric}, visible_columns: {visible_columns}, color_columns: {color_columns}, size_range: {size_range}, sort_by: {sort_by}, ascending: {ascending}, is_dark: {is_dark_mode}\n")
if not models_size:
return "No data available for the selected models and period."
target_period_data = all_data[period]
target_file_name = [model_size_to_file_name[model] for model in models_size]
sheet_name = metric_to_sheet[metric]
combined_data = pd.concat([df.dropna(axis=1, how="all") for df in [target_period_data[file_name][sheet_name] for file_name in target_file_name]], axis=0)
if len(combined_data) == 0:
return "No data available for the selected models and period."
combined_data = combined_data[combined_data["Parameters Count (B)"].between(size_range[0], size_range[1])]
combined_data.reset_index(drop=True, inplace=True)
if len(combined_data) == 0:
return "No data available for the selected models and period."
combined_data["Name"] = combined_data["Name"].apply(lambda x: x.replace(".pth", ""))
ordered_columns = get_unique_column_names()
relevant_columns = [col for col in ordered_columns if col in visible_columns and col not in ["Name", "Parameters Count (B)", "Average (The lower the better)"]]
if len(combined_data) > 0 and relevant_columns:
combined_data["Average (The lower the better)"] = round(combined_data[relevant_columns].mean(axis=1), 3)
combined_data = combined_data.rename(columns={"Parameters Count (B)": "Params (B)", "Average (The lower the better)": "Average (lower=better)"})
sorted_data = combined_data.sort_values(by=sort_by, ascending=ascending)
visible_columns_final = ["Name", "Params (B)", "Average (lower=better)"] + relevant_columns
filtered_data = sorted_data[visible_columns_final]
filtered_data.columns = [col.replace("_", " ") for col in filtered_data.columns]
formatter = {col: "{:.3f}" for col in filtered_data.columns if filtered_data[col].dtype in ["float64", "float32"]}
# --- 核心改动点 2: 根据主题模式选择不同的配色方案 ---
if is_dark_mode:
# 夜间模式配色 (绿 -> 深灰 -> 红)
colors = ["#2ca02c", "#2b2b2b", "#d62728"]
else:
# 日间模式配色 (绿 -> 白 -> 红)
colors = ["#63be7b", "#ffffff", "#f8696b"]
vmin, vmax, vmid = {}, {}, {}
for column in filtered_data.columns:
if column in ["Name", "Params (B)"]: continue
col_values = filtered_data[column].dropna()
if len(col_values) > 1:
sorted_values = np.sort(col_values)
vmin[column] = sorted_values.min()
vmax[column] = sorted_values.max()
idx = int(len(sorted_values) * midpoint)
vmid[column] = sorted_values[idx]
# --- 核心改动点 3: 修改样式函数以包含固定的黑色字体 ---
def custom_background_gradient(series, cmap, vmin_val, vmax_val, vmid_val):
if len(series) == 0: return series
def normalize(x):
if pd.isna(x): return 0.5 # Neutral for NaN
if vmid_val == vmin_val and x <= vmid_val: return 0.0
if vmid_val == vmax_val and x >= vmid_val: return 1.0
if vmid_val == vmin_val or vmid_val == vmax_val: return 0.5
if x <= vmid_val:
return 0.5 * (x - vmin_val) / (vmid_val - vmin_val)
else:
return 0.5 + 0.5 * (x - vmid_val) / (vmax_val - vmid_val)
normed = series.apply(normalize)
cmap_colors = [cmap(x) for x in normed]
# 在返回的CSS中同时设置 background-color 和 color
return [
"background-color: rgba({}, {}, {}, {}); color: black;".format(*[int(255 * c) for c in color[:3]], color[3])
for color in cmap_colors
]
target_color_columns = []
if "Average" in color_columns: target_color_columns.append("Average (lower=better)")
if "Individual Tests" in color_columns: target_color_columns.extend([col for col in filtered_data.columns if col not in ["Name", "Params (B)", "Average (lower=better)"]])
def color_params_column_dynamic(value):
if not pd.notna(value):
return "default"
# 2. 根据 is_dark_mode 返回不同的颜色
if is_dark_mode:
# 为夜间模式选择一个柔和、不刺眼的暗金色
# 字体颜色也设置为浅色以保证对比度
return "background-color: #4b4936; color: #f0f0f0;"
else:
# 为日间模式使用明亮的奶油色,字体为黑色
return "background-color: #fffdd0; color: black;"
styler = filtered_data.style.format(formatter).map(color_params_column_dynamic, subset=["Params (B)"])
for column in target_color_columns:
if column in vmin:
custom_cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
styler = styler.apply(custom_background_gradient, cmap=custom_cmap, vmin_val=vmin[column], vmax_val=vmax[column], vmid_val=vmid[column], subset=[column])
styler = styler.hide(axis="index")
widths = [300, 150, 150, 100, 100, 100, 100, 100, 100, 100, 100]
table_styles = []
table_styles.append({"selector": "th", "props": [("background-color", "var(--background-fill-secondary)"), ("color", "var(--body-text-color)"), ("padding", "8px"), ("font-weight", "bold")]})
table_styles.append({"selector": "table", "props": [("border-collapse", "collapse"), ("border", f"1px solid var(--border-color-primary)")]})
for i, w in enumerate(widths):
table_styles.append({"selector": f"th.col{i}, td.col{i}", "props": [("min-width", f"{w}px"), ("max-width", f"{w}px"), ("text-align", "center"), ("border", f"1px solid var(--border-color-primary)")]})
styler = styler.set_table_styles(table_styles)
return styler.to_html()
def create_world_languages_gdp_chart():
languages = ["English", "Chinese", "Spanish", "Japanese", "German", "French", "Arabic", "Italian", "Portuguese", "Korean", "Other"]
shares = [27, 18, 8, 6, 5, 4, 3, 2, 2, 2, 23]
colors = ["#FF7F7F", "#FFA07A", "#FFDB58", "#90EE90", "#98FB98", "#87CEFA", "#B0C4DE", "#DDA0DD", "#D8BFD8", "#F0E68C", "#E0FFFF"]
fig = go.Figure(
data=[
go.Pie(
labels=languages,
values=shares,
hole=0.3,
marker=dict(colors=colors, line=dict(color="#FFFFFF", width=2)),
textinfo="label+percent",
textposition="outside",
insidetextorientation="radial",
textfont=dict(size=12),
)
]
)
fig.update_layout(
title={
"text": "World Languages by Share of Global GDP",
"y": 0.95,
"x": 0.5,
"xanchor": "center",
"yanchor": "top",
"font": dict(size=20, color="black"),
},
showlegend=False,
width=700,
height=500,
margin=dict(t=80, b=20, l=20, r=20),
)
return fig
def check_model_exists(model_id):
api = HfApi()
try:
model_info = api.model_info(model_id)
return "Exists and is accessible"
except GatedRepoError:
return "Exists but is restricted"
except HTTPError as e:
if e.response.status_code == 404:
return "Does not exist"
else:
return "Error: " + str(e)
def submit_model(name):
if "Exists" not in check_model_exists(name):
return f"# ERROR: Model {name} does not exist on Hugging Face!"
try:
response = requests.post(webhook_url, json={"content": name})
if response.status_code == 200:
response_data = response.json()
if response_data.get("status") == "success":
return "# SUCCESS: We will check the model as soon as possible. Thank you for your submission!"
else:
return f"# ERROR: {response_data.get('message', 'Unknown error')}"
else:
return f"# ERROR: Failed to submit model {name}. Server returned status code {response.status_code}."
except requests.exceptions.HTTPError:
return "# ERROR: Network error while contacting queue. Please try again in a few minutes."
except Exception as e:
print(e)
return "ERROR: Unexpected error. Please try again later."
def create_scaling_plot(all_data, period):
selected_columns = ["Name", "Parameters Count (B)", "Average (The lower the better)"]
target_data = all_data[period]
new_df = pd.DataFrame()
for size in target_data.keys():
new_df = pd.concat([new_df, target_data[size]["cr"].loc[:, selected_columns].dropna(axis=1, how="all")], axis=0)
x_values = new_df["Parameters Count (B)"].astype(float).tolist()
y_values = new_df["Average (The lower the better)"].astype(float).tolist()
names = new_df["Name"].tolist()
x_min, x_max = np.log10(min(x_values)), np.log10(max(x_values))
y_min, y_max = np.log10(min(y_values)), np.log10(max(y_values))
x_dtick = (x_max - x_min) / 4
y_dtick = (y_max - y_min) / 4
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x_values,
y=y_values,
mode="markers",
name="Models",
marker=dict(size=12, color="#39C5BB", opacity=0.8),
text=names,
customdata=list(zip(x_values, y_values)),
hovertemplate=(
"%{text}
" + "Params: %{customdata[0]:.2f}B
" + "Compression Rate: %{customdata[1]:.2f}%
" + "