Gongyi / app.py
demongodYY
fix interface
0f46b42
from queue import Empty, Queue
from threading import Thread
from langchain.callbacks.streaming_stdout_final_only import (
FinalStreamingStdOutCallbackHandler,
)
from agent import agent_executor
from typing import Generator
import gradio as gr
class QueneCallback(FinalStreamingStdOutCallbackHandler):
def __init__(self, q):
super().__init__()
self.q = q
def on_llm_new_token(self, token: str, **kwargs: any) -> None:
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(token)
# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
self.answer_reached = True
if self.stream_prefix:
for t in self.last_tokens:
self.q.put(t)
return
# ... if yes, then print tokens from now on
if self.answer_reached:
self.q.put(token)
def on_llm_end(self, *args, **kwargs: any) -> None:
return self.q.empty()
def stream(input_text) -> Generator:
# Create a Queue
q = Queue()
job_done = object()
# Create a funciton to call - this will run in a thread
def task():
agent_executor.invoke(
{"input": input_text}, config={"callbacks": [QueneCallback(q)]}
)
q.put(job_done)
# Create a thread and start the function
t = Thread(target=task)
t.start()
# Get each new token from the queue and yield for our generator
while True:
try:
next_token = q.get(True, timeout=1)
if next_token is job_done:
break
yield next_token
except Empty:
continue
def predict(message, history):
if len(message) == 0:
return
history_openai_format = []
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
partial_message = ""
for token in stream(message):
if len(token) != 0:
partial_message = partial_message + token
yield partial_message
gr.ChatInterface(
fn=predict,
title="问题定义工具包",
chatbot=gr.Chatbot(height=660),
).queue().launch()