|
import json |
|
import pandas as pd |
|
from statistics import mean |
|
from huggingface_hub import HfApi, create_repo |
|
from datasets import load_dataset, Dataset |
|
from datasets.data_files import EmptyDatasetError |
|
import re |
|
|
|
from constants import ( |
|
REPO_ID, |
|
HF_TOKEN, |
|
DATASETS, |
|
SHORT_DATASET_NAMES, |
|
DATASET_DESCRIPTIONS, |
|
) |
|
|
|
api = HfApi(token=HF_TOKEN) |
|
|
|
|
|
OPEN_LICENSE_KEYWORDS = { |
|
"mit", "apache", "apache-2", "apache-2.0", |
|
"bsd", "bsd-2", "bsd-3", "bsd-2-clause", "bsd-3-clause", |
|
"isc", "mpl", "mpl-2.0", |
|
"lgpl", "lgpl-2.1", "lgpl-3.0", |
|
"gpl", "gpl-2.0", "gpl-3.0", "agpl", "agpl-3.0", |
|
"epl", "epl-2.0", "cddl", "cddl-1.0", "cddl-1.1", |
|
"bsl", "bsl-1.0", "boost", "zlib", "unlicense", "artistic-2.0", |
|
"cc0", "cc0-1.0", |
|
"cc-by", "cc-by-3.0", "cc-by-4.0", |
|
"cc-by-sa", "cc-by-sa-3.0", "cc-by-sa-4.0", |
|
"openrail", "openrail-m", "bigscience openrail", "bigscience openrail-m", |
|
"open-source", "opensource", "open source" |
|
} |
|
|
|
RESTRICTIVE_LICENSE_KEYWORDS = { |
|
"cc-by-nc", "cc-by-nc-sa", "cc-nc", "nc-sa", "nc-nd", |
|
"cc-by-nd", "cc-nd", "no-derivatives", "no derivatives", |
|
"non-commercial", "noncommercial", "research-only", "research only", |
|
"llama", "llama-2", "community license", |
|
"proprietary", "closed", "unknown", "custom" |
|
} |
|
|
|
def is_open_license(license_str: str) -> bool: |
|
s = (str(license_str) if license_str is not None else "").strip().lower() |
|
if not s: |
|
return False |
|
if any(pat in s for pat in RESTRICTIVE_LICENSE_KEYWORDS): |
|
return False |
|
return any(pat in s for pat in OPEN_LICENSE_KEYWORDS) |
|
|
|
|
|
def init_repo(): |
|
try: |
|
api.repo_info(REPO_ID, repo_type="dataset") |
|
except: |
|
create_repo(REPO_ID, repo_type="dataset", private=True, token=HF_TOKEN) |
|
|
|
|
|
def load_data(): |
|
columns = ( |
|
["model_name", "link", "license", "overall_wer", "overall_cer"] |
|
+ [f"wer_{ds}" for ds in DATASETS] |
|
+ [f"cer_{ds}" for ds in DATASETS] |
|
) |
|
try: |
|
dataset = load_dataset(REPO_ID, token=HF_TOKEN) |
|
df = dataset["train"].to_pandas() |
|
except EmptyDatasetError: |
|
df = pd.DataFrame(columns=columns) |
|
|
|
if not df.empty: |
|
df = df.sort_values("overall_wer").reset_index(drop=True) |
|
df.insert(0, "rank", df.index + 1) |
|
for col in ( |
|
["overall_wer", "overall_cer"] |
|
+ [f"wer_{ds}" for ds in DATASETS] |
|
+ [f"cer_{ds}" for ds in DATASETS] |
|
): |
|
df[col] = (df[col] * 100).round(2) |
|
|
|
best_values = {ds: df[f"wer_{ds}"].min() for ds in DATASETS} |
|
for short_ds, ds in zip(SHORT_DATASET_NAMES, DATASETS): |
|
df[short_ds] = df.apply( |
|
lambda row: f'<span title="CER: {row[f"cer_{ds}"]:.2f}%" ' |
|
f'class="metric-cell{" best-metric" if row[f"wer_{ds}"] == best_values[ds] else ""}">' |
|
f"{row[f'wer_{ds}']:.2f}%</span>", |
|
axis=1, |
|
) |
|
df = df.drop(columns=[f"wer_{ds}", f"cer_{ds}"]) |
|
|
|
df["model_name"] = df.apply( |
|
lambda row: f'<a href="{row["link"]}" target="_blank">{row["model_name"]}</a>', |
|
axis=1, |
|
) |
|
df = df.drop(columns=["link"]) |
|
|
|
df["license"] = df["license"].apply(lambda x: "Открытая" if is_open_license(x) else "Закрытая") |
|
|
|
df["rank"] = df["rank"].apply( |
|
lambda r: "🥇" if r == 1 else "🥈" if r == 2 else "🥉" if r == 3 else str(r) |
|
) |
|
|
|
df.rename( |
|
columns={ |
|
"overall_wer": "Средний WER ⬇️", |
|
"overall_cer": "Средний CER ⬇️", |
|
"license": "Тип модели", |
|
"model_name": "Модель", |
|
"rank": "Ранг", |
|
}, |
|
inplace=True, |
|
) |
|
|
|
table_html = df.to_html( |
|
escape=False, index=False, classes="display cell-border compact stripe" |
|
) |
|
return f'<div class="leaderboard-wrapper"><div class="leaderboard-table">{table_html}</div></div>' |
|
else: |
|
return ( |
|
'<div class="leaderboard-wrapper"><div class="leaderboard-table"><table><thead><tr><th>Ранг</th><th>Модель</th><th>Тип модели</th><th>Средний WER ⬇️</th><th>Средний CER ⬇️</th>' |
|
+ "".join(f"<th>{short}</th>" for short in SHORT_DATASET_NAMES) |
|
+ "</tr></thead><tbody></tbody></table></div></div>" |
|
) |
|
|
|
|
|
def process_submit(json_str): |
|
columns = ( |
|
["model_name", "link", "license", "overall_wer", "overall_cer"] |
|
+ [f"wer_{ds}" for ds in DATASETS] |
|
+ [f"cer_{ds}" for ds in DATASETS] |
|
) |
|
try: |
|
data = json.loads(json_str) |
|
required_keys = ["model_name", "link", "license", "metrics"] |
|
if not all(key in data for key in required_keys): |
|
raise ValueError( |
|
"Неверная структура JSON. Требуемые поля: model_name, link, license, metrics" |
|
) |
|
metrics = data["metrics"] |
|
if set(metrics.keys()) != set(DATASETS): |
|
raise ValueError( |
|
f"Метрики должны быть для всех датасетов: {', '.join(DATASETS)}" |
|
) |
|
wers, cers = [], [] |
|
row = { |
|
"model_name": data["model_name"], |
|
"link": data["link"], |
|
"license": data["license"], |
|
} |
|
for ds in DATASETS: |
|
if "wer" not in metrics[ds] or "cer" not in metrics[ds]: |
|
raise ValueError(f"Для {ds} требуются wer и cer") |
|
row[f"wer_{ds}"] = metrics[ds]["wer"] |
|
row[f"cer_{ds}"] = metrics[ds]["cer"] |
|
wers.append(metrics[ds]["wer"]) |
|
cers.append(metrics[ds]["cer"]) |
|
row["overall_wer"] = mean(wers) |
|
row["overall_cer"] = mean(cers) |
|
|
|
try: |
|
dataset = load_dataset(REPO_ID, token=HF_TOKEN) |
|
df = dataset["train"].to_pandas() |
|
except EmptyDatasetError: |
|
df = pd.DataFrame(columns=columns) |
|
|
|
new_df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) |
|
new_dataset = Dataset.from_pandas(new_df) |
|
new_dataset.push_to_hub(REPO_ID, token=HF_TOKEN) |
|
|
|
updated_html = load_data() |
|
return updated_html, "Успешно добавлено!", "" |
|
except Exception as e: |
|
return None, f"Ошибка: {str(e)}", json_str |
|
|
|
|
|
def get_datasets_description(): |
|
html = '<div class="datasets-container">' |
|
for short_ds, info in DATASET_DESCRIPTIONS.items(): |
|
html += f""" |
|
<div class="dataset-card"> |
|
<h3>{short_ds} <span class="full-name">{info["full_name"]}</span></h3> |
|
<p>{info["description"]}</p> |
|
<p class="records">📊 {info["num_rows"]} записей</p> |
|
</div> |
|
""" |
|
html += "</div>" |
|
return html |
|
|
|
|
|
def _strip_punct(text: str) -> str: |
|
return re.sub(r"[^\w\s]+", "", text, flags=re.UNICODE) |
|
|
|
|
|
def normalize_text(s: str) -> str: |
|
return _strip_punct(s.lower()).strip() |
|
|
|
|
|
def _edit_distance(a, b): |
|
n, m = len(a), len(b) |
|
dp = [[0] * (m + 1) for _ in range(n + 1)] |
|
for i in range(n + 1): |
|
dp[i][0] = i |
|
for j in range(m + 1): |
|
dp[0][j] = j |
|
for i in range(1, n + 1): |
|
ai = a[i - 1] |
|
for j in range(1, m + 1): |
|
cost = 0 if ai == b[j - 1] else 1 |
|
dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost) |
|
return dp[n][m] |
|
|
|
|
|
def compute_wer_cer(ref: str, hyp: str, normalize: bool = True): |
|
if normalize: |
|
ref_norm, hyp_norm = normalize_text(ref), normalize_text(hyp) |
|
else: |
|
ref_norm, hyp_norm = ref, hyp |
|
ref_words, hyp_words = ref_norm.split(), hyp_norm.split() |
|
Nw = max(1, len(ref_words)) |
|
wer = _edit_distance(ref_words, hyp_words) / Nw |
|
ref_chars, hyp_chars = list(ref_norm), list(hyp_norm) |
|
Nc = max(1, len(ref_chars)) |
|
cer = _edit_distance(ref_chars, hyp_chars) / Nc |
|
return round(wer * 100, 2), round(cer * 100, 2) |
|
|
|
|
|
def get_metrics_html(): |
|
return """ |
|
<div class="metrics-grid"> |
|
<div class="metric-card"> |
|
<h3>WER — Word Error Rate</h3> |
|
<div class="formula">WER = ( <span>S</span> + <span>D</span> + <span>I</span> ) / <span>N</span></div> |
|
<div class="chips"> |
|
<div class="chip"><b>S</b><small>замены</small></div> |
|
<div class="chip"><b>D</b><small>удаления</small></div> |
|
<div class="chip"><b>I</b><small>вставки</small></div> |
|
<div class="chip"><b>N</b><small>слов в референсе</small></div> |
|
</div> |
|
</div> |
|
<div class="metric-card"> |
|
<h3>CER — Character Error Rate</h3> |
|
<div class="formula">CER = ( <span>S</span> + <span>D</span> + <span>I</span> ) / <span>N</span></div> |
|
<div class="chips"> |
|
<div class="chip"><b>S, D, I</b><small>операции редактирования</small></div> |
|
<div class="chip"><b>N</b><small>символов в референсе</small></div> |
|
</div> |
|
</div> |
|
<div class="metric-card"> |
|
<h3>Нормализация</h3> |
|
<p class="metric-text">Перед расчётом приводим текст к нижнему регистру и удаляем пунктуацию.</p> |
|
</div> |
|
<div class="metric-card"> |
|
<h3>Сравнение</h3> |
|
<p class="metric-text">Сортировка по среднему WER по всем датасетам. Метрики отображаются в процентах.</p> |
|
</div> |
|
</div> |
|
""" |
|
|
|
|
|
def get_submit_html(): |
|
return """ |
|
<div class="submit-grid"> |
|
<div class="form-card"> |
|
<h3>Общая информация</h3> |
|
<ul> |
|
<li><b>Название модели</b> — коротко и понятно.</li> |
|
<li><b>Ссылка</b> — HuggingFace, GitHub или сайт.</li> |
|
<li><b>Лицензия</b> — MIT, Apache-2.0, GPL или Closed.</li> |
|
</ul> |
|
</div> |
|
<div class="form-card"> |
|
<h3>Метрики</h3> |
|
<p>Укажите WER и CER для всех датасетов в формате JSON. Значения — от 0 до 1.</p> |
|
<pre class="code-block json">{ |
|
<span class="key">"Russian_LibriSpeech"</span>: { <span class="key">"wer"</span>: <span class="number">0.1234</span>, <span class="key">"cer"</span>: <span class="number">0.0567</span> }, |
|
<span class="key">"Common_Voice_Corpus_22.0"</span>: { <span class="key">"wer"</span>: <span class="number">0.2345</span>, <span class="key">"cer"</span>: <span class="number">0.0789</span> }, |
|
<span class="key">"Tone_Webinars"</span>: { <span class="key">"wer"</span>: <span class="number">0.3456</span>, <span class="key">"cer"</span>: <span class="number">0.0987</span> }, |
|
<span class="key">"Tone_Books"</span>: { <span class="key">"wer"</span>: <span class="number">0.4567</span>, <span class="key">"cer"</span>: <span class="number">0.1098</span> }, |
|
<span class="key">"Tone_Speak"</span>: { <span class="key">"wer"</span>: <span class="number">0.5678</span>, <span class="key">"cer"</span>: <span class="number">0.1209</span> }, |
|
<span class="key">"Sova_RuDevices"</span>: { <span class="key">"wer"</span>: <span class="number">0.6789</span>, <span class="key">"cer"</span>: <span class="number">0.1310</span> } |
|
}</pre> |
|
</div> |
|
</div> |
|
""" |
|
|