import json, time, csv, os import gradio as gr from transformers import pipeline # ———————————————— # Load taxonomies # ———————————————— with open("coarse_labels.json") as f: coarse_labels = json.load(f) with open("fine_labels.json") as f: fine_map = json.load(f) # ———————————————— # Model choices (5 only) # ———————————————— MODEL_CHOICES = [ "facebook/bart-large-mnli", "roberta-large-mnli", "joeddav/xlm-roberta-large-xnli", "valhalla/distilbart-mnli-12-4", "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" , "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"# placeholder — replace with real phantom model ] PIPELINES = {} def get_pipeline(name): if name not in PIPELINES: PIPELINES[name] = pipeline("zero-shot-classification", model=name) return PIPELINES[name] # ———————————————— # Ensure log files exist # ———————————————— LOG_FILE = "logs.csv" FEEDBACK_FILE = "feedback.csv" for fn, hdr in [ (LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]), (FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"]) ]: if not os.path.exists(fn): with open(fn, "w", newline="") as f: csv.writer(f).writerow(hdr) # ———————————————— # Inference functions # ———————————————— def run_stage1(question, model_name): if not question or not question.strip(): return {}, gr.update(choices=[]), "" start = time.time() clf = get_pipeline(model_name) out = clf(question, candidate_labels=coarse_labels) labels, scores = out["labels"][:3], out["scores"][:3] duration = round(time.time() - start, 3) # Prepare outputs subject_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)} radio_update = gr.update(choices=labels, value=labels[0]) time_str = f"⏱ {duration}s" return subject_dict, radio_update, time_str def run_stage2(question, model_name, subject): # 1) Validate inputs if not question or not question.strip(): return {}, "No question provided", "" fine_labels = fine_map.get(subject, []) if not fine_labels: return {}, f"No topics found for '{subject}'", "" # 2) Inference (fast, using preloaded pipeline) start = time.time() clf = get_pipeline(model_name) out = clf(question, candidate_labels=fine_labels) labels, scores = out["labels"][:3], out["scores"][:3] duration = round(time.time() - start, 3) # 3) Logging with open(LOG_FILE, "a", newline="") as f: csv.writer(f).writerow([ time.strftime("%Y-%m-%d %H:%M:%S"), model_name, question.replace("\n"," "), subject, ";".join(labels), duration ]) # 4) Return topics + time topic_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)} return topic_dict, f"⏱ {duration}s" def submit_feedback(question, subject_fb, topic_fb): with open(FEEDBACK_FILE, "a", newline="") as f: csv.writer(f).writerow([ time.strftime("%Y-%m-%d %H:%M:%S"), question.replace("\n"," "), subject_fb, topic_fb ]) return "✅ Feedback recorded!" # ———————————————— # Build Gradio UI # ———————————————— with gr.Blocks() as demo: gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback") with gr.Row(): question_input = gr.Textbox(lines=3, label="Enter your question") model_input = gr.Dropdown(choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model") go_button = gr.Button("Run Stage 1") subject_out = gr.Label(num_top_classes=3, label="Top-3 Subjects") subj_radio = gr.Radio(choices=[], label="Select Subject for Stage 2") stage1_time = gr.Textbox(label="Stage 1 Time") go_button.click( fn=run_stage1, inputs=[question_input, model_input], outputs=[subject_out, subj_radio, stage1_time] ) # Stage 2 UI go2_button = gr.Button("Run Stage 2") topics_out = gr.Label(label="Top-3 Topics") stage2_time = gr.Textbox(label="Stage 2 Time") go2_button.click( fn=run_stage2, inputs=[question_input, model_input, subj_radio], outputs=[topics_out, stage2_time] ) gr.Markdown("---") gr.Markdown("### Feedback / Correction") subject_fb = gr.Textbox(label="Correct Subject") topic_fb = gr.Textbox(label="Correct Topic(s)") fb_button = gr.Button("Submit Feedback") fb_status = gr.Textbox(label="") fb_button.click( fn=submit_feedback, inputs=[question_input, subject_fb, topic_fb], outputs=[fb_status] ) demo.launch(share=True)