Alexandra Zapko-Willmes commited on
Commit
14bd759
·
verified ·
1 Parent(s): 561e518

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -18
app.py CHANGED
@@ -6,31 +6,33 @@ import io
6
  # Load once
7
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
8
 
9
- LIKERT_OPTIONS = ["Strongly disagree", "Disagree", "Neutral", "Agree", "Strongly agree"]
10
-
11
  response_table = []
12
 
13
- def classify_likert(questions_text):
14
  questions = [q.strip() for q in questions_text.strip().split("\n") if q.strip()]
 
15
 
 
 
 
16
  global response_table
17
  response_table = []
18
  output_lines = []
19
 
20
  for i, question in enumerate(questions, 1):
21
- result = classifier(question, LIKERT_OPTIONS, multi_label=False)
22
  probs = dict(zip(result['labels'], result['scores']))
 
23
  output_lines.append(f"{i}. {question}")
24
- for label in LIKERT_OPTIONS:
25
- prob = round(probs.get(label, 0.0), 3)
26
- output_lines.append(f"→ {label}: {prob}")
27
  output_lines.append("")
28
 
29
  row = {"Item #": i, "Item": question}
30
- row.update({label: round(probs.get(label, 0.0), 3) for label in LIKERT_OPTIONS})
31
  response_table.append(row)
32
 
33
- return "\n".join(output_lines)
34
 
35
  def download_csv():
36
  global response_table
@@ -41,18 +43,22 @@ def download_csv():
41
  df.to_csv(csv_buffer, index=False)
42
  return csv_buffer.getvalue()
43
 
44
- # Gradio interface
45
  with gr.Blocks() as demo:
46
- gr.Markdown("# Likert-Style Zero-Shot Classifier")
47
- gr.Markdown("Paste questionnaire items. Each will be classified into: Strongly disagree Strongly agree, with probabilities.")
48
 
49
- questions_input = gr.Textbox(label="Enter multiple items (one per line)", lines=10, placeholder="e.g.\nI feel in control of my life.\nI enjoy being around others...")
50
- output_box = gr.Textbox(label="Classification Output", lines=20)
51
- submit_btn = gr.Button("Classify Items")
52
- csv_btn = gr.Button("📥 Download CSV")
53
- file_output = gr.File(label="Download CSV", visible=False)
 
 
 
 
54
 
55
- submit_btn.click(fn=classify_likert, inputs=questions_input, outputs=output_box)
56
  csv_btn.click(fn=download_csv, inputs=[], outputs=file_output)
57
 
58
  demo.launch()
 
6
  # Load once
7
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
8
 
 
 
9
  response_table = []
10
 
11
+ def classify_items(questions_text, labels_text):
12
  questions = [q.strip() for q in questions_text.strip().split("\n") if q.strip()]
13
+ labels = [l.strip() for l in labels_text.strip().split(",") if l.strip()]
14
 
15
+ if not labels or not questions:
16
+ return "Please provide both items and at least two response options.", ""
17
+
18
  global response_table
19
  response_table = []
20
  output_lines = []
21
 
22
  for i, question in enumerate(questions, 1):
23
+ result = classifier(question, labels, multi_label=False)
24
  probs = dict(zip(result['labels'], result['scores']))
25
+
26
  output_lines.append(f"{i}. {question}")
27
+ for label in labels:
28
+ output_lines.append(f"→ {label}: {round(probs.get(label, 0.0), 3)}")
 
29
  output_lines.append("")
30
 
31
  row = {"Item #": i, "Item": question}
32
+ row.update({label: round(probs.get(label, 0.0), 3) for label in labels})
33
  response_table.append(row)
34
 
35
+ return "\n".join(output_lines), None
36
 
37
  def download_csv():
38
  global response_table
 
43
  df.to_csv(csv_buffer, index=False)
44
  return csv_buffer.getvalue()
45
 
46
+ # Gradio UI
47
  with gr.Blocks() as demo:
48
+ gr.Markdown("# 🧠 Zero-Shot Classification for Questionnaire Responses")
49
+ gr.Markdown("Paste questionnaire items (one per line), and provide your own response labels (comma-separated).")
50
 
51
+ with gr.Row():
52
+ with gr.Column():
53
+ questions_input = gr.Textbox(label="Questionnaire Items", lines=10, placeholder="e.g.\nI feel in control of my life.\nI enjoy being around others.")
54
+ labels_input = gr.Textbox(label="Response Options (comma-separated)", placeholder="Strongly disagree, Disagree, Neutral, Agree, Strongly agree")
55
+ submit_btn = gr.Button("Classify Items")
56
+ csv_btn = gr.Button("📥 Download CSV")
57
+ with gr.Column():
58
+ output_box = gr.Textbox(label="Classification Output", lines=20)
59
+ file_output = gr.File(label="Download CSV", visible=False)
60
 
61
+ submit_btn.click(fn=classify_items, inputs=[questions_input, labels_input], outputs=[output_box, file_output])
62
  csv_btn.click(fn=download_csv, inputs=[], outputs=file_output)
63
 
64
  demo.launch()