Update app.py
Browse files
app.py
CHANGED
@@ -9,20 +9,12 @@ import matplotlib.pyplot as plt
|
|
9 |
# Global DataFrame
|
10 |
df = pd.DataFrame()
|
11 |
|
12 |
-
# List of free, recommended models (for future auto-labeling)
|
13 |
-
DEFAULT_MODELS = [
|
14 |
-
"mistralai/Mistral-7B-Instruct-v0.2",
|
15 |
-
"HuggingFaceH4/zephyr-7b-beta",
|
16 |
-
"tiiuae/falcon-rw-1b",
|
17 |
-
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
18 |
-
]
|
19 |
-
|
20 |
def upload_csv(file):
|
21 |
global df
|
22 |
df = pd.read_csv(file.name)
|
23 |
if "text" not in df.columns or "label" not in df.columns:
|
24 |
return gr.update(visible=False), "❌ CSV must contain ‘text’ and ‘label’ columns."
|
25 |
-
df["label"] = df["label"].fillna("")
|
26 |
return (
|
27 |
gr.update(value=df[["text","label"]], visible=True),
|
28 |
"✅ File uploaded — you can now edit labels."
|
@@ -39,34 +31,50 @@ def download_csv():
|
|
39 |
df.to_csv(out_path, index=False)
|
40 |
return out_path
|
41 |
|
42 |
-
def
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
fig,
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
plt.tight_layout()
|
53 |
return fig
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def push_to_hub(repo_name: str, hf_token: str) -> str:
|
56 |
global df
|
57 |
try:
|
58 |
api = HfApi()
|
59 |
-
api.create_repo(
|
60 |
-
repo_id=repo_name,
|
61 |
-
token=hf_token,
|
62 |
-
repo_type="dataset",
|
63 |
-
exist_ok=True
|
64 |
-
)
|
65 |
|
66 |
local_dir = Path(f"./{repo_name.replace('/', '_')}")
|
67 |
if local_dir.exists():
|
68 |
-
for child in local_dir.iterdir():
|
69 |
-
child.unlink()
|
70 |
local_dir.rmdir()
|
71 |
|
72 |
repo = Repository(
|
@@ -80,7 +88,6 @@ def push_to_hub(repo_name: str, hf_token: str) -> str:
|
|
80 |
df.to_csv(csv_path, index=False)
|
81 |
repo.push_to_hub(commit_message="📑 Update annotated data")
|
82 |
return f"🚀 Pushed to https://huggingface.co/datasets/{repo_name}"
|
83 |
-
|
84 |
except Exception as e:
|
85 |
return f"❌ Push failed: {e}"
|
86 |
|
@@ -92,27 +99,19 @@ with gr.Blocks(theme=gr.themes.Default()) as app:
|
|
92 |
file_input = gr.File(label="📁 Upload CSV", file_types=[".csv"])
|
93 |
upload_btn = gr.Button("Upload")
|
94 |
|
95 |
-
df_table = gr.Dataframe(
|
96 |
-
|
97 |
-
label="📝 Editable Table",
|
98 |
-
interactive=True,
|
99 |
-
visible=False
|
100 |
-
)
|
101 |
status = gr.Textbox(label="Status", interactive=False)
|
102 |
|
103 |
with gr.Row():
|
104 |
-
save_btn
|
105 |
-
download_btn
|
106 |
-
|
107 |
-
download_out = gr.File(label="📥 Downloaded File")
|
108 |
-
viz_out = gr.Plot(label="Label Distribution")
|
109 |
|
110 |
with gr.Row():
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
value=DEFAULT_MODELS[0]
|
115 |
-
)
|
116 |
|
117 |
with gr.Accordion("📦 Push to Hugging Face Hub", open=False):
|
118 |
repo_input = gr.Textbox(label="Repo (username/dataset-name)")
|
@@ -121,10 +120,10 @@ with gr.Blocks(theme=gr.themes.Default()) as app:
|
|
121 |
push_status = gr.Textbox(label="Push Status", interactive=False)
|
122 |
|
123 |
# Bind events
|
124 |
-
upload_btn.click(upload_csv,
|
125 |
-
save_btn.click(
|
126 |
download_btn.click(download_csv, outputs=download_out)
|
127 |
-
visualize_btn.click(
|
128 |
-
push_btn.click(
|
129 |
|
130 |
app.launch()
|
|
|
9 |
# Global DataFrame
|
10 |
df = pd.DataFrame()
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def upload_csv(file):
|
13 |
global df
|
14 |
df = pd.read_csv(file.name)
|
15 |
if "text" not in df.columns or "label" not in df.columns:
|
16 |
return gr.update(visible=False), "❌ CSV must contain ‘text’ and ‘label’ columns."
|
17 |
+
df["label"] = df["label"].fillna("")
|
18 |
return (
|
19 |
gr.update(value=df[["text","label"]], visible=True),
|
20 |
"✅ File uploaded — you can now edit labels."
|
|
|
31 |
df.to_csv(out_path, index=False)
|
32 |
return out_path
|
33 |
|
34 |
+
def create_distribution_figure(df_input):
|
35 |
+
counts = df_input["label"].value_counts().sort_values(ascending=False)
|
36 |
+
labels = counts.index.tolist()
|
37 |
+
values = counts.values.tolist()
|
38 |
+
|
39 |
+
fig, (ax_table, ax_bar) = plt.subplots(
|
40 |
+
nrows=1, ncols=2,
|
41 |
+
gridspec_kw={"width_ratios": [1, 2]},
|
42 |
+
figsize=(8, max(2, len(labels) * 0.3))
|
43 |
+
)
|
44 |
+
|
45 |
+
# Table
|
46 |
+
ax_table.axis("off")
|
47 |
+
table_data = [[lab, cnt] for lab, cnt in zip(labels, values)]
|
48 |
+
tbl = ax_table.table(cellText=table_data, colLabels=["Label","Count"], loc="center")
|
49 |
+
tbl.auto_set_font_size(False)
|
50 |
+
tbl.set_fontsize(10)
|
51 |
+
tbl.scale(1, 1.5)
|
52 |
+
|
53 |
+
# Bar chart
|
54 |
+
ax_bar.barh(labels, values)
|
55 |
+
ax_bar.invert_yaxis()
|
56 |
+
ax_bar.set_xlabel("Count")
|
57 |
+
ax_bar.set_ylabel("")
|
58 |
+
|
59 |
plt.tight_layout()
|
60 |
return fig
|
61 |
|
62 |
+
def visualize_and_download_chart():
|
63 |
+
global df
|
64 |
+
fig = create_distribution_figure(df)
|
65 |
+
chart_path = "label_distribution.png"
|
66 |
+
fig.savefig(chart_path, dpi=150)
|
67 |
+
return fig, chart_path
|
68 |
+
|
69 |
def push_to_hub(repo_name: str, hf_token: str) -> str:
|
70 |
global df
|
71 |
try:
|
72 |
api = HfApi()
|
73 |
+
api.create_repo(repo_id=repo_name, token=hf_token, repo_type="dataset", exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
local_dir = Path(f"./{repo_name.replace('/', '_')}")
|
76 |
if local_dir.exists():
|
77 |
+
for child in local_dir.iterdir(): child.unlink()
|
|
|
78 |
local_dir.rmdir()
|
79 |
|
80 |
repo = Repository(
|
|
|
88 |
df.to_csv(csv_path, index=False)
|
89 |
repo.push_to_hub(commit_message="📑 Update annotated data")
|
90 |
return f"🚀 Pushed to https://huggingface.co/datasets/{repo_name}"
|
|
|
91 |
except Exception as e:
|
92 |
return f"❌ Push failed: {e}"
|
93 |
|
|
|
99 |
file_input = gr.File(label="📁 Upload CSV", file_types=[".csv"])
|
100 |
upload_btn = gr.Button("Upload")
|
101 |
|
102 |
+
df_table = gr.Dataframe(headers=["text","label"], label="📝 Editable Table",
|
103 |
+
interactive=True, visible=False)
|
|
|
|
|
|
|
|
|
104 |
status = gr.Textbox(label="Status", interactive=False)
|
105 |
|
106 |
with gr.Row():
|
107 |
+
save_btn = gr.Button("💾 Save")
|
108 |
+
download_btn = gr.Button("⬇️ Download CSV")
|
109 |
+
download_out = gr.File(label="📥 Downloaded File")
|
|
|
|
|
110 |
|
111 |
with gr.Row():
|
112 |
+
visualize_btn = gr.Button("📊 Visualize Distribution")
|
113 |
+
chart_plot = gr.Plot(label="Label Distribution")
|
114 |
+
download_chart = gr.File(label="📥 Download Chart")
|
|
|
|
|
115 |
|
116 |
with gr.Accordion("📦 Push to Hugging Face Hub", open=False):
|
117 |
repo_input = gr.Textbox(label="Repo (username/dataset-name)")
|
|
|
120 |
push_status = gr.Textbox(label="Push Status", interactive=False)
|
121 |
|
122 |
# Bind events
|
123 |
+
upload_btn.click(upload_csv, inputs=file_input, outputs=[df_table, status])
|
124 |
+
save_btn.click(save_changes, inputs=df_table, outputs=status)
|
125 |
download_btn.click(download_csv, outputs=download_out)
|
126 |
+
visualize_btn.click(visualize_and_download_chart, outputs=[chart_plot, download_chart])
|
127 |
+
push_btn.click(push_to_hub, inputs=[repo_input, token_input], outputs=push_status)
|
128 |
|
129 |
app.launch()
|