Alexandra Zapko-Willmes commited on
Commit
86f9780
·
verified ·
1 Parent(s): 842372c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -21
app.py CHANGED
@@ -1,22 +1,58 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- models = {
5
- "Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
6
- "Falcon-7B": "tiiuae/falcon-7b-instruct"
7
- }
8
-
9
- def ask_model(question, model_choice):
10
- client = InferenceClient(models[model_choice])
11
- prompt = f"Answer this questionnaire item: {question} (Strongly disagree - Strongly agree)"
12
- return client.text_generation(prompt=prompt)
13
-
14
- gr.Interface(
15
- fn=ask_model,
16
- inputs=[
17
- gr.Textbox(label="Questionnaire Item"),
18
- gr.Dropdown(list(models.keys()), label="Choose Model")
19
- ],
20
- outputs="text",
21
- title="LLM-Powered Questionnaire"
22
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+ import pandas as pd
4
+ import io
5
+
6
+ # Load once
7
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
8
+
9
+ LIKERT_OPTIONS = ["Strongly disagree", "Disagree", "Neutral", "Agree", "Strongly agree"]
10
+
11
+ response_table = []
12
+
13
+ def classify_likert(questions_text):
14
+ questions = [q.strip() for q in questions_text.strip().split("\n") if q.strip()]
15
+
16
+ global response_table
17
+ response_table = []
18
+ output_lines = []
19
+
20
+ for i, question in enumerate(questions, 1):
21
+ result = classifier(question, LIKERT_OPTIONS, multi_label=False)
22
+ probs = dict(zip(result['labels'], result['scores']))
23
+ output_lines.append(f"{i}. {question}")
24
+ for label in LIKERT_OPTIONS:
25
+ prob = round(probs.get(label, 0.0), 3)
26
+ output_lines.append(f"→ {label}: {prob}")
27
+ output_lines.append("")
28
+
29
+ row = {"Item #": i, "Item": question}
30
+ row.update({label: round(probs.get(label, 0.0), 3) for label in LIKERT_OPTIONS})
31
+ response_table.append(row)
32
+
33
+ return "\n".join(output_lines)
34
+
35
+ def download_csv():
36
+ global response_table
37
+ if not response_table:
38
+ return None
39
+ df = pd.DataFrame(response_table)
40
+ csv_buffer = io.StringIO()
41
+ df.to_csv(csv_buffer, index=False)
42
+ return csv_buffer.getvalue()
43
+
44
+ # Gradio interface
45
+ with gr.Blocks() as demo:
46
+ gr.Markdown("# Likert-Style Zero-Shot Classifier")
47
+ gr.Markdown("Paste questionnaire items. Each will be classified into: Strongly disagree → Strongly agree, with probabilities.")
48
+
49
+ questions_input = gr.Textbox(label="Enter multiple items (one per line)", lines=10, placeholder="e.g.\nI feel in control of my life.\nI enjoy being around others...")
50
+ output_box = gr.Textbox(label="Classification Output", lines=20)
51
+ submit_btn = gr.Button("Classify Items")
52
+ csv_btn = gr.Button("📥 Download CSV")
53
+ file_output = gr.File(label="Download CSV", visible=False)
54
+
55
+ submit_btn.click(fn=classify_likert, inputs=questions_input, outputs=output_box)
56
+ csv_btn.click(fn=download_csv, inputs=[], outputs=file_output)
57
+
58
+ demo.launch()