Spaces:
Runtime error
Runtime error
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
from collections import defaultdict | |
import json | |
import logging | |
import os | |
import platform | |
import shutil | |
import time | |
from typing import List, Tuple | |
import gradio as gr | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
import similarities | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--query_faq_examples_json_file", | |
default="query_faq_examples.json", | |
type=str | |
) | |
parser.add_argument( | |
"--history_json_file", | |
default="history.json", | |
type=str | |
) | |
parser.add_argument( | |
"--description_md_file", | |
default="description.md", | |
type=str | |
) | |
args = parser.parse_args() | |
return args | |
def click_query_candidates(qc_query: str, qc_candidates: str, qc_model_choices: str): | |
candidates = qc_candidates.split("\n") | |
candidates = [candidate.strip() for candidate in candidates if len(candidate.strip()) != 0] | |
model = SentenceTransformer(qc_model_choices) | |
embeddings = model.encode([qc_query] + candidates) | |
query_embedding = np.array(embeddings[0], dtype=np.float32) | |
candidates_embeddings = np.array(embeddings[1:], dtype=np.float32) | |
l2 = np.sum(query_embedding * candidates_embeddings, axis=-1) | |
scores = l2.tolist() | |
result = "" | |
for candidate, score in sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True): | |
row = "score: {}; candidate: {}\n".format(round(score, 4), candidate) | |
result += row | |
return result | |
def click_query_faq(qf_query: str, | |
qf_intent_1: str, qf_intent_2: str, qf_intent_3: str, qf_intent_4: str, qf_intent_5: str, | |
qf_model_choices: str, qf_top_k_total: int, qf_top_k_each: int, qf_min_score: float, | |
): | |
qf_intent_1 = qf_intent_1.split("\n") | |
qf_intent_1 = [q.strip() for q in qf_intent_1 if len(q.strip()) != 0] | |
qf_intent_2 = qf_intent_2.split("\n") | |
qf_intent_2 = [q.strip() for q in qf_intent_2 if len(q.strip()) != 0] | |
qf_intent_3 = qf_intent_3.split("\n") | |
qf_intent_3 = [q.strip() for q in qf_intent_3 if len(q.strip()) != 0] | |
qf_intent_4 = qf_intent_4.split("\n") | |
qf_intent_4 = [q.strip() for q in qf_intent_4 if len(q.strip()) != 0] | |
qf_intent_5 = qf_intent_5.split("\n") | |
qf_intent_5 = [q.strip() for q in qf_intent_5 if len(q.strip()) != 0] | |
labels = ["intent 1"] * len(qf_intent_1) | |
labels += ["intent 2"] * len(qf_intent_2) | |
labels += ["intent 3"] * len(qf_intent_3) | |
labels += ["intent 4"] * len(qf_intent_4) | |
labels += ["intent 5"] * len(qf_intent_5) | |
model = SentenceTransformer(qf_model_choices) | |
query_embedding = model.encode([qf_query]) | |
query_embedding = np.array(query_embedding, dtype=np.float32) | |
intent_embedding = model.encode(qf_intent_1 + qf_intent_2 + qf_intent_3 + qf_intent_4 + qf_intent_5) | |
intent_embedding = np.array(intent_embedding, dtype=np.float32) | |
scores = np.sum(query_embedding * intent_embedding, axis=-1) | |
arg_index = np.argsort(scores, axis=0) | |
arg_index = arg_index[-qf_top_k_total:] | |
arg_index = arg_index[::-1] | |
intent_to_score_list = defaultdict(list) | |
for index in arg_index: | |
score = scores[index] | |
label = labels[index] | |
intent_to_score_list[label].append(score) | |
intent_score_list = list() | |
for intent, score_list in intent_to_score_list.items(): | |
score_list = score_list[:qf_top_k_each] | |
mean_score = np.mean(score_list) | |
mean_score = round(float(mean_score), 4) | |
if mean_score < qf_min_score: | |
continue | |
intent_score_list.append((intent, mean_score)) | |
intent_score_list = list(sorted(intent_score_list, key=lambda x: x[1], reverse=True)) | |
result = "\n".join(["{}: {}".format(intent_score[0], round(intent_score[1], 4)) for intent_score in intent_score_list]) | |
return result | |
def main(): | |
args = get_args() | |
brief_description = """ | |
## Sentence Similarity | |
""" | |
# examples | |
query_faq_examples = list() | |
with open(args.query_faq_examples_json_file, "r", encoding="utf-8") as f: | |
query_faq_examples_ = json.load(f) | |
for example in query_faq_examples_: | |
query = example["query"] | |
intent1 = example["intent1"] | |
intent2 = example["intent2"] | |
intent3 = example["intent3"] | |
intent4 = example["intent4"] | |
intent5 = example["intent5"] | |
model = example["model"] | |
top_k_total = example["top_k_total"] | |
top_k_each = example["top_k_each"] | |
min_score = example["min_score"] | |
row = [ | |
query, | |
"\n".join(intent1), | |
"\n".join(intent2), | |
"\n".join(intent3), | |
"\n".join(intent4), | |
"\n".join(intent5), | |
model, | |
top_k_total, | |
top_k_each, | |
min_score | |
] | |
query_faq_examples.append(row) | |
# ui | |
with gr.Blocks() as blocks: | |
gr.Markdown(value=brief_description) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
with gr.Tabs(): | |
with gr.TabItem("query candidates"): | |
gr.Markdown(value="此标签用于 sentence-transformers 句向量余弦相似度匹配。") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
qc_candidates = gr.TextArea(label="candidates") | |
with gr.Column(scale=1): | |
qc_model_choices = gr.Dropdown( | |
choices=["sentence-transformers/all-MiniLM-L6-v2"], | |
value="sentence-transformers/all-MiniLM-L6-v2", | |
label="model" | |
) | |
qc_query = gr.Textbox(label="query") | |
qc_query_candidates_button = gr.Button("query", variant="primary") | |
_ = gr.ClearButton([qc_query, qc_candidates]) | |
gr.Examples( | |
examples=[ | |
[ | |
"how can i track my order ?", | |
"what's the track id ?\nhow much the price ?\ndid I paid for it ?", | |
"sentence-transformers/all-MiniLM-L6-v2" | |
] | |
], | |
inputs=[qc_query, qc_candidates, qc_model_choices] | |
) | |
with gr.TabItem("query FAQ"): | |
gr.Markdown(value="此标签基于 sentence-transformers 句向量的余弦相似度做 FAQ 匹配。") | |
with gr.Row(): | |
qf_intent_1 = gr.TextArea(label="FAQ intent 1") | |
qf_intent_2 = gr.TextArea(label="FAQ intent 2") | |
qf_intent_3 = gr.TextArea(label="FAQ intent 3") | |
with gr.Row(): | |
qf_intent_4 = gr.TextArea(label="FAQ intent 4") | |
qf_intent_5 = gr.TextArea(label="FAQ intent 5") | |
qf_intent = gr.TextArea(label="output intent") | |
with gr.Row(): | |
qf_query = gr.Textbox(label="query") | |
qf_model_choices = gr.Dropdown( | |
choices=["sentence-transformers/all-MiniLM-L6-v2"], | |
value="sentence-transformers/all-MiniLM-L6-v2", | |
label="model" | |
) | |
qf_top_k_total = gr.Slider(minimum=0, maximum=30, value=10, step=1, label="top_k_total") | |
qf_top_k_each = gr.Slider(minimum=0, maximum=30, value=5, step=1, label="top_k_each") | |
qf_min_score = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.1, label="min_score") | |
with gr.Row(): | |
qf_query_faq_button = gr.Button("query", variant="primary") | |
_ = gr.ClearButton(components=[ | |
qf_query, | |
qf_intent_1, qf_intent_2, qf_intent_3, qf_intent_4, qf_intent_5, | |
qf_model_choices, | |
]) | |
gr.Examples( | |
examples=query_faq_examples, | |
inputs=[ | |
qf_query, | |
qf_intent_1, qf_intent_2, qf_intent_3, qf_intent_4, qf_intent_5, | |
qf_model_choices, | |
qf_top_k_total, qf_top_k_each, qf_min_score | |
] | |
) | |
# click event | |
qc_query_candidates_button.click( | |
click_query_candidates, | |
inputs=[qc_query, qc_candidates, qc_model_choices], | |
outputs=[qc_candidates] | |
) | |
qf_query_faq_button.click( | |
click_query_faq, | |
inputs=[ | |
qf_query, | |
qf_intent_1, qf_intent_2, qf_intent_3, qf_intent_4, qf_intent_5, | |
qf_model_choices, qf_top_k_total, qf_top_k_each, qf_min_score | |
], | |
outputs=[qf_intent] | |
) | |
blocks.queue().launch( | |
share=False if platform.system() == "Windows" else False | |
) | |
return | |
if __name__ == '__main__': | |
main() | |