Alexandra Zapko-Willmes commited on
Commit
cc03b28
·
verified ·
1 Parent(s): 7a32121

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -36
app.py CHANGED
@@ -1,64 +1,60 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import pandas as pd
4
  import io
 
5
 
6
- # Load once
7
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
 
 
 
8
 
9
  response_table = []
10
 
11
- def classify_items(questions_text, labels_text):
 
12
  questions = [q.strip() for q in questions_text.strip().split("\n") if q.strip()]
13
- labels = [l.strip() for l in labels_text.strip().split(",") if l.strip()]
14
-
15
  if not labels or not questions:
16
- return "Please provide both items and at least two response options.", ""
17
 
 
18
  global response_table
19
  response_table = []
20
  output_lines = []
21
 
22
  for i, question in enumerate(questions, 1):
23
  result = classifier(question, labels, multi_label=False)
24
- probs = dict(zip(result['labels'], result['scores']))
25
-
26
  output_lines.append(f"{i}. {question}")
27
- for label in labels:
28
- output_lines.append(f"→ {label}: {round(probs.get(label, 0.0), 3)}")
 
29
  output_lines.append("")
30
-
31
- row = {"Item #": i, "Item": question}
32
- row.update({label: round(probs.get(label, 0.0), 3) for label in labels})
33
  response_table.append(row)
34
 
35
  return "\n".join(output_lines), None
36
 
37
  def download_csv():
38
- global response_table
39
- if not response_table:
40
- return None
41
  df = pd.DataFrame(response_table)
42
- csv_buffer = io.StringIO()
43
- df.to_csv(csv_buffer, index=False)
44
- return csv_buffer.getvalue()
45
 
46
- # Gradio UI
47
  with gr.Blocks() as demo:
48
- gr.Markdown("# 🧠 Zero-Shot Classification for Questionnaire Responses")
49
- gr.Markdown("Paste questionnaire items (one per line), and provide your own response labels (comma-separated).")
50
-
51
- with gr.Row():
52
- with gr.Column():
53
- questions_input = gr.Textbox(label="Questionnaire Items", lines=10, placeholder="e.g.\nI feel in control of my life.\nI enjoy being around others.")
54
- labels_input = gr.Textbox(label="Response Options (comma-separated)", placeholder="Strongly disagree, Disagree, Neutral, Agree, Strongly agree")
55
- submit_btn = gr.Button("Classify Items")
56
- csv_btn = gr.Button("📥 Download CSV")
57
- with gr.Column():
58
- output_box = gr.Textbox(label="Classification Output", lines=20)
59
- file_output = gr.File(label="Download CSV", visible=False)
60
-
61
- submit_btn.click(fn=classify_items, inputs=[questions_input, labels_input], outputs=[output_box, file_output])
62
- csv_btn.click(fn=download_csv, inputs=[], outputs=file_output)
63
 
64
  demo.launch()
 
1
  import gradio as gr
 
2
  import pandas as pd
3
  import io
4
+ from transformers import pipeline
5
 
6
+ # Available zero-shot classification models
7
+ models = {
8
+ "EN: deberta-v3-large-zeroshot": "MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
9
+ "MULTI: mDeBERTa-v3-xnli": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
10
+ "MULTI: xlm-roberta-xnli": "joeddav/xlm-roberta-large-xnli"
11
+ }
12
 
13
  response_table = []
14
 
15
+ def classify_items(questions_text, labels_text, model_choice):
16
+ labels = [l.strip() for l in labels_text.split(",") if l.strip()]
17
  questions = [q.strip() for q in questions_text.strip().split("\n") if q.strip()]
 
 
18
  if not labels or not questions:
19
+ return "Please enter both questionnaire items and response labels.", None
20
 
21
+ classifier = pipeline("zero-shot-classification", model=models[model_choice])
22
  global response_table
23
  response_table = []
24
  output_lines = []
25
 
26
  for i, question in enumerate(questions, 1):
27
  result = classifier(question, labels, multi_label=False)
28
+ row = {"Item #": i, "Item": question}
 
29
  output_lines.append(f"{i}. {question}")
30
+ for label, score in zip(result["labels"], result["scores"]):
31
+ row[label] = round(score, 3)
32
+ output_lines.append(f"→ {label}: {round(score, 3)}")
33
  output_lines.append("")
 
 
 
34
  response_table.append(row)
35
 
36
  return "\n".join(output_lines), None
37
 
38
  def download_csv():
 
 
 
39
  df = pd.DataFrame(response_table)
40
+ buffer = io.StringIO()
41
+ df.to_csv(buffer, index=False)
42
+ return buffer.getvalue()
43
 
44
+ # Gradio interface
45
  with gr.Blocks() as demo:
46
+ gr.Markdown("## 🧠 Zero-Shot Classification with Model Selection")
47
+ gr.Markdown("Students can enter multiple questionnaire items and define their own response labels. The selected model will classify each item and provide probabilities.")
48
+
49
+ model_dropdown = gr.Dropdown(choices=list(models.keys()), label="Choose a model")
50
+ labels_input = gr.Textbox(label="Response Options (comma-separated)", placeholder="e.g., Strongly disagree, Disagree, Neutral, Agree, Strongly agree")
51
+ questions_input = gr.Textbox(label="Questionnaire Items (one per line)", lines=10)
52
+ output_box = gr.Textbox(label="Model Output", lines=20)
53
+ submit_btn = gr.Button("Classify")
54
+ download_btn = gr.Button("📥 Download CSV")
55
+ file_output = gr.File(label="Download CSV", visible=False)
56
+
57
+ submit_btn.click(fn=classify_items, inputs=[questions_input, labels_input, model_dropdown], outputs=[output_box, file_output])
58
+ download_btn.click(fn=download_csv, inputs=[], outputs=file_output)
 
 
59
 
60
  demo.launch()