labelit / app.py
Suzana's picture
Update app.py
2dccd10 verified
raw
history blame
4.99 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from huggingface_hub import HfApi, Repository
plt.rcParams.update({
"font.family": "sans-serif",
"font.size": 10,
})
# Global DataFrame
df = pd.DataFrame()
def upload_csv(file):
global df
df = pd.read_csv(file.name)
if "text" not in df.columns or "label" not in df.columns:
return (
None, # table
"❌ CSV must contain 'text' and 'label' columns.",
gr.update(visible=False), # save
gr.update(visible=False), # download CSV
gr.update(visible=False), # visualize
gr.update(visible=False), # push accordion
)
df["label"] = df["label"].fillna("")
return (
df[["text","label"]],
"✅ File uploaded — you can now annotate and use the buttons below.",
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
)
def save_changes(table):
global df
df = pd.DataFrame(table, columns=["text","label"])
return "💾 Changes saved."
def download_csv():
global df
path = "annotated_data.csv"
df.to_csv(path, index=False)
return path
def create_distribution_figure():
global df
counts = df["label"].value_counts().sort_values(ascending=False)
labels, values = counts.index.tolist(), counts.values.tolist()
fig, (ax_table, ax_bar) = plt.subplots(
ncols=2,
gridspec_kw={"width_ratios": [1,2]},
figsize=(8, max(2, len(labels)*0.4)),
tight_layout=True
)
# Table
ax_table.axis("off")
data = [[l,v] for l,v in zip(labels, values)]
tbl = ax_table.table(cellText=data, colLabels=["Label","Count"], loc="center")
tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1,1.2)
# Bar chart
ax_bar.barh(labels, values, color="#222")
ax_bar.invert_yaxis(); ax_bar.set_xlabel("Count")
return fig
def visualize_and_download_chart():
fig = create_distribution_figure()
out_path = "label_distribution.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight")
return fig, out_path
def push_to_hub(repo_name, hf_token):
global df
try:
api = HfApi()
api.create_repo(repo_id=repo_name, token=hf_token,
repo_type="dataset", exist_ok=True)
local_dir = Path(f"./{repo_name.replace('/','_')}")
if local_dir.exists():
for f in local_dir.iterdir(): f.unlink()
local_dir.rmdir()
repo = Repository(
local_dir=str(local_dir),
clone_from=repo_name,
repo_type="dataset",
use_auth_token=hf_token
)
df.to_csv(local_dir/"data.csv", index=False)
repo.push_to_hub(commit_message="📑 Update annotated data")
return f"🚀 Pushed to https://huggingface.co/datasets/{repo_name}"
except Exception as e:
return f"❌ Push failed: {e}"
with gr.Blocks(theme=gr.themes.Default()) as app:
gr.Markdown("## 🏷️ Label It! Text Annotation Tool\n"
"Upload a `.csv` (with **text** + **label** columns), "
"then annotate, export, visualize, or publish.")
# Step 1: Upload
with gr.Row():
file_input = gr.File(label="📁 Upload CSV", file_types=[".csv"])
upload_btn = gr.Button("Upload")
# Editable table
table = gr.Dataframe(headers=["text","label"], interactive=True, visible=False)
status = gr.Textbox(label="Status", interactive=False)
# Step 2 buttons (hidden initially)
with gr.Row(visible=False) as action_row:
save_btn = gr.Button("💾 Save")
download_btn = gr.Button("⬇️ Download CSV")
visualize_btn= gr.Button("📊 Visualize Distribution")
download_csv_out = gr.File(label="📥 Download CSV")
chart_plot = gr.Plot(label="Label Distribution")
download_chart_out = gr.File(label="📥 Download Chart")
# Push accordion
push_acc = gr.Accordion("📦 Push to Hugging Face Hub", open=False, visible=False)
with push_acc:
repo_in = gr.Textbox(label="Repo (username/dataset-name)")
token_in = gr.Textbox(label="🔑 HF Token", type="password")
push_btn = gr.Button("🚀 Push")
push_status = gr.Textbox(label="Push Status", interactive=False)
# Event bindings
upload_btn.click(
upload_csv,
inputs=file_input,
outputs=[table, status,
save_btn, download_btn, visualize_btn, push_acc]
)
save_btn.click(save_changes, inputs=table, outputs=status)
download_btn.click(download_csv, outputs=download_csv_out)
visualize_btn.click(visualize_and_download_chart,
outputs=[chart_plot, download_chart_out])
push_btn.click(push_to_hub, inputs=[repo_in, token_in], outputs=push_status)
app.launch()