Alexandra Zapko-Willmes commited on
Commit
4036329
·
verified ·
1 Parent(s): 86f9780

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -46
app.py CHANGED
@@ -1,58 +1,43 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
  import pandas as pd
4
- import io
5
 
6
- # Load once
7
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
 
 
8
 
9
- LIKERT_OPTIONS = ["Strongly disagree", "Disagree", "Neutral", "Agree", "Strongly agree"]
 
 
 
10
 
11
- response_table = []
 
 
 
 
 
 
 
 
12
 
13
- def classify_likert(questions_text):
14
- questions = [q.strip() for q in questions_text.strip().split("\n") if q.strip()]
15
-
16
- global response_table
17
- response_table = []
18
- output_lines = []
19
-
20
- for i, question in enumerate(questions, 1):
21
- result = classifier(question, LIKERT_OPTIONS, multi_label=False)
22
- probs = dict(zip(result['labels'], result['scores']))
23
- output_lines.append(f"{i}. {question}")
24
- for label in LIKERT_OPTIONS:
25
- prob = round(probs.get(label, 0.0), 3)
26
- output_lines.append(f"→ {label}: {prob}")
27
- output_lines.append("")
28
-
29
- row = {"Item #": i, "Item": question}
30
- row.update({label: round(probs.get(label, 0.0), 3) for label in LIKERT_OPTIONS})
31
- response_table.append(row)
32
-
33
- return "\n".join(output_lines)
34
-
35
- def download_csv():
36
- global response_table
37
- if not response_table:
38
- return None
39
- df = pd.DataFrame(response_table)
40
- csv_buffer = io.StringIO()
41
- df.to_csv(csv_buffer, index=False)
42
- return csv_buffer.getvalue()
43
-
44
- # Gradio interface
45
  with gr.Blocks() as demo:
46
- gr.Markdown("# Likert-Style Zero-Shot Classifier")
47
- gr.Markdown("Paste questionnaire items. Each will be classified into: Strongly disagree → Strongly agree, with probabilities.")
 
 
 
 
 
48
 
49
- questions_input = gr.Textbox(label="Enter multiple items (one per line)", lines=10, placeholder="e.g.\nI feel in control of my life.\nI enjoy being around others...")
50
- output_box = gr.Textbox(label="Classification Output", lines=20)
51
- submit_btn = gr.Button("Classify Items")
52
- csv_btn = gr.Button("📥 Download CSV")
53
- file_output = gr.File(label="Download CSV", visible=False)
54
 
55
- submit_btn.click(fn=classify_likert, inputs=questions_input, outputs=output_box)
56
- csv_btn.click(fn=download_csv, inputs=[], outputs=file_output)
 
57
 
58
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
  import pandas as pd
 
4
 
5
+ MODEL_MAP = {
6
+ "MoritzLaurer/deberta-v3-large-zeroshot-v2.0": "MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
7
+ "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
8
+ "joeddav/xlm-roberta-large-xnli": "joeddav/xlm-roberta-large-xnli"
9
+ }
10
 
11
+ def classify_items(model_name, items_text, labels_text):
12
+ classifier = pipeline("zero-shot-classification", model=MODEL_MAP[model_name])
13
+ items = [item.strip() for item in items_text.split("\n") if item.strip()]
14
+ labels = [label.strip() for label in labels_text.split(",") if label.strip()]
15
 
16
+ results = []
17
+ for item in items:
18
+ out = classifier(item, labels, multi_label=True)
19
+ scores = {label: prob for label, prob in zip(out["labels"], out["scores"])}
20
+ scores["item"] = item
21
+ results.append(scores)
22
+
23
+ df = pd.DataFrame(results).fillna(0)
24
+ return df, df.to_csv(index=False)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  with gr.Blocks() as demo:
27
+ gr.Markdown("### Zero-Shot Questionnaire Classifier")
28
+
29
+ with gr.Row():
30
+ model_choice = gr.Dropdown(choices=list(MODEL_MAP.keys()), label="Choose a model")
31
+
32
+ item_input = gr.Textbox(label="Enter questionnaire items (one per line)", lines=5, placeholder="e.g., I enjoy social gatherings.\nI prefer planning over spontaneity.")
33
+ label_input = gr.Textbox(label="Enter response options (comma-separated)", placeholder="e.g., Strongly disagree, Disagree, Neutral, Agree, Strongly agree")
34
 
35
+ run_button = gr.Button("Classify")
36
+ output_table = gr.Dataframe(label="Classification Results")
37
+ download_csv = gr.File(label="Download CSV")
 
 
38
 
39
+ run_button.click(fn=classify_items,
40
+ inputs=[model_choice, item_input, label_input],
41
+ outputs=[output_table, download_csv])
42
 
43
  demo.launch()