File size: 2,428 Bytes
8177664
 
 
 
 
 
 
 
 
 
 
737d994
8177664
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import asyncio
import uuid
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
import gradio as gr

# モジュールレベルで一度だけイベントループを作成
_loop = asyncio.new_event_loop()
asyncio.set_event_loop(_loop)

# 非同期エンジンの初期化
en_args = AsyncEngineArgs(
    model="EQUES/JPharmatron-7B-chat",
    enforce_eager=True
)
model = AsyncLLMEngine.from_engine_args(en_args)

# 非同期でトークン生成をストリーミングするジェネレータ
async def astream_generate(prompt: str):
    previous_text = ""
    async for request_output in model.generate(
        prompt,
        SamplingParams(temperature=0.0, max_tokens=512),
        request_id=str(uuid.uuid4())
    ):
        full_text = request_output.outputs[0].text
        new_chunk = full_text[len(previous_text):]
        previous_text = full_text
        yield new_chunk

# Gradio 用の応答関数(同期ジェネレータ)
def respond(user_input, history):
    history = history or []
    # システムプロンプトと過去履歴を組み立て
    base_prompt = "以下は親切で何でも答えてくれるAIアシスタントとの会話です。\n"
    for u, b in history:
        base_prompt += f"ユーザー: {u}\nアシスタント: {b}\n"
    base_prompt += f"ユーザー: {user_input}\nアシスタント:"

    # ユーザー発話と空応答を履歴に追加
    history.append((user_input, ""))

    # 単一のイベントループを使ってストリーミング
    agen = astream_generate(base_prompt)
    try:
        while True:
            chunk = _loop.run_until_complete(agen.__anext__())
            # 最新の履歴にトークンを追加
            history[-1] = (user_input, history[-1][1] + chunk)
            yield history, history
    except StopAsyncIteration:
        return

# Gradio UI 定義
with gr.Blocks() as demo:
    gr.Markdown("# 💊 製薬に関する質問をしてみてください。")
    gr.Markdown("※ ストリーミングデモ用コードです。")

    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="あなたのメッセージ")
    clear = gr.Button("チャット履歴をクリア")
    state = gr.State([])

    msg.submit(respond, [msg, state], [chatbot, state])
    clear.click(lambda: ([], []), None, [chatbot, state])

# エントリポイント
def main():
    demo.launch()

if __name__ == "__main__":
    main()