Spaces:
Running
Running
import json, time, csv, os | |
import gradio as gr | |
from transformers import pipeline | |
# ββββββββββββββββ | |
# Load taxonomies | |
# ββββββββββββββββ | |
with open("coarse_labels.json") as f: | |
coarse_labels = json.load(f) | |
with open("fine_labels.json") as f: | |
fine_map = json.load(f) | |
# ββββββββββββββββ | |
# Model choices (5 only) | |
# ββββββββββββββββ | |
MODEL_CHOICES = [ | |
"facebook/bart-large-mnli", | |
"roberta-large-mnli", | |
"joeddav/xlm-roberta-large-xnli", | |
"valhalla/distilbart-mnli-12-4", | |
"MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" , | |
"MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"# placeholder β replace with real phantom model | |
] | |
PIPELINES = {} | |
def get_pipeline(name): | |
if name not in PIPELINES: | |
PIPELINES[name] = pipeline("zero-shot-classification", model=name) | |
return PIPELINES[name] | |
# ββββββββββββββββ | |
# Ensure log files exist | |
# ββββββββββββββββ | |
LOG_FILE = "logs.csv" | |
FEEDBACK_FILE = "feedback.csv" | |
for fn, hdr in [ | |
(LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]), | |
(FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"]) | |
]: | |
if not os.path.exists(fn): | |
with open(fn, "w", newline="") as f: | |
csv.writer(f).writerow(hdr) | |
# ββββββββββββββββ | |
# Inference functions | |
# ββββββββββββββββ | |
def run_stage1(question, model_name): | |
if not question or not question.strip(): | |
return {}, gr.update(choices=[]), "" | |
start = time.time() | |
clf = get_pipeline(model_name) | |
out = clf(question, candidate_labels=coarse_labels) | |
labels, scores = out["labels"][:3], out["scores"][:3] | |
duration = round(time.time() - start, 3) | |
# Prepare outputs | |
subject_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)} | |
radio_update = gr.update(choices=labels, value=labels[0]) | |
time_str = f"β± {duration}s" | |
return subject_dict, radio_update, time_str | |
def run_stage2(question, model_name, subject): | |
# 1) Validate inputs | |
if not question or not question.strip(): | |
return {}, "No question provided", "" | |
fine_labels = fine_map.get(subject, []) | |
if not fine_labels: | |
return {}, f"No topics found for '{subject}'", "" | |
# 2) Inference (fast, using preloaded pipeline) | |
start = time.time() | |
clf = get_pipeline(model_name) | |
out = clf(question, candidate_labels=fine_labels) | |
labels, scores = out["labels"][:3], out["scores"][:3] | |
duration = round(time.time() - start, 3) | |
# 3) Logging | |
with open(LOG_FILE, "a", newline="") as f: | |
csv.writer(f).writerow([ | |
time.strftime("%Y-%m-%d %H:%M:%S"), | |
model_name, | |
question.replace("\n"," "), | |
subject, | |
";".join(labels), | |
duration | |
]) | |
# 4) Return topics + time | |
topic_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)} | |
return topic_dict, f"β± {duration}s" | |
def submit_feedback(question, subject_fb, topic_fb): | |
with open(FEEDBACK_FILE, "a", newline="") as f: | |
csv.writer(f).writerow([ | |
time.strftime("%Y-%m-%d %H:%M:%S"), | |
question.replace("\n"," "), | |
subject_fb, | |
topic_fb | |
]) | |
return "β Feedback recorded!" | |
# ββββββββββββββββ | |
# Build Gradio UI | |
# ββββββββββββββββ | |
with gr.Blocks() as demo: | |
gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback") | |
with gr.Row(): | |
question_input = gr.Textbox(lines=3, label="Enter your question") | |
model_input = gr.Dropdown(choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model") | |
go_button = gr.Button("Run Stage 1") | |
subject_out = gr.Label(num_top_classes=3, label="Top-3 Subjects") | |
subj_radio = gr.Radio(choices=[], label="Select Subject for Stage 2") | |
stage1_time = gr.Textbox(label="Stage 1 Time") | |
go_button.click( | |
fn=run_stage1, | |
inputs=[question_input, model_input], | |
outputs=[subject_out, subj_radio, stage1_time] | |
) | |
# Stage 2 UI | |
go2_button = gr.Button("Run Stage 2") | |
topics_out = gr.Label(label="Top-3 Topics") | |
stage2_time = gr.Textbox(label="Stage 2 Time") | |
go2_button.click( | |
fn=run_stage2, | |
inputs=[question_input, model_input, subj_radio], | |
outputs=[topics_out, stage2_time] | |
) | |
gr.Markdown("---") | |
gr.Markdown("### Feedback / Correction") | |
subject_fb = gr.Textbox(label="Correct Subject") | |
topic_fb = gr.Textbox(label="Correct Topic(s)") | |
fb_button = gr.Button("Submit Feedback") | |
fb_status = gr.Textbox(label="") | |
fb_button.click( | |
fn=submit_feedback, | |
inputs=[question_input, subject_fb, topic_fb], | |
outputs=[fb_status] | |
) | |
demo.launch(share=True) | |