stardust-eques's picture
Update app.py
737d994 verified
import asyncio
import uuid
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
import gradio as gr
# モジュールレベルで一度だけイベントループを作成
_loop = asyncio.new_event_loop()
asyncio.set_event_loop(_loop)
# 非同期エンジンの初期化
en_args = AsyncEngineArgs(
model="EQUES/JPharmatron-7B-chat",
enforce_eager=True
)
model = AsyncLLMEngine.from_engine_args(en_args)
# 非同期でトークン生成をストリーミングするジェネレータ
async def astream_generate(prompt: str):
previous_text = ""
async for request_output in model.generate(
prompt,
SamplingParams(temperature=0.0, max_tokens=512),
request_id=str(uuid.uuid4())
):
full_text = request_output.outputs[0].text
new_chunk = full_text[len(previous_text):]
previous_text = full_text
yield new_chunk
# Gradio 用の応答関数(同期ジェネレータ)
def respond(user_input, history):
history = history or []
# システムプロンプトと過去履歴を組み立て
base_prompt = "以下は親切で何でも答えてくれるAIアシスタントとの会話です。\n"
for u, b in history:
base_prompt += f"ユーザー: {u}\nアシスタント: {b}\n"
base_prompt += f"ユーザー: {user_input}\nアシスタント:"
# ユーザー発話と空応答を履歴に追加
history.append((user_input, ""))
# 単一のイベントループを使ってストリーミング
agen = astream_generate(base_prompt)
try:
while True:
chunk = _loop.run_until_complete(agen.__anext__())
# 最新の履歴にトークンを追加
history[-1] = (user_input, history[-1][1] + chunk)
yield history, history
except StopAsyncIteration:
return
# Gradio UI 定義
with gr.Blocks() as demo:
gr.Markdown("# 💊 製薬に関する質問をしてみてください。")
gr.Markdown("※ ストリーミングデモ用コードです。")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="あなたのメッセージ")
clear = gr.Button("チャット履歴をクリア")
state = gr.State([])
msg.submit(respond, [msg, state], [chatbot, state])
clear.click(lambda: ([], []), None, [chatbot, state])
# エントリポイント
def main():
demo.launch()
if __name__ == "__main__":
main()