Suzana commited on
Commit
9e6c3bb
·
verified ·
1 Parent(s): 3ea3aae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -48
app.py CHANGED
@@ -9,20 +9,12 @@ import matplotlib.pyplot as plt
9
  # Global DataFrame
10
  df = pd.DataFrame()
11
 
12
- # List of free, recommended models (for future auto-labeling)
13
- DEFAULT_MODELS = [
14
- "mistralai/Mistral-7B-Instruct-v0.2",
15
- "HuggingFaceH4/zephyr-7b-beta",
16
- "tiiuae/falcon-rw-1b",
17
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
18
- ]
19
-
20
  def upload_csv(file):
21
  global df
22
  df = pd.read_csv(file.name)
23
  if "text" not in df.columns or "label" not in df.columns:
24
  return gr.update(visible=False), "❌ CSV must contain ‘text’ and ‘label’ columns."
25
- df["label"] = df["label"].fillna("") # ensure there’s always a label column
26
  return (
27
  gr.update(value=df[["text","label"]], visible=True),
28
  "✅ File uploaded — you can now edit labels."
@@ -39,34 +31,50 @@ def download_csv():
39
  df.to_csv(out_path, index=False)
40
  return out_path
41
 
42
- def visualize_distribution():
43
- global df
44
- if df.empty or "label" not in df.columns:
45
- return None
46
- counts = df["label"].value_counts()
47
- fig, ax = plt.subplots()
48
- counts.plot(kind="bar", ax=ax)
49
- ax.set_title("Label Distribution")
50
- ax.set_xlabel("Label")
51
- ax.set_ylabel("Count")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  plt.tight_layout()
53
  return fig
54
 
 
 
 
 
 
 
 
55
  def push_to_hub(repo_name: str, hf_token: str) -> str:
56
  global df
57
  try:
58
  api = HfApi()
59
- api.create_repo(
60
- repo_id=repo_name,
61
- token=hf_token,
62
- repo_type="dataset",
63
- exist_ok=True
64
- )
65
 
66
  local_dir = Path(f"./{repo_name.replace('/', '_')}")
67
  if local_dir.exists():
68
- for child in local_dir.iterdir():
69
- child.unlink()
70
  local_dir.rmdir()
71
 
72
  repo = Repository(
@@ -80,7 +88,6 @@ def push_to_hub(repo_name: str, hf_token: str) -> str:
80
  df.to_csv(csv_path, index=False)
81
  repo.push_to_hub(commit_message="📑 Update annotated data")
82
  return f"🚀 Pushed to https://huggingface.co/datasets/{repo_name}"
83
-
84
  except Exception as e:
85
  return f"❌ Push failed: {e}"
86
 
@@ -92,27 +99,19 @@ with gr.Blocks(theme=gr.themes.Default()) as app:
92
  file_input = gr.File(label="📁 Upload CSV", file_types=[".csv"])
93
  upload_btn = gr.Button("Upload")
94
 
95
- df_table = gr.Dataframe(
96
- headers=["text","label"],
97
- label="📝 Editable Table",
98
- interactive=True,
99
- visible=False
100
- )
101
  status = gr.Textbox(label="Status", interactive=False)
102
 
103
  with gr.Row():
104
- save_btn = gr.Button("💾 Save")
105
- download_btn = gr.Button("⬇️ Download CSV")
106
- visualize_btn= gr.Button("📊 Visualize Distribution")
107
- download_out = gr.File(label="📥 Downloaded File")
108
- viz_out = gr.Plot(label="Label Distribution")
109
 
110
  with gr.Row():
111
- model_dropdown = gr.Dropdown(
112
- label="🤖 (Future) Auto-Label Model",
113
- choices=DEFAULT_MODELS,
114
- value=DEFAULT_MODELS[0]
115
- )
116
 
117
  with gr.Accordion("📦 Push to Hugging Face Hub", open=False):
118
  repo_input = gr.Textbox(label="Repo (username/dataset-name)")
@@ -121,10 +120,10 @@ with gr.Blocks(theme=gr.themes.Default()) as app:
121
  push_status = gr.Textbox(label="Push Status", interactive=False)
122
 
123
  # Bind events
124
- upload_btn.click(upload_csv, inputs=file_input, outputs=[df_table, status])
125
- save_btn.click( save_changes, inputs=df_table, outputs=status)
126
  download_btn.click(download_csv, outputs=download_out)
127
- visualize_btn.click(visualize_distribution, outputs=viz_out)
128
- push_btn.click( push_to_hub, inputs=[repo_input, token_input], outputs=push_status)
129
 
130
  app.launch()
 
9
  # Global DataFrame
10
  df = pd.DataFrame()
11
 
 
 
 
 
 
 
 
 
12
  def upload_csv(file):
13
  global df
14
  df = pd.read_csv(file.name)
15
  if "text" not in df.columns or "label" not in df.columns:
16
  return gr.update(visible=False), "❌ CSV must contain ‘text’ and ‘label’ columns."
17
+ df["label"] = df["label"].fillna("")
18
  return (
19
  gr.update(value=df[["text","label"]], visible=True),
20
  "✅ File uploaded — you can now edit labels."
 
31
  df.to_csv(out_path, index=False)
32
  return out_path
33
 
34
+ def create_distribution_figure(df_input):
35
+ counts = df_input["label"].value_counts().sort_values(ascending=False)
36
+ labels = counts.index.tolist()
37
+ values = counts.values.tolist()
38
+
39
+ fig, (ax_table, ax_bar) = plt.subplots(
40
+ nrows=1, ncols=2,
41
+ gridspec_kw={"width_ratios": [1, 2]},
42
+ figsize=(8, max(2, len(labels) * 0.3))
43
+ )
44
+
45
+ # Table
46
+ ax_table.axis("off")
47
+ table_data = [[lab, cnt] for lab, cnt in zip(labels, values)]
48
+ tbl = ax_table.table(cellText=table_data, colLabels=["Label","Count"], loc="center")
49
+ tbl.auto_set_font_size(False)
50
+ tbl.set_fontsize(10)
51
+ tbl.scale(1, 1.5)
52
+
53
+ # Bar chart
54
+ ax_bar.barh(labels, values)
55
+ ax_bar.invert_yaxis()
56
+ ax_bar.set_xlabel("Count")
57
+ ax_bar.set_ylabel("")
58
+
59
  plt.tight_layout()
60
  return fig
61
 
62
+ def visualize_and_download_chart():
63
+ global df
64
+ fig = create_distribution_figure(df)
65
+ chart_path = "label_distribution.png"
66
+ fig.savefig(chart_path, dpi=150)
67
+ return fig, chart_path
68
+
69
  def push_to_hub(repo_name: str, hf_token: str) -> str:
70
  global df
71
  try:
72
  api = HfApi()
73
+ api.create_repo(repo_id=repo_name, token=hf_token, repo_type="dataset", exist_ok=True)
 
 
 
 
 
74
 
75
  local_dir = Path(f"./{repo_name.replace('/', '_')}")
76
  if local_dir.exists():
77
+ for child in local_dir.iterdir(): child.unlink()
 
78
  local_dir.rmdir()
79
 
80
  repo = Repository(
 
88
  df.to_csv(csv_path, index=False)
89
  repo.push_to_hub(commit_message="📑 Update annotated data")
90
  return f"🚀 Pushed to https://huggingface.co/datasets/{repo_name}"
 
91
  except Exception as e:
92
  return f"❌ Push failed: {e}"
93
 
 
99
  file_input = gr.File(label="📁 Upload CSV", file_types=[".csv"])
100
  upload_btn = gr.Button("Upload")
101
 
102
+ df_table = gr.Dataframe(headers=["text","label"], label="📝 Editable Table",
103
+ interactive=True, visible=False)
 
 
 
 
104
  status = gr.Textbox(label="Status", interactive=False)
105
 
106
  with gr.Row():
107
+ save_btn = gr.Button("💾 Save")
108
+ download_btn = gr.Button("⬇️ Download CSV")
109
+ download_out = gr.File(label="📥 Downloaded File")
 
 
110
 
111
  with gr.Row():
112
+ visualize_btn = gr.Button("📊 Visualize Distribution")
113
+ chart_plot = gr.Plot(label="Label Distribution")
114
+ download_chart = gr.File(label="📥 Download Chart")
 
 
115
 
116
  with gr.Accordion("📦 Push to Hugging Face Hub", open=False):
117
  repo_input = gr.Textbox(label="Repo (username/dataset-name)")
 
120
  push_status = gr.Textbox(label="Push Status", interactive=False)
121
 
122
  # Bind events
123
+ upload_btn.click(upload_csv, inputs=file_input, outputs=[df_table, status])
124
+ save_btn.click(save_changes, inputs=df_table, outputs=status)
125
  download_btn.click(download_csv, outputs=download_out)
126
+ visualize_btn.click(visualize_and_download_chart, outputs=[chart_plot, download_chart])
127
+ push_btn.click(push_to_hub, inputs=[repo_input, token_input], outputs=push_status)
128
 
129
  app.launch()