#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from collections import defaultdict import json import logging import os import platform import shutil import time from typing import List, Tuple import gradio as gr import numpy as np from sentence_transformers import SentenceTransformer import similarities def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--query_faq_examples_json_file", default="query_faq_examples.json", type=str ) parser.add_argument( "--history_json_file", default="history.json", type=str ) parser.add_argument( "--description_md_file", default="description.md", type=str ) args = parser.parse_args() return args def click_query_candidates(qc_query: str, qc_candidates: str, qc_model_choices: str): candidates = qc_candidates.split("\n") candidates = [candidate.strip() for candidate in candidates if len(candidate.strip()) != 0] model = SentenceTransformer(qc_model_choices) embeddings = model.encode([qc_query] + candidates) query_embedding = np.array(embeddings[0], dtype=np.float32) candidates_embeddings = np.array(embeddings[1:], dtype=np.float32) l2 = np.sum(query_embedding * candidates_embeddings, axis=-1) scores = l2.tolist() result = "" for candidate, score in sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True): row = "score: {}; candidate: {}\n".format(round(score, 4), candidate) result += row return result def click_query_faq(qf_query: str, qf_intent_1: str, qf_intent_2: str, qf_intent_3: str, qf_intent_4: str, qf_intent_5: str, qf_model_choices: str, qf_top_k_total: int, qf_top_k_each: int, qf_min_score: float, ): qf_intent_1 = qf_intent_1.split("\n") qf_intent_1 = [q.strip() for q in qf_intent_1 if len(q.strip()) != 0] qf_intent_2 = qf_intent_2.split("\n") qf_intent_2 = [q.strip() for q in qf_intent_2 if len(q.strip()) != 0] qf_intent_3 = qf_intent_3.split("\n") qf_intent_3 = [q.strip() for q in qf_intent_3 if len(q.strip()) != 0] qf_intent_4 = qf_intent_4.split("\n") qf_intent_4 = [q.strip() for q in qf_intent_4 if len(q.strip()) != 0] qf_intent_5 = qf_intent_5.split("\n") qf_intent_5 = [q.strip() for q in qf_intent_5 if len(q.strip()) != 0] labels = ["intent 1"] * len(qf_intent_1) labels += ["intent 2"] * len(qf_intent_2) labels += ["intent 3"] * len(qf_intent_3) labels += ["intent 4"] * len(qf_intent_4) labels += ["intent 5"] * len(qf_intent_5) model = SentenceTransformer(qf_model_choices) query_embedding = model.encode([qf_query]) query_embedding = np.array(query_embedding, dtype=np.float32) intent_embedding = model.encode(qf_intent_1 + qf_intent_2 + qf_intent_3 + qf_intent_4 + qf_intent_5) intent_embedding = np.array(intent_embedding, dtype=np.float32) scores = np.sum(query_embedding * intent_embedding, axis=-1) arg_index = np.argsort(scores, axis=0) arg_index = arg_index[-qf_top_k_total:] arg_index = arg_index[::-1] intent_to_score_list = defaultdict(list) for index in arg_index: score = scores[index] label = labels[index] intent_to_score_list[label].append(score) intent_score_list = list() for intent, score_list in intent_to_score_list.items(): score_list = score_list[:qf_top_k_each] mean_score = np.mean(score_list) mean_score = round(float(mean_score), 4) if mean_score < qf_min_score: continue intent_score_list.append((intent, mean_score)) intent_score_list = list(sorted(intent_score_list, key=lambda x: x[1], reverse=True)) result = "\n".join(["{}: {}".format(intent_score[0], round(intent_score[1], 4)) for intent_score in intent_score_list]) return result def main(): args = get_args() brief_description = """ ## Sentence Similarity """ # examples query_faq_examples = list() with open(args.query_faq_examples_json_file, "r", encoding="utf-8") as f: query_faq_examples_ = json.load(f) for example in query_faq_examples_: query = example["query"] intent1 = example["intent1"] intent2 = example["intent2"] intent3 = example["intent3"] intent4 = example["intent4"] intent5 = example["intent5"] model = example["model"] top_k_total = example["top_k_total"] top_k_each = example["top_k_each"] min_score = example["min_score"] row = [ query, "\n".join(intent1), "\n".join(intent2), "\n".join(intent3), "\n".join(intent4), "\n".join(intent5), model, top_k_total, top_k_each, min_score ] query_faq_examples.append(row) # ui with gr.Blocks() as blocks: gr.Markdown(value=brief_description) with gr.Row(): with gr.Column(scale=5): with gr.Tabs(): with gr.TabItem("query candidates"): gr.Markdown(value="此标签用于 sentence-transformers 句向量余弦相似度匹配。") with gr.Row(): with gr.Column(scale=1): qc_candidates = gr.TextArea(label="candidates") with gr.Column(scale=1): qc_model_choices = gr.Dropdown( choices=["sentence-transformers/all-MiniLM-L6-v2"], value="sentence-transformers/all-MiniLM-L6-v2", label="model" ) qc_query = gr.Textbox(label="query") qc_query_candidates_button = gr.Button("query", variant="primary") _ = gr.ClearButton([qc_query, qc_candidates]) gr.Examples( examples=[ [ "how can i track my order ?", "what's the track id ?\nhow much the price ?\ndid I paid for it ?", "sentence-transformers/all-MiniLM-L6-v2" ] ], inputs=[qc_query, qc_candidates, qc_model_choices] ) with gr.TabItem("query FAQ"): gr.Markdown(value="此标签基于 sentence-transformers 句向量的余弦相似度做 FAQ 匹配。") with gr.Row(): qf_intent_1 = gr.TextArea(label="FAQ intent 1") qf_intent_2 = gr.TextArea(label="FAQ intent 2") qf_intent_3 = gr.TextArea(label="FAQ intent 3") with gr.Row(): qf_intent_4 = gr.TextArea(label="FAQ intent 4") qf_intent_5 = gr.TextArea(label="FAQ intent 5") qf_intent = gr.TextArea(label="output intent") with gr.Row(): qf_query = gr.Textbox(label="query") qf_model_choices = gr.Dropdown( choices=["sentence-transformers/all-MiniLM-L6-v2"], value="sentence-transformers/all-MiniLM-L6-v2", label="model" ) qf_top_k_total = gr.Slider(minimum=0, maximum=30, value=10, step=1, label="top_k_total") qf_top_k_each = gr.Slider(minimum=0, maximum=30, value=5, step=1, label="top_k_each") qf_min_score = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.1, label="min_score") with gr.Row(): qf_query_faq_button = gr.Button("query", variant="primary") _ = gr.ClearButton(components=[ qf_query, qf_intent_1, qf_intent_2, qf_intent_3, qf_intent_4, qf_intent_5, qf_model_choices, ]) gr.Examples( examples=query_faq_examples, inputs=[ qf_query, qf_intent_1, qf_intent_2, qf_intent_3, qf_intent_4, qf_intent_5, qf_model_choices, qf_top_k_total, qf_top_k_each, qf_min_score ] ) # click event qc_query_candidates_button.click( click_query_candidates, inputs=[qc_query, qc_candidates, qc_model_choices], outputs=[qc_candidates] ) qf_query_faq_button.click( click_query_faq, inputs=[ qf_query, qf_intent_1, qf_intent_2, qf_intent_3, qf_intent_4, qf_intent_5, qf_model_choices, qf_top_k_total, qf_top_k_each, qf_min_score ], outputs=[qf_intent] ) blocks.queue().launch( share=False if platform.system() == "Windows" else False ) return if __name__ == '__main__': main()