labelit / app.py
Suzana's picture
Update app.py
9e6c3bb verified
raw
history blame
4.49 kB
import gradio as gr
import pandas as pd
import io
import os
from pathlib import Path
from huggingface_hub import HfApi, Repository
import matplotlib.pyplot as plt
# Global DataFrame
df = pd.DataFrame()
def upload_csv(file):
global df
df = pd.read_csv(file.name)
if "text" not in df.columns or "label" not in df.columns:
return gr.update(visible=False), "❌ CSV must contain ‘text’ and ‘label’ columns."
df["label"] = df["label"].fillna("")
return (
gr.update(value=df[["text","label"]], visible=True),
"✅ File uploaded — you can now edit labels."
)
def save_changes(edited_table):
global df
df = pd.DataFrame(edited_table, columns=["text","label"])
return "💾 Changes saved."
def download_csv():
global df
out_path = "annotated_data.csv"
df.to_csv(out_path, index=False)
return out_path
def create_distribution_figure(df_input):
counts = df_input["label"].value_counts().sort_values(ascending=False)
labels = counts.index.tolist()
values = counts.values.tolist()
fig, (ax_table, ax_bar) = plt.subplots(
nrows=1, ncols=2,
gridspec_kw={"width_ratios": [1, 2]},
figsize=(8, max(2, len(labels) * 0.3))
)
# Table
ax_table.axis("off")
table_data = [[lab, cnt] for lab, cnt in zip(labels, values)]
tbl = ax_table.table(cellText=table_data, colLabels=["Label","Count"], loc="center")
tbl.auto_set_font_size(False)
tbl.set_fontsize(10)
tbl.scale(1, 1.5)
# Bar chart
ax_bar.barh(labels, values)
ax_bar.invert_yaxis()
ax_bar.set_xlabel("Count")
ax_bar.set_ylabel("")
plt.tight_layout()
return fig
def visualize_and_download_chart():
global df
fig = create_distribution_figure(df)
chart_path = "label_distribution.png"
fig.savefig(chart_path, dpi=150)
return fig, chart_path
def push_to_hub(repo_name: str, hf_token: str) -> str:
global df
try:
api = HfApi()
api.create_repo(repo_id=repo_name, token=hf_token, repo_type="dataset", exist_ok=True)
local_dir = Path(f"./{repo_name.replace('/', '_')}")
if local_dir.exists():
for child in local_dir.iterdir(): child.unlink()
local_dir.rmdir()
repo = Repository(
local_dir=str(local_dir),
clone_from=repo_name,
repo_type="dataset",
use_auth_token=hf_token
)
csv_path = local_dir / "data.csv"
df.to_csv(csv_path, index=False)
repo.push_to_hub(commit_message="📑 Update annotated data")
return f"🚀 Pushed to https://huggingface.co/datasets/{repo_name}"
except Exception as e:
return f"❌ Push failed: {e}"
with gr.Blocks(theme=gr.themes.Default()) as app:
gr.Markdown("## 🏷️ Label It! Text Annotation Tool")
gr.Markdown("Upload a `.csv` with **text** + **label** columns, annotate in-place, then export, visualize, or publish.")
with gr.Row():
file_input = gr.File(label="📁 Upload CSV", file_types=[".csv"])
upload_btn = gr.Button("Upload")
df_table = gr.Dataframe(headers=["text","label"], label="📝 Editable Table",
interactive=True, visible=False)
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
save_btn = gr.Button("💾 Save")
download_btn = gr.Button("⬇️ Download CSV")
download_out = gr.File(label="📥 Downloaded File")
with gr.Row():
visualize_btn = gr.Button("📊 Visualize Distribution")
chart_plot = gr.Plot(label="Label Distribution")
download_chart = gr.File(label="📥 Download Chart")
with gr.Accordion("📦 Push to Hugging Face Hub", open=False):
repo_input = gr.Textbox(label="Repo (username/dataset-name)")
token_input = gr.Textbox(label="🔑 HF Token", type="password")
push_btn = gr.Button("🚀 Push")
push_status = gr.Textbox(label="Push Status", interactive=False)
# Bind events
upload_btn.click(upload_csv, inputs=file_input, outputs=[df_table, status])
save_btn.click(save_changes, inputs=df_table, outputs=status)
download_btn.click(download_csv, outputs=download_out)
visualize_btn.click(visualize_and_download_chart, outputs=[chart_plot, download_chart])
push_btn.click(push_to_hub, inputs=[repo_input, token_input], outputs=push_status)
app.launch()