Spaces:
Paused
Paused
File size: 5,126 Bytes
757e2cd 69ce27b 757e2cd b5b6ecb 757e2cd 8c20665 69ce27b 757e2cd 8c20665 757e2cd b5b6ecb 58ecde0 757e2cd b5b6ecb 757e2cd b5b6ecb 757e2cd b5b6ecb 757e2cd b5b6ecb 757e2cd b5b6ecb 757e2cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import transformers
import re
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
from vllm import LLM, SamplingParams
import torch
import gradio as gr
import json
import os
import shutil
import requests
import lancedb
import pandas as pd
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define variables
temperature = 0.6
max_new_tokens = 3000
top_p = 0.95
repetition_penalty = 1.2
model_name = "dataesr/"
# Initialize vLLM
llm = LLM(model_name, max_model_len=8128)
# Connect to the LanceDB database
db = lancedb.connect("base/lancedb_data")
table = db.open_table("abstractsC")
def hybrid_search(text):
results = table.search(text, query_type="hybrid").limit(5).to_pandas()
# Add a check for duplicate hashes
seen_hashes = set()
document = []
document_html = []
for _, row in results.iterrows():
hash_id = str(row['hash'])
# Skip if we've already seen this hash
if hash_id in seen_hashes:
continue
seen_hashes.add(hash_id)
title = row['hash']
content = row['text']
document.append(f"**{hash_id}**\n{title}\n{content}")
document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>')
document = "\n".join(document)
document_html = '<div id="source_listing">' + "".join(document_html) + "</div>"
return document, document_html
class ESRChatBot:
def __init__(self, system_prompt="Tu es ESR, le chatbot qui donne des réponses sourcées."):
self.system_prompt = system_prompt
def predict(self, user_message):
fiches, fiches_html = hybrid_search(user_message)
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty, stop=["#END#"])
detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Answer ###\n"""
prompts = [detailed_prompt]
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
generated_text = outputs[0].outputs[0].text
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
return generated_text, fiches_html
def format_references(text):
ref_start_marker = '<ref text="'
ref_end_marker = '</ref>'
parts = []
current_pos = 0
ref_number = 1
while True:
start_pos = text.find(ref_start_marker, current_pos)
if start_pos == -1:
parts.append(text[current_pos:])
break
parts.append(text[current_pos:start_pos])
end_pos = text.find('">', start_pos)
if end_pos == -1:
break
ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip()
ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">")
ref_end_pos = text.find(ref_end_marker, end_pos)
if ref_end_pos == -1:
break
ref_id = text[end_pos + 2:ref_end_pos].strip()
tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[{ref_number}]</a></span>'
parts.append(tooltip_html)
current_pos = ref_end_pos + len(ref_end_marker)
ref_number = ref_number + 1
return ''.join(parts)
# Initialize the ESRChatBot
ESR_bot = ESRChatBot()
# CSS for styling
css = """
.generation {
margin-left:2em;
margin-right:2em;
}
:target {
background-color: #CCF3DF;
}
.source {
float:left;
max-width:17%;
margin-left:2%;
}
.tooltip {
position: relative;
cursor: pointer;
font-variant-position: super;
color: #97999b;
}
.tooltip:hover::after {
content: attr(data-text);
position: absolute;
left: 0;
top: 120%;
white-space: pre-wrap;
width: 500px;
max-width: 500px;
z-index: 1;
background-color: #f9f9f9;
color: #000;
border: 1px solid #ddd;
border-radius: 5px;
padding: 5px;
display: block;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
"""
# Gradio interface
def gradio_interface(user_message):
response, sources = ESR_bot.predict(user_message)
return response, sources
# Create Gradio app
demo = gr.Blocks(css=css)
with demo:
gr.HTML("""<h1 style="text-align:center">ESR</h1>""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3)
text_button = gr.Button("Interroger ESR")
with gr.Column(scale=3):
text_output = gr.HTML(label="La réponse de ESR")
with gr.Row():
embedding_output = gr.HTML(label="Les sources utilisées")
text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output])
# Launch the app
if __name__ == "__main__":
demo.launch() |