jpmodeldemo / app.py
oggata's picture
Update app.py
96b5085 verified
import gradio as gr
import requests
import json
import os
from typing import List, Tuple
class JapaneseLLMChat:
def __init__(self):
# 利用可能な日本語LLMモデル
self.models = {
"cyberagent/open-calm-7b": "CyberAgent Open CALM 7B",
"rinna/japanese-gpt-neox-3.6b-instruction-sft": "Rinna GPT-NeoX 3.6B",
"matsuo-lab/weblab-10b-instruction-sft": "Matsuo Lab WebLab 10B",
"stabilityai/japanese-stablelm-instruct-alpha-7b": "Japanese StableLM 7B"
}
# デフォルトモデル
self.current_model = "cyberagent/open-calm-7b"
# HuggingFace API設定
self.api_url = "https://api-inference.huggingface.co/models/"
self.headers = {}
def set_api_key(self, api_key: str):
"""APIキーを設定"""
if api_key.strip():
self.headers = {"Authorization": f"Bearer {api_key}"}
return "✅ APIキーが設定されました"
else:
return "❌ 有効なAPIキーを入力してください"
def set_model(self, model_name: str):
"""使用するモデルを変更"""
self.current_model = model_name
return f"モデルを {self.models[model_name]} に変更しました"
def query_model(self, prompt: str, max_length: int = 200, temperature: float = 0.7) -> str:
"""HuggingFace Inference APIにクエリを送信"""
if not self.headers:
return "❌ APIキーが設定されていません"
url = self.api_url + self.current_model
payload = {
"inputs": prompt,
"parameters": {
"max_length": max_length,
"temperature": temperature,
"do_sample": True,
"top_p": 0.95,
"return_full_text": False
}
}
try:
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
if isinstance(result, list) and len(result) > 0:
generated_text = result[0].get("generated_text", "")
return generated_text.strip()
else:
return "❌ 予期しないレスポンス形式です"
elif response.status_code == 503:
return "⏳ モデルが読み込み中です。しばらく待ってから再試行してください。"
elif response.status_code == 401:
return "❌ APIキーが無効です"
else:
return f"❌ エラーが発生しました (ステータス: {response.status_code})"
except requests.exceptions.Timeout:
return "⏳ リクエストがタイムアウトしました。再試行してください。"
except requests.exceptions.RequestException as e:
return f"❌ 接続エラー: {str(e)}"
def chat_response(self, message: str, history: List[Tuple[str, str]],
max_length: int, temperature: float) -> Tuple[str, List[Tuple[str, str]]]:
"""チャット応答を生成"""
if not message.strip():
return "", history
# 対話履歴を考慮したプロンプト作成
conversation_context = ""
for user_msg, bot_msg in history[-3:]: # 直近3回の会話を含める
conversation_context += f"ユーザー: {user_msg}\nアシスタント: {bot_msg}\n"
# プロンプトの構築
if self.current_model == "rinna/japanese-gpt-neox-3.6b-instruction-sft":
prompt = f"{conversation_context}ユーザー: {message}\nアシスタント:"
elif "instruct" in self.current_model.lower():
prompt = f"以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書いてください。\n\n### 指示:\n日本語で自然な会話を行ってください。\n\n### 入力:\n{conversation_context}ユーザー: {message}\n\n### 応答:\n"
else:
prompt = f"{conversation_context}ユーザー: {message}\nアシスタント:"
# モデルから応答を取得
response = self.query_model(prompt, max_length, temperature)
# 履歴に追加
history.append((message, response))
return "", history
# チャットインスタンスを作成
chat_bot = JapaneseLLMChat()
# Gradio インターフェースの構築
def create_interface():
with gr.Blocks(
title="日本語LLMチャット",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1000px !important;
}
"""
) as demo:
gr.Markdown(
"""
# 🤖 日本語LLMチャット
HuggingFace Inference APIを使用した日本語対話システム
"""
)
with gr.Row():
with gr.Column(scale=2):
# APIキー設定
with gr.Group():
gr.Markdown("### 🔑 設定")
api_key_input = gr.Textbox(
label="HuggingFace API Token",
placeholder="hf_xxxxxxxxxxxxxxxxx",
type="password"
)
api_key_btn = gr.Button("APIキーを設定", variant="primary")
api_key_status = gr.Textbox(label="ステータス", interactive=False)
# モデル選択
with gr.Group():
gr.Markdown("### 🧠 モデル選択")
model_dropdown = gr.Dropdown(
choices=[(v, k) for k, v in chat_bot.models.items()],
value="cyberagent/open-calm-7b",
label="使用するモデル"
)
model_status = gr.Textbox(label="現在のモデル", interactive=False,
value=chat_bot.models[chat_bot.current_model])
# パラメータ設定
with gr.Group():
gr.Markdown("### ⚙️ 生成パラメータ")
max_length_slider = gr.Slider(
minimum=50, maximum=500, value=200,
label="最大生成長"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=2.0, value=0.7,
label="Temperature(創造性)"
)
with gr.Column(scale=3):
# チャットインターフェース
chatbot = gr.Chatbot(
height=500,
label="会話",
show_label=True,
avatar_images=["👤", "🤖"]
)
msg = gr.Textbox(
label="メッセージ",
placeholder="メッセージを入力してください...",
lines=2
)
with gr.Row():
send_btn = gr.Button("送信", variant="primary")
clear_btn = gr.Button("会話をクリア", variant="secondary")
# 使用方法の説明
with gr.Accordion("📖 使用方法", open=False):
gr.Markdown(
"""
1. **APIキーの設定**: HuggingFace(https://huggingface.co/settings/tokens)からAccess Tokenを取得し、上記フィールドに入力してください
2. **モデル選択**: 使用したい日本語LLMを選択してください
3. **パラメータ調整**: 必要に応じて生成パラメータを調整してください
4. **チャット開始**: メッセージを入力して「送信」ボタンをクリックしてください
**注意**:
- 初回使用時はモデルの読み込みに時間がかかる場合があります
- 大きなモデル(7B以上)の使用には有料アカウントが必要な場合があります
"""
)
# イベントハンドラーの設定
api_key_btn.click(
chat_bot.set_api_key,
inputs=[api_key_input],
outputs=[api_key_status]
)
model_dropdown.change(
chat_bot.set_model,
inputs=[model_dropdown],
outputs=[model_status]
)
send_btn.click(
chat_bot.chat_response,
inputs=[msg, chatbot, max_length_slider, temperature_slider],
outputs=[msg, chatbot]
)
msg.submit(
chat_bot.chat_response,
inputs=[msg, chatbot, max_length_slider, temperature_slider],
outputs=[msg, chatbot]
)
clear_btn.click(
lambda: ([], ""),
outputs=[chatbot, msg]
)
return demo
# アプリケーションの起動
if __name__ == "__main__":
demo = create_interface()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)