Anonumous's picture
Update utils.py
120c53a verified
import json
import pandas as pd
from statistics import mean
from huggingface_hub import HfApi, create_repo
from datasets import load_dataset, Dataset
from datasets.data_files import EmptyDatasetError
import re
from constants import (
REPO_ID,
HF_TOKEN,
DATASETS,
SHORT_DATASET_NAMES,
DATASET_DESCRIPTIONS,
)
api = HfApi(token=HF_TOKEN)
OPEN_LICENSE_KEYWORDS = {
"mit", "apache", "apache-2", "apache-2.0",
"bsd", "bsd-2", "bsd-3", "bsd-2-clause", "bsd-3-clause",
"isc", "mpl", "mpl-2.0",
"lgpl", "lgpl-2.1", "lgpl-3.0",
"gpl", "gpl-2.0", "gpl-3.0", "agpl", "agpl-3.0",
"epl", "epl-2.0", "cddl", "cddl-1.0", "cddl-1.1",
"bsl", "bsl-1.0", "boost", "zlib", "unlicense", "artistic-2.0",
"cc0", "cc0-1.0",
"cc-by", "cc-by-3.0", "cc-by-4.0",
"cc-by-sa", "cc-by-sa-3.0", "cc-by-sa-4.0",
"openrail", "openrail-m", "bigscience openrail", "bigscience openrail-m",
"open-source", "opensource", "open source"
}
RESTRICTIVE_LICENSE_KEYWORDS = {
"cc-by-nc", "cc-by-nc-sa", "cc-nc", "nc-sa", "nc-nd",
"cc-by-nd", "cc-nd", "no-derivatives", "no derivatives",
"non-commercial", "noncommercial", "research-only", "research only",
"llama", "llama-2", "community license",
"proprietary", "closed", "unknown", "custom"
}
def is_open_license(license_str: str) -> bool:
s = (str(license_str) if license_str is not None else "").strip().lower()
if not s:
return False
if any(pat in s for pat in RESTRICTIVE_LICENSE_KEYWORDS):
return False
return any(pat in s for pat in OPEN_LICENSE_KEYWORDS)
def init_repo():
try:
api.repo_info(REPO_ID, repo_type="dataset")
except:
create_repo(REPO_ID, repo_type="dataset", private=True, token=HF_TOKEN)
def load_data():
columns = (
["model_name", "link", "license", "overall_wer", "overall_cer"]
+ [f"wer_{ds}" for ds in DATASETS]
+ [f"cer_{ds}" for ds in DATASETS]
)
try:
dataset = load_dataset(REPO_ID, token=HF_TOKEN)
df = dataset["train"].to_pandas()
except EmptyDatasetError:
df = pd.DataFrame(columns=columns)
if not df.empty:
df = df.sort_values("overall_wer").reset_index(drop=True)
df.insert(0, "rank", df.index + 1)
for col in (
["overall_wer", "overall_cer"]
+ [f"wer_{ds}" for ds in DATASETS]
+ [f"cer_{ds}" for ds in DATASETS]
):
df[col] = (df[col] * 100).round(2)
best_values = {ds: df[f"wer_{ds}"].min() for ds in DATASETS}
for short_ds, ds in zip(SHORT_DATASET_NAMES, DATASETS):
df[short_ds] = df.apply(
lambda row: f'<span title="CER: {row[f"cer_{ds}"]:.2f}%" '
f'class="metric-cell{" best-metric" if row[f"wer_{ds}"] == best_values[ds] else ""}">'
f"{row[f'wer_{ds}']:.2f}%</span>",
axis=1,
)
df = df.drop(columns=[f"wer_{ds}", f"cer_{ds}"])
df["model_name"] = df.apply(
lambda row: f'<a href="{row["link"]}" target="_blank">{row["model_name"]}</a>',
axis=1,
)
df = df.drop(columns=["link"])
df["license"] = df["license"].apply(lambda x: "Открытая" if is_open_license(x) else "Закрытая")
df["rank"] = df["rank"].apply(
lambda r: "🥇" if r == 1 else "🥈" if r == 2 else "🥉" if r == 3 else str(r)
)
df.rename(
columns={
"overall_wer": "Средний WER ⬇️",
"overall_cer": "Средний CER ⬇️",
"license": "Тип модели",
"model_name": "Модель",
"rank": "Ранг",
},
inplace=True,
)
table_html = df.to_html(
escape=False, index=False, classes="display cell-border compact stripe"
)
return f'<div class="leaderboard-wrapper"><div class="leaderboard-table">{table_html}</div></div>'
else:
return (
'<div class="leaderboard-wrapper"><div class="leaderboard-table"><table><thead><tr><th>Ранг</th><th>Модель</th><th>Тип модели</th><th>Средний WER ⬇️</th><th>Средний CER ⬇️</th>'
+ "".join(f"<th>{short}</th>" for short in SHORT_DATASET_NAMES)
+ "</tr></thead><tbody></tbody></table></div></div>"
)
def process_submit(json_str):
columns = (
["model_name", "link", "license", "overall_wer", "overall_cer"]
+ [f"wer_{ds}" for ds in DATASETS]
+ [f"cer_{ds}" for ds in DATASETS]
)
try:
data = json.loads(json_str)
required_keys = ["model_name", "link", "license", "metrics"]
if not all(key in data for key in required_keys):
raise ValueError(
"Неверная структура JSON. Требуемые поля: model_name, link, license, metrics"
)
metrics = data["metrics"]
if set(metrics.keys()) != set(DATASETS):
raise ValueError(
f"Метрики должны быть для всех датасетов: {', '.join(DATASETS)}"
)
wers, cers = [], []
row = {
"model_name": data["model_name"],
"link": data["link"],
"license": data["license"],
}
for ds in DATASETS:
if "wer" not in metrics[ds] or "cer" not in metrics[ds]:
raise ValueError(f"Для {ds} требуются wer и cer")
row[f"wer_{ds}"] = metrics[ds]["wer"]
row[f"cer_{ds}"] = metrics[ds]["cer"]
wers.append(metrics[ds]["wer"])
cers.append(metrics[ds]["cer"])
row["overall_wer"] = mean(wers)
row["overall_cer"] = mean(cers)
try:
dataset = load_dataset(REPO_ID, token=HF_TOKEN)
df = dataset["train"].to_pandas()
except EmptyDatasetError:
df = pd.DataFrame(columns=columns)
new_df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
new_dataset = Dataset.from_pandas(new_df)
new_dataset.push_to_hub(REPO_ID, token=HF_TOKEN)
updated_html = load_data()
return updated_html, "Успешно добавлено!", ""
except Exception as e:
return None, f"Ошибка: {str(e)}", json_str
def get_datasets_description():
html = '<div class="datasets-container">'
for short_ds, info in DATASET_DESCRIPTIONS.items():
html += f"""
<div class="dataset-card">
<h3>{short_ds} <span class="full-name">{info["full_name"]}</span></h3>
<p>{info["description"]}</p>
<p class="records">📊 {info["num_rows"]} записей</p>
</div>
"""
html += "</div>"
return html
def _strip_punct(text: str) -> str:
return re.sub(r"[^\w\s]+", "", text, flags=re.UNICODE)
def normalize_text(s: str) -> str:
return _strip_punct(s.lower()).strip()
def _edit_distance(a, b):
n, m = len(a), len(b)
dp = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(n + 1):
dp[i][0] = i
for j in range(m + 1):
dp[0][j] = j
for i in range(1, n + 1):
ai = a[i - 1]
for j in range(1, m + 1):
cost = 0 if ai == b[j - 1] else 1
dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost)
return dp[n][m]
def compute_wer_cer(ref: str, hyp: str, normalize: bool = True):
if normalize:
ref_norm, hyp_norm = normalize_text(ref), normalize_text(hyp)
else:
ref_norm, hyp_norm = ref, hyp
ref_words, hyp_words = ref_norm.split(), hyp_norm.split()
Nw = max(1, len(ref_words))
wer = _edit_distance(ref_words, hyp_words) / Nw
ref_chars, hyp_chars = list(ref_norm), list(hyp_norm)
Nc = max(1, len(ref_chars))
cer = _edit_distance(ref_chars, hyp_chars) / Nc
return round(wer * 100, 2), round(cer * 100, 2)
def get_metrics_html():
return """
<div class="metrics-grid">
<div class="metric-card">
<h3>WER — Word Error Rate</h3>
<div class="formula">WER = ( <span>S</span> + <span>D</span> + <span>I</span> ) / <span>N</span></div>
<div class="chips">
<div class="chip"><b>S</b><small>замены</small></div>
<div class="chip"><b>D</b><small>удаления</small></div>
<div class="chip"><b>I</b><small>вставки</small></div>
<div class="chip"><b>N</b><small>слов в референсе</small></div>
</div>
</div>
<div class="metric-card">
<h3>CER — Character Error Rate</h3>
<div class="formula">CER = ( <span>S</span> + <span>D</span> + <span>I</span> ) / <span>N</span></div>
<div class="chips">
<div class="chip"><b>S, D, I</b><small>операции редактирования</small></div>
<div class="chip"><b>N</b><small>символов в референсе</small></div>
</div>
</div>
<div class="metric-card">
<h3>Нормализация</h3>
<p class="metric-text">Перед расчётом приводим текст к нижнему регистру и удаляем пунктуацию.</p>
</div>
<div class="metric-card">
<h3>Сравнение</h3>
<p class="metric-text">Сортировка по среднему WER по всем датасетам. Метрики отображаются в процентах.</p>
</div>
</div>
"""
def get_submit_html():
return """
<div class="submit-grid">
<div class="form-card">
<h3>Общая информация</h3>
<ul>
<li><b>Название модели</b> — коротко и понятно.</li>
<li><b>Ссылка</b> — HuggingFace, GitHub или сайт.</li>
<li><b>Лицензия</b> — MIT, Apache-2.0, GPL или Closed.</li>
</ul>
</div>
<div class="form-card">
<h3>Метрики</h3>
<p>Укажите WER и CER для всех датасетов в формате JSON. Значения — от 0 до 1.</p>
<pre class="code-block json">{
<span class="key">"Russian_LibriSpeech"</span>: { <span class="key">"wer"</span>: <span class="number">0.1234</span>, <span class="key">"cer"</span>: <span class="number">0.0567</span> },
<span class="key">"Common_Voice_Corpus_22.0"</span>: { <span class="key">"wer"</span>: <span class="number">0.2345</span>, <span class="key">"cer"</span>: <span class="number">0.0789</span> },
<span class="key">"Tone_Webinars"</span>: { <span class="key">"wer"</span>: <span class="number">0.3456</span>, <span class="key">"cer"</span>: <span class="number">0.0987</span> },
<span class="key">"Tone_Books"</span>: { <span class="key">"wer"</span>: <span class="number">0.4567</span>, <span class="key">"cer"</span>: <span class="number">0.1098</span> },
<span class="key">"Tone_Speak"</span>: { <span class="key">"wer"</span>: <span class="number">0.5678</span>, <span class="key">"cer"</span>: <span class="number">0.1209</span> },
<span class="key">"Sova_RuDevices"</span>: { <span class="key">"wer"</span>: <span class="number">0.6789</span>, <span class="key">"cer"</span>: <span class="number">0.1310</span> }
}</pre>
</div>
</div>
"""