Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from callback_handler import QueueCallback | |
from collections.abc import Generator | |
from queue import Queue, Empty | |
from threading import Thread | |
from dotenv import load_dotenv | |
load_dotenv() | |
from call_openai import call_openai | |
import pinecone | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import Pinecone | |
OPENAI_API_KEY=os.environ["OPENAI_API_KEY"] | |
PINECONE_API_KEY=os.environ["PINECONE_API_KEY"] | |
PINECONE_ENV=os.environ["PINECONE_ENV"] | |
PINECONE_INDEX=os.environ["PINECONE_INDEX"] | |
# TOOL | |
##################################################################### | |
pinecone.init( | |
api_key=PINECONE_API_KEY, | |
environment=PINECONE_ENV | |
) | |
index = pinecone.Index(PINECONE_INDEX) | |
embedder = OpenAIEmbeddings() | |
class PineconeSearch: | |
docsearch = None | |
topk = 2 | |
def __init__( | |
self, | |
namespace, | |
topk | |
): | |
self.docsearch = Pinecone.from_existing_index(PINECONE_INDEX, embedder, namespace=namespace) | |
self.topk=topk | |
def __call__(self,query): | |
docs = self.docsearch.similarity_search(query=query, k=self.topk) | |
context = "ARTICLES:\n\n" | |
for doc in docs: | |
context += f"Content:\n{doc.page_content}\n\n" | |
context += f"Source: {doc.metadata['url']}\n" | |
context += "----" | |
return context | |
def query_tool(category, pinecone_topk, query): | |
print(query) | |
data = { | |
"1_D3_receptor": "demo-richter-target-400-30-1", | |
"2_dopamine": "demo-richter-target-400-30-2", | |
"3_mitochondrial": "demo-richter-target-400-30-3" | |
} | |
pinecone_namespace = data[category] | |
search_tool = PineconeSearch( | |
namespace=pinecone_namespace, | |
topk=pinecone_topk, | |
) | |
return search_tool(query) | |
def print_token_and_price(response): | |
inp = sum(response["token_usage"]["prompt_tokens"]) | |
out = sum( response["token_usage"]["completion_tokens"]) | |
print(f"Token usage: {inp+out}") | |
price = inp/1000*0.01 + out/1000*0.03 | |
print(f"Total price: {price*370:.2f} Ft") | |
print("===================================") | |
agent_prompt = """You are an expert research assistant. You can access information about articles via your tool. | |
Use information ONLY from this tool. Do not invent or add any more knowladge, be strict for the articles. | |
Answer the question in a few brief sentence based on the piece of article you get from your tool. | |
Quote the used sources in [bracket] next to the facts, and at the end of your answer write them out""" | |
def stream(input_text, history, user_prompt, topic, topk) -> Generator: | |
# Create a Queue | |
q = Queue() | |
job_done = object() | |
# Create a funciton to call - this will run in a thread | |
def task(): | |
tool_resp = query_tool(topic, topk, str(input_text)) | |
response = call_openai( | |
messages=[{"role": "system", "content": agent_prompt}, | |
{"role": "user", "content": input_text}, | |
{"role": "user", "content": tool_resp} | |
], | |
stream="token", | |
model="gpt-4-1106-preview", | |
callback=QueueCallback(q) | |
) | |
print(response) | |
#print_token_and_price(response=response) | |
q.put(job_done) | |
# Create a thread and start the function | |
t = Thread(target=task) | |
t.start() | |
content = "" | |
# Get each new token from the queue and yield for our generator | |
counter = 0 | |
while True: | |
try: | |
next_token = q.get(True, timeout=1) | |
if next_token is job_done: | |
break | |
content += next_token | |
counter += 1 | |
if counter == 20: | |
content += "\n" | |
counter = 0 | |
if "\n" in next_token: | |
counter = 0 | |
yield next_token, content | |
except Empty: | |
continue | |
def ask_llm(message, history, prompt, topic, topk): | |
for next_token, content in stream(message, history, prompt, topic, topk): | |
yield(content) | |
agent_prompt_textbox = gr.Textbox( | |
label = "Set the behaviour of the agent", | |
lines = 2, | |
value = "NOT WORKING" | |
) | |
namespace_drobdown = gr.Dropdown( | |
["1_D3_receptor", "2_dopamine", "3_mitochondrial"], | |
label="Choose a topic", | |
value="1_D3_receptor" | |
) | |
topk_slider = gr.Slider( | |
minimum=10, | |
maximum=100, | |
value=70, | |
step=10 | |
) | |
additional_inputs = [agent_prompt_textbox, namespace_drobdown, topk_slider] | |
chatInterface = gr.ChatInterface( | |
fn=ask_llm, | |
additional_inputs=additional_inputs, | |
additional_inputs_accordion_name="Agent parameters" | |
).queue().launch() |