Alexandra Zapko-Willmes commited on
Commit
23cf7a9
·
verified ·
1 Parent(s): 4036329

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -29
app.py CHANGED
@@ -1,43 +1,64 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
  import pandas as pd
 
4
 
5
- MODEL_MAP = {
6
- "MoritzLaurer/deberta-v3-large-zeroshot-v2.0": "MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
7
- "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
8
- "joeddav/xlm-roberta-large-xnli": "joeddav/xlm-roberta-large-xnli"
9
- }
10
 
11
- def classify_items(model_name, items_text, labels_text):
12
- classifier = pipeline("zero-shot-classification", model=MODEL_MAP[model_name])
13
- items = [item.strip() for item in items_text.split("\n") if item.strip()]
14
- labels = [label.strip() for label in labels_text.split(",") if label.strip()]
15
 
16
- results = []
17
- for item in items:
18
- out = classifier(item, labels, multi_label=True)
19
- scores = {label: prob for label, prob in zip(out["labels"], out["scores"])}
20
- scores["item"] = item
21
- results.append(scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- df = pd.DataFrame(results).fillna(0)
24
- return df, df.to_csv(index=False)
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  with gr.Blocks() as demo:
27
- gr.Markdown("### Zero-Shot Questionnaire Classifier")
 
28
 
29
  with gr.Row():
30
- model_choice = gr.Dropdown(choices=list(MODEL_MAP.keys()), label="Choose a model")
31
-
32
- item_input = gr.Textbox(label="Enter questionnaire items (one per line)", lines=5, placeholder="e.g., I enjoy social gatherings.\nI prefer planning over spontaneity.")
33
- label_input = gr.Textbox(label="Enter response options (comma-separated)", placeholder="e.g., Strongly disagree, Disagree, Neutral, Agree, Strongly agree")
34
-
35
- run_button = gr.Button("Classify")
36
- output_table = gr.Dataframe(label="Classification Results")
37
- download_csv = gr.File(label="Download CSV")
38
 
39
- run_button.click(fn=classify_items,
40
- inputs=[model_choice, item_input, label_input],
41
- outputs=[output_table, download_csv])
42
 
43
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
  import pandas as pd
4
+ import io
5
 
6
+ # Load once
7
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
 
 
8
 
9
+ response_table = []
 
 
 
10
 
11
+ def classify_items(questions_text, labels_text):
12
+ questions = [q.strip() for q in questions_text.strip().split("\n") if q.strip()]
13
+ labels = [l.strip() for l in labels_text.strip().split(",") if l.strip()]
14
+
15
+ if not labels or not questions:
16
+ return "Please provide both items and at least two response options.", ""
17
+
18
+ global response_table
19
+ response_table = []
20
+ output_lines = []
21
+
22
+ for i, question in enumerate(questions, 1):
23
+ result = classifier(question, labels, multi_label=False)
24
+ probs = dict(zip(result['labels'], result['scores']))
25
+
26
+ output_lines.append(f"{i}. {question}")
27
+ for label in labels:
28
+ output_lines.append(f"→ {label}: {round(probs.get(label, 0.0), 3)}")
29
+ output_lines.append("")
30
 
31
+ row = {"Item #": i, "Item": question}
32
+ row.update({label: round(probs.get(label, 0.0), 3) for label in labels})
33
+ response_table.append(row)
34
 
35
+ return "\n".join(output_lines), None
36
+
37
+ def download_csv():
38
+ global response_table
39
+ if not response_table:
40
+ return None
41
+ df = pd.DataFrame(response_table)
42
+ csv_buffer = io.StringIO()
43
+ df.to_csv(csv_buffer, index=False)
44
+ return csv_buffer.getvalue()
45
+
46
+ # Gradio UI
47
  with gr.Blocks() as demo:
48
+ gr.Markdown("# 🧠 Zero-Shot Classification for Questionnaire Responses")
49
+ gr.Markdown("Paste questionnaire items (one per line), and provide your own response labels (comma-separated).")
50
 
51
  with gr.Row():
52
+ with gr.Column():
53
+ questions_input = gr.Textbox(label="Questionnaire Items", lines=10, placeholder="e.g.\nI feel in control of my life.\nI enjoy being around others.")
54
+ labels_input = gr.Textbox(label="Response Options (comma-separated)", placeholder="Strongly disagree, Disagree, Neutral, Agree, Strongly agree")
55
+ submit_btn = gr.Button("Classify Items")
56
+ csv_btn = gr.Button("📥 Download CSV")
57
+ with gr.Column():
58
+ output_box = gr.Textbox(label="Classification Output", lines=20)
59
+ file_output = gr.File(label="Download CSV", visible=False)
60
 
61
+ submit_btn.click(fn=classify_items, inputs=[questions_input, labels_input], outputs=[output_box, file_output])
62
+ csv_btn.click(fn=download_csv, inputs=[], outputs=file_output)
 
63
 
64
  demo.launch()