File size: 2,414 Bytes
991dafd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f46b42
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from queue import Empty, Queue
from threading import Thread
from langchain.callbacks.streaming_stdout_final_only import (
    FinalStreamingStdOutCallbackHandler,
)
from agent import agent_executor

from typing import Generator

import gradio as gr


class QueneCallback(FinalStreamingStdOutCallbackHandler):
    def __init__(self, q):
        super().__init__()
        self.q = q

    def on_llm_new_token(self, token: str, **kwargs: any) -> None:
        # Remember the last n tokens, where n = len(answer_prefix_tokens)
        self.append_to_last_tokens(token)

        # Check if the last n tokens match the answer_prefix_tokens list ...
        if self.check_if_answer_reached():
            self.answer_reached = True
            if self.stream_prefix:
                for t in self.last_tokens:
                    self.q.put(t)
            return

        # ... if yes, then print tokens from now on
        if self.answer_reached:
            self.q.put(token)

    def on_llm_end(self, *args, **kwargs: any) -> None:
        return self.q.empty()


def stream(input_text) -> Generator:
    # Create a Queue
    q = Queue()
    job_done = object()

    # Create a funciton to call - this will run in a thread
    def task():
        agent_executor.invoke(
            {"input": input_text}, config={"callbacks": [QueneCallback(q)]}
        )
        q.put(job_done)

    # Create a thread and start the function
    t = Thread(target=task)
    t.start()

    # Get each new token from the queue and yield for our generator
    while True:
        try:
            next_token = q.get(True, timeout=1)
            if next_token is job_done:
                break
            yield next_token
        except Empty:
            continue


def predict(message, history):
    if len(message) == 0:
        return
    history_openai_format = []
    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human})
        history_openai_format.append({"role": "assistant", "content": assistant})
    history_openai_format.append({"role": "user", "content": message})
    partial_message = ""

    for token in stream(message):
        if len(token) != 0:
            partial_message = partial_message + token
            yield partial_message


gr.ChatInterface(
    fn=predict,
    title="问题定义工具包",
    chatbot=gr.Chatbot(height=660),
).queue().launch()