|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from huggingface_hub import HfApi, Repository |
|
|
|
plt.rcParams.update({ |
|
"font.family": "sans-serif", |
|
"font.size": 10, |
|
}) |
|
|
|
|
|
df = pd.DataFrame() |
|
|
|
def upload_csv(file): |
|
global df |
|
df = pd.read_csv(file.name) |
|
if "text" not in df.columns or "label" not in df.columns: |
|
return ( |
|
None, |
|
"❌ CSV must contain 'text' and 'label' columns.", |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
df["label"] = df["label"].fillna("") |
|
return ( |
|
df[["text","label"]], |
|
"✅ File uploaded — you can now annotate and use the buttons below.", |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
) |
|
|
|
def save_changes(table): |
|
global df |
|
df = pd.DataFrame(table, columns=["text","label"]) |
|
return "💾 Changes saved." |
|
|
|
def download_csv(): |
|
global df |
|
path = "annotated_data.csv" |
|
df.to_csv(path, index=False) |
|
return path |
|
|
|
def create_distribution_figure(): |
|
global df |
|
counts = df["label"].value_counts().sort_values(ascending=False) |
|
labels, values = counts.index.tolist(), counts.values.tolist() |
|
|
|
fig, (ax_table, ax_bar) = plt.subplots( |
|
ncols=2, |
|
gridspec_kw={"width_ratios": [1,2]}, |
|
figsize=(8, max(2, len(labels)*0.4)), |
|
tight_layout=True |
|
) |
|
|
|
ax_table.axis("off") |
|
data = [[l,v] for l,v in zip(labels, values)] |
|
tbl = ax_table.table(cellText=data, colLabels=["Label","Count"], loc="center") |
|
tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1,1.2) |
|
|
|
ax_bar.barh(labels, values, color="#222") |
|
ax_bar.invert_yaxis(); ax_bar.set_xlabel("Count") |
|
return fig |
|
|
|
def visualize_and_download_chart(): |
|
fig = create_distribution_figure() |
|
out_path = "label_distribution.png" |
|
fig.savefig(out_path, dpi=150, bbox_inches="tight") |
|
return fig, out_path |
|
|
|
def push_to_hub(repo_name, hf_token): |
|
global df |
|
try: |
|
api = HfApi() |
|
api.create_repo(repo_id=repo_name, token=hf_token, |
|
repo_type="dataset", exist_ok=True) |
|
local_dir = Path(f"./{repo_name.replace('/','_')}") |
|
if local_dir.exists(): |
|
for f in local_dir.iterdir(): f.unlink() |
|
local_dir.rmdir() |
|
repo = Repository( |
|
local_dir=str(local_dir), |
|
clone_from=repo_name, |
|
repo_type="dataset", |
|
use_auth_token=hf_token |
|
) |
|
df.to_csv(local_dir/"data.csv", index=False) |
|
repo.push_to_hub(commit_message="📑 Update annotated data") |
|
return f"🚀 Pushed to https://huggingface.co/datasets/{repo_name}" |
|
except Exception as e: |
|
return f"❌ Push failed: {e}" |
|
|
|
with gr.Blocks(theme=gr.themes.Default()) as app: |
|
gr.Markdown("## 🏷️ Label It! Text Annotation Tool\n" |
|
"Upload a `.csv` (with **text** + **label** columns), " |
|
"then annotate, export, visualize, or publish.") |
|
|
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="📁 Upload CSV", file_types=[".csv"]) |
|
upload_btn = gr.Button("Upload") |
|
|
|
|
|
table = gr.Dataframe(headers=["text","label"], interactive=True, visible=False) |
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.Row(visible=False) as action_row: |
|
save_btn = gr.Button("💾 Save") |
|
download_btn = gr.Button("⬇️ Download CSV") |
|
visualize_btn= gr.Button("📊 Visualize Distribution") |
|
download_csv_out = gr.File(label="📥 Download CSV") |
|
chart_plot = gr.Plot(label="Label Distribution") |
|
download_chart_out = gr.File(label="📥 Download Chart") |
|
|
|
|
|
push_acc = gr.Accordion("📦 Push to Hugging Face Hub", open=False, visible=False) |
|
with push_acc: |
|
repo_in = gr.Textbox(label="Repo (username/dataset-name)") |
|
token_in = gr.Textbox(label="🔑 HF Token", type="password") |
|
push_btn = gr.Button("🚀 Push") |
|
push_status = gr.Textbox(label="Push Status", interactive=False) |
|
|
|
|
|
upload_btn.click( |
|
upload_csv, |
|
inputs=file_input, |
|
outputs=[table, status, |
|
save_btn, download_btn, visualize_btn, push_acc] |
|
) |
|
save_btn.click(save_changes, inputs=table, outputs=status) |
|
download_btn.click(download_csv, outputs=download_csv_out) |
|
visualize_btn.click(visualize_and_download_chart, |
|
outputs=[chart_plot, download_chart_out]) |
|
push_btn.click(push_to_hub, inputs=[repo_in, token_in], outputs=push_status) |
|
|
|
app.launch() |
|
|