ZeroShotTagger / app.py
naveenus's picture
Update app.py
dc2875b verified
import json, time, csv, os
import gradio as gr
from transformers import pipeline
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Load taxonomies
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with open("coarse_labels.json") as f:
coarse_labels = json.load(f)
with open("fine_labels.json") as f:
fine_map = json.load(f)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Model choices (5 only)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
MODEL_CHOICES = [
"facebook/bart-large-mnli",
"roberta-large-mnli",
"joeddav/xlm-roberta-large-xnli",
"valhalla/distilbart-mnli-12-4",
"MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" ,
"MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"# placeholder β€” replace with real phantom model
]
PIPELINES = {}
def get_pipeline(name):
if name not in PIPELINES:
PIPELINES[name] = pipeline("zero-shot-classification", model=name)
return PIPELINES[name]
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Ensure log files exist
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
LOG_FILE = "logs.csv"
FEEDBACK_FILE = "feedback.csv"
for fn, hdr in [
(LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]),
(FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"])
]:
if not os.path.exists(fn):
with open(fn, "w", newline="") as f:
csv.writer(f).writerow(hdr)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Inference functions
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def run_stage1(question, model_name):
if not question or not question.strip():
return {}, gr.update(choices=[]), ""
start = time.time()
clf = get_pipeline(model_name)
out = clf(question, candidate_labels=coarse_labels)
labels, scores = out["labels"][:3], out["scores"][:3]
duration = round(time.time() - start, 3)
# Prepare outputs
subject_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)}
radio_update = gr.update(choices=labels, value=labels[0])
time_str = f"⏱ {duration}s"
return subject_dict, radio_update, time_str
def run_stage2(question, model_name, subject):
# 1) Validate inputs
if not question or not question.strip():
return {}, "No question provided", ""
fine_labels = fine_map.get(subject, [])
if not fine_labels:
return {}, f"No topics found for '{subject}'", ""
# 2) Inference (fast, using preloaded pipeline)
start = time.time()
clf = get_pipeline(model_name)
out = clf(question, candidate_labels=fine_labels)
labels, scores = out["labels"][:3], out["scores"][:3]
duration = round(time.time() - start, 3)
# 3) Logging
with open(LOG_FILE, "a", newline="") as f:
csv.writer(f).writerow([
time.strftime("%Y-%m-%d %H:%M:%S"),
model_name,
question.replace("\n"," "),
subject,
";".join(labels),
duration
])
# 4) Return topics + time
topic_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)}
return topic_dict, f"⏱ {duration}s"
def submit_feedback(question, subject_fb, topic_fb):
with open(FEEDBACK_FILE, "a", newline="") as f:
csv.writer(f).writerow([
time.strftime("%Y-%m-%d %H:%M:%S"),
question.replace("\n"," "),
subject_fb,
topic_fb
])
return "βœ… Feedback recorded!"
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Build Gradio UI
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with gr.Blocks() as demo:
gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback")
with gr.Row():
question_input = gr.Textbox(lines=3, label="Enter your question")
model_input = gr.Dropdown(choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model")
go_button = gr.Button("Run Stage 1")
subject_out = gr.Label(num_top_classes=3, label="Top-3 Subjects")
subj_radio = gr.Radio(choices=[], label="Select Subject for Stage 2")
stage1_time = gr.Textbox(label="Stage 1 Time")
go_button.click(
fn=run_stage1,
inputs=[question_input, model_input],
outputs=[subject_out, subj_radio, stage1_time]
)
# Stage 2 UI
go2_button = gr.Button("Run Stage 2")
topics_out = gr.Label(label="Top-3 Topics")
stage2_time = gr.Textbox(label="Stage 2 Time")
go2_button.click(
fn=run_stage2,
inputs=[question_input, model_input, subj_radio],
outputs=[topics_out, stage2_time]
)
gr.Markdown("---")
gr.Markdown("### Feedback / Correction")
subject_fb = gr.Textbox(label="Correct Subject")
topic_fb = gr.Textbox(label="Correct Topic(s)")
fb_button = gr.Button("Submit Feedback")
fb_status = gr.Textbox(label="")
fb_button.click(
fn=submit_feedback,
inputs=[question_input, subject_fb, topic_fb],
outputs=[fb_status]
)
demo.launch(share=True)