|
from queue import Empty, Queue |
|
from threading import Thread |
|
from langchain.callbacks.streaming_stdout_final_only import ( |
|
FinalStreamingStdOutCallbackHandler, |
|
) |
|
from agent import agent_executor |
|
|
|
from typing import Generator |
|
|
|
import gradio as gr |
|
|
|
|
|
class QueneCallback(FinalStreamingStdOutCallbackHandler): |
|
def __init__(self, q): |
|
super().__init__() |
|
self.q = q |
|
|
|
def on_llm_new_token(self, token: str, **kwargs: any) -> None: |
|
|
|
self.append_to_last_tokens(token) |
|
|
|
|
|
if self.check_if_answer_reached(): |
|
self.answer_reached = True |
|
if self.stream_prefix: |
|
for t in self.last_tokens: |
|
self.q.put(t) |
|
return |
|
|
|
|
|
if self.answer_reached: |
|
self.q.put(token) |
|
|
|
def on_llm_end(self, *args, **kwargs: any) -> None: |
|
return self.q.empty() |
|
|
|
|
|
def stream(input_text) -> Generator: |
|
|
|
q = Queue() |
|
job_done = object() |
|
|
|
|
|
def task(): |
|
agent_executor.invoke( |
|
{"input": input_text}, config={"callbacks": [QueneCallback(q)]} |
|
) |
|
q.put(job_done) |
|
|
|
|
|
t = Thread(target=task) |
|
t.start() |
|
|
|
|
|
while True: |
|
try: |
|
next_token = q.get(True, timeout=1) |
|
if next_token is job_done: |
|
break |
|
yield next_token |
|
except Empty: |
|
continue |
|
|
|
|
|
def predict(message, history): |
|
if len(message) == 0: |
|
return |
|
history_openai_format = [] |
|
for human, assistant in history: |
|
history_openai_format.append({"role": "user", "content": human}) |
|
history_openai_format.append({"role": "assistant", "content": assistant}) |
|
history_openai_format.append({"role": "user", "content": message}) |
|
partial_message = "" |
|
|
|
for token in stream(message): |
|
if len(token) != 0: |
|
partial_message = partial_message + token |
|
yield partial_message |
|
|
|
|
|
gr.ChatInterface( |
|
fn=predict, |
|
title="问题定义工具包", |
|
chatbot=gr.Chatbot(height=660), |
|
).queue().launch() |
|
|