File size: 4,994 Bytes
20e7095
 
2dccd10
d18e6c8
a4cec6f
20e7095
c91426b
 
 
 
 
3ea3aae
20e7095
 
 
 
 
 
c91426b
2dccd10
 
 
 
 
 
c91426b
9e6c3bb
1d6c7cd
2dccd10
 
c91426b
 
 
 
1d6c7cd
20e7095
2dccd10
20e7095
2dccd10
1d6c7cd
20e7095
 
a4cec6f
c91426b
 
 
20e7095
2dccd10
 
 
 
9e6c3bb
 
c91426b
 
 
 
9e6c3bb
 
 
2dccd10
 
 
9e6c3bb
2dccd10
 
3ea3aae
 
9e6c3bb
2dccd10
c91426b
 
 
9e6c3bb
c91426b
a4cec6f
d18e6c8
 
c91426b
 
 
d18e6c8
c91426b
d18e6c8
 
 
 
3ea3aae
d18e6c8
 
c91426b
d18e6c8
 
 
 
a4cec6f
1d6c7cd
2dccd10
 
 
c91426b
2dccd10
3ea3aae
2dccd10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c91426b
 
2dccd10
 
c91426b
 
2dccd10
c91426b
 
2dccd10
c91426b
 
d18e6c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from huggingface_hub import HfApi, Repository

plt.rcParams.update({
    "font.family": "sans-serif",
    "font.size": 10,
})

# Global DataFrame
df = pd.DataFrame()

def upload_csv(file):
    global df
    df = pd.read_csv(file.name)
    if "text" not in df.columns or "label" not in df.columns:
        return (
            None,  # table
            "❌ CSV must contain 'text' and 'label' columns.",
            gr.update(visible=False),  # save
            gr.update(visible=False),  # download CSV
            gr.update(visible=False),  # visualize
            gr.update(visible=False),  # push accordion
        )
    df["label"] = df["label"].fillna("")
    return (
        df[["text","label"]],
        "✅ File uploaded — you can now annotate and use the buttons below.",
        gr.update(visible=True),
        gr.update(visible=True),
        gr.update(visible=True),
        gr.update(visible=True),
    )

def save_changes(table):
    global df
    df = pd.DataFrame(table, columns=["text","label"])
    return "💾 Changes saved."

def download_csv():
    global df
    path = "annotated_data.csv"
    df.to_csv(path, index=False)
    return path

def create_distribution_figure():
    global df
    counts = df["label"].value_counts().sort_values(ascending=False)
    labels, values = counts.index.tolist(), counts.values.tolist()
    
    fig, (ax_table, ax_bar) = plt.subplots(
        ncols=2,
        gridspec_kw={"width_ratios": [1,2]},
        figsize=(8, max(2, len(labels)*0.4)),
        tight_layout=True
    )
    # Table
    ax_table.axis("off")
    data = [[l,v] for l,v in zip(labels, values)]
    tbl = ax_table.table(cellText=data, colLabels=["Label","Count"], loc="center")
    tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1,1.2)
    # Bar chart
    ax_bar.barh(labels, values, color="#222")
    ax_bar.invert_yaxis(); ax_bar.set_xlabel("Count")
    return fig

def visualize_and_download_chart():
    fig = create_distribution_figure()
    out_path = "label_distribution.png"
    fig.savefig(out_path, dpi=150, bbox_inches="tight")
    return fig, out_path

def push_to_hub(repo_name, hf_token):
    global df
    try:
        api = HfApi()
        api.create_repo(repo_id=repo_name, token=hf_token,
                        repo_type="dataset", exist_ok=True)
        local_dir = Path(f"./{repo_name.replace('/','_')}")
        if local_dir.exists():
            for f in local_dir.iterdir(): f.unlink()
            local_dir.rmdir()
        repo = Repository(
            local_dir=str(local_dir),
            clone_from=repo_name,
            repo_type="dataset",
            use_auth_token=hf_token
        )
        df.to_csv(local_dir/"data.csv", index=False)
        repo.push_to_hub(commit_message="📑 Update annotated data")
        return f"🚀 Pushed to https://huggingface.co/datasets/{repo_name}"
    except Exception as e:
        return f"❌ Push failed: {e}"

with gr.Blocks(theme=gr.themes.Default()) as app:
    gr.Markdown("## 🏷️ Label It! Text Annotation Tool\n"
                "Upload a `.csv` (with **text** + **label** columns), "
                "then annotate, export, visualize, or publish.")
    
    # Step 1: Upload
    with gr.Row():
        file_input = gr.File(label="📁 Upload CSV", file_types=[".csv"])
        upload_btn = gr.Button("Upload")
    
    # Editable table
    table = gr.Dataframe(headers=["text","label"], interactive=True, visible=False)
    status = gr.Textbox(label="Status", interactive=False)
    
    # Step 2 buttons (hidden initially)
    with gr.Row(visible=False) as action_row:
        save_btn     = gr.Button("💾 Save")
        download_btn = gr.Button("⬇️ Download CSV")
        visualize_btn= gr.Button("📊 Visualize Distribution")
        download_csv_out   = gr.File(label="📥 Download CSV")
        chart_plot         = gr.Plot(label="Label Distribution")
        download_chart_out = gr.File(label="📥 Download Chart")
    
    # Push accordion
    push_acc = gr.Accordion("📦 Push to Hugging Face Hub", open=False, visible=False)
    with push_acc:
        repo_in   = gr.Textbox(label="Repo (username/dataset-name)")
        token_in  = gr.Textbox(label="🔑 HF Token", type="password")
        push_btn  = gr.Button("🚀 Push")
        push_status = gr.Textbox(label="Push Status", interactive=False)
    
    # Event bindings
    upload_btn.click(
        upload_csv,
        inputs=file_input,
        outputs=[table, status,
                 save_btn, download_btn, visualize_btn, push_acc]
    )
    save_btn.click(save_changes, inputs=table, outputs=status)
    download_btn.click(download_csv, outputs=download_csv_out)
    visualize_btn.click(visualize_and_download_chart,
                        outputs=[chart_plot, download_chart_out])
    push_btn.click(push_to_hub, inputs=[repo_in, token_in], outputs=push_status)

app.launch()