qgyd2021's picture
update
5106748
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
import json
import logging
import os
import platform
import shutil
import time
from typing import List, Tuple
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer
import similarities
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--query_faq_examples_json_file",
default="query_faq_examples.json",
type=str
)
parser.add_argument(
"--history_json_file",
default="history.json",
type=str
)
parser.add_argument(
"--description_md_file",
default="description.md",
type=str
)
args = parser.parse_args()
return args
def click_query_candidates(qc_query: str, qc_candidates: str, qc_model_choices: str):
candidates = qc_candidates.split("\n")
candidates = [candidate.strip() for candidate in candidates if len(candidate.strip()) != 0]
model = SentenceTransformer(qc_model_choices)
embeddings = model.encode([qc_query] + candidates)
query_embedding = np.array(embeddings[0], dtype=np.float32)
candidates_embeddings = np.array(embeddings[1:], dtype=np.float32)
l2 = np.sum(query_embedding * candidates_embeddings, axis=-1)
scores = l2.tolist()
result = ""
for candidate, score in sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True):
row = "score: {}; candidate: {}\n".format(round(score, 4), candidate)
result += row
return result
def click_query_faq(qf_query: str,
qf_intent_1: str, qf_intent_2: str, qf_intent_3: str, qf_intent_4: str, qf_intent_5: str,
qf_model_choices: str, qf_top_k_total: int, qf_top_k_each: int, qf_min_score: float,
):
qf_intent_1 = qf_intent_1.split("\n")
qf_intent_1 = [q.strip() for q in qf_intent_1 if len(q.strip()) != 0]
qf_intent_2 = qf_intent_2.split("\n")
qf_intent_2 = [q.strip() for q in qf_intent_2 if len(q.strip()) != 0]
qf_intent_3 = qf_intent_3.split("\n")
qf_intent_3 = [q.strip() for q in qf_intent_3 if len(q.strip()) != 0]
qf_intent_4 = qf_intent_4.split("\n")
qf_intent_4 = [q.strip() for q in qf_intent_4 if len(q.strip()) != 0]
qf_intent_5 = qf_intent_5.split("\n")
qf_intent_5 = [q.strip() for q in qf_intent_5 if len(q.strip()) != 0]
labels = ["intent 1"] * len(qf_intent_1)
labels += ["intent 2"] * len(qf_intent_2)
labels += ["intent 3"] * len(qf_intent_3)
labels += ["intent 4"] * len(qf_intent_4)
labels += ["intent 5"] * len(qf_intent_5)
model = SentenceTransformer(qf_model_choices)
query_embedding = model.encode([qf_query])
query_embedding = np.array(query_embedding, dtype=np.float32)
intent_embedding = model.encode(qf_intent_1 + qf_intent_2 + qf_intent_3 + qf_intent_4 + qf_intent_5)
intent_embedding = np.array(intent_embedding, dtype=np.float32)
scores = np.sum(query_embedding * intent_embedding, axis=-1)
arg_index = np.argsort(scores, axis=0)
arg_index = arg_index[-qf_top_k_total:]
arg_index = arg_index[::-1]
intent_to_score_list = defaultdict(list)
for index in arg_index:
score = scores[index]
label = labels[index]
intent_to_score_list[label].append(score)
intent_score_list = list()
for intent, score_list in intent_to_score_list.items():
score_list = score_list[:qf_top_k_each]
mean_score = np.mean(score_list)
mean_score = round(float(mean_score), 4)
if mean_score < qf_min_score:
continue
intent_score_list.append((intent, mean_score))
intent_score_list = list(sorted(intent_score_list, key=lambda x: x[1], reverse=True))
result = "\n".join(["{}: {}".format(intent_score[0], round(intent_score[1], 4)) for intent_score in intent_score_list])
return result
def main():
args = get_args()
brief_description = """
## Sentence Similarity
"""
# examples
query_faq_examples = list()
with open(args.query_faq_examples_json_file, "r", encoding="utf-8") as f:
query_faq_examples_ = json.load(f)
for example in query_faq_examples_:
query = example["query"]
intent1 = example["intent1"]
intent2 = example["intent2"]
intent3 = example["intent3"]
intent4 = example["intent4"]
intent5 = example["intent5"]
model = example["model"]
top_k_total = example["top_k_total"]
top_k_each = example["top_k_each"]
min_score = example["min_score"]
row = [
query,
"\n".join(intent1),
"\n".join(intent2),
"\n".join(intent3),
"\n".join(intent4),
"\n".join(intent5),
model,
top_k_total,
top_k_each,
min_score
]
query_faq_examples.append(row)
# ui
with gr.Blocks() as blocks:
gr.Markdown(value=brief_description)
with gr.Row():
with gr.Column(scale=5):
with gr.Tabs():
with gr.TabItem("query candidates"):
gr.Markdown(value="此标签用于 sentence-transformers 句向量余弦相似度匹配。")
with gr.Row():
with gr.Column(scale=1):
qc_candidates = gr.TextArea(label="candidates")
with gr.Column(scale=1):
qc_model_choices = gr.Dropdown(
choices=["sentence-transformers/all-MiniLM-L6-v2"],
value="sentence-transformers/all-MiniLM-L6-v2",
label="model"
)
qc_query = gr.Textbox(label="query")
qc_query_candidates_button = gr.Button("query", variant="primary")
_ = gr.ClearButton([qc_query, qc_candidates])
gr.Examples(
examples=[
[
"how can i track my order ?",
"what's the track id ?\nhow much the price ?\ndid I paid for it ?",
"sentence-transformers/all-MiniLM-L6-v2"
]
],
inputs=[qc_query, qc_candidates, qc_model_choices]
)
with gr.TabItem("query FAQ"):
gr.Markdown(value="此标签基于 sentence-transformers 句向量的余弦相似度做 FAQ 匹配。")
with gr.Row():
qf_intent_1 = gr.TextArea(label="FAQ intent 1")
qf_intent_2 = gr.TextArea(label="FAQ intent 2")
qf_intent_3 = gr.TextArea(label="FAQ intent 3")
with gr.Row():
qf_intent_4 = gr.TextArea(label="FAQ intent 4")
qf_intent_5 = gr.TextArea(label="FAQ intent 5")
qf_intent = gr.TextArea(label="output intent")
with gr.Row():
qf_query = gr.Textbox(label="query")
qf_model_choices = gr.Dropdown(
choices=["sentence-transformers/all-MiniLM-L6-v2"],
value="sentence-transformers/all-MiniLM-L6-v2",
label="model"
)
qf_top_k_total = gr.Slider(minimum=0, maximum=30, value=10, step=1, label="top_k_total")
qf_top_k_each = gr.Slider(minimum=0, maximum=30, value=5, step=1, label="top_k_each")
qf_min_score = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.1, label="min_score")
with gr.Row():
qf_query_faq_button = gr.Button("query", variant="primary")
_ = gr.ClearButton(components=[
qf_query,
qf_intent_1, qf_intent_2, qf_intent_3, qf_intent_4, qf_intent_5,
qf_model_choices,
])
gr.Examples(
examples=query_faq_examples,
inputs=[
qf_query,
qf_intent_1, qf_intent_2, qf_intent_3, qf_intent_4, qf_intent_5,
qf_model_choices,
qf_top_k_total, qf_top_k_each, qf_min_score
]
)
# click event
qc_query_candidates_button.click(
click_query_candidates,
inputs=[qc_query, qc_candidates, qc_model_choices],
outputs=[qc_candidates]
)
qf_query_faq_button.click(
click_query_faq,
inputs=[
qf_query,
qf_intent_1, qf_intent_2, qf_intent_3, qf_intent_4, qf_intent_5,
qf_model_choices, qf_top_k_total, qf_top_k_each, qf_min_score
],
outputs=[qf_intent]
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False
)
return
if __name__ == '__main__':
main()