Spaces:
Sleeping
Sleeping
File size: 9,605 Bytes
fecb2f9 96b5085 fecb2f9 96b5085 fecb2f9 96b5085 fecb2f9 96b5085 fecb2f9 96b5085 fecb2f9 96b5085 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import gradio as gr
import requests
import json
import os
from typing import List, Tuple
class JapaneseLLMChat:
def __init__(self):
# 利用可能な日本語LLMモデル
self.models = {
"cyberagent/open-calm-7b": "CyberAgent Open CALM 7B",
"rinna/japanese-gpt-neox-3.6b-instruction-sft": "Rinna GPT-NeoX 3.6B",
"matsuo-lab/weblab-10b-instruction-sft": "Matsuo Lab WebLab 10B",
"stabilityai/japanese-stablelm-instruct-alpha-7b": "Japanese StableLM 7B"
}
# デフォルトモデル
self.current_model = "cyberagent/open-calm-7b"
# HuggingFace API設定
self.api_url = "https://api-inference.huggingface.co/models/"
self.headers = {}
def set_api_key(self, api_key: str):
"""APIキーを設定"""
if api_key.strip():
self.headers = {"Authorization": f"Bearer {api_key}"}
return "✅ APIキーが設定されました"
else:
return "❌ 有効なAPIキーを入力してください"
def set_model(self, model_name: str):
"""使用するモデルを変更"""
self.current_model = model_name
return f"モデルを {self.models[model_name]} に変更しました"
def query_model(self, prompt: str, max_length: int = 200, temperature: float = 0.7) -> str:
"""HuggingFace Inference APIにクエリを送信"""
if not self.headers:
return "❌ APIキーが設定されていません"
url = self.api_url + self.current_model
payload = {
"inputs": prompt,
"parameters": {
"max_length": max_length,
"temperature": temperature,
"do_sample": True,
"top_p": 0.95,
"return_full_text": False
}
}
try:
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
if isinstance(result, list) and len(result) > 0:
generated_text = result[0].get("generated_text", "")
return generated_text.strip()
else:
return "❌ 予期しないレスポンス形式です"
elif response.status_code == 503:
return "⏳ モデルが読み込み中です。しばらく待ってから再試行してください。"
elif response.status_code == 401:
return "❌ APIキーが無効です"
else:
return f"❌ エラーが発生しました (ステータス: {response.status_code})"
except requests.exceptions.Timeout:
return "⏳ リクエストがタイムアウトしました。再試行してください。"
except requests.exceptions.RequestException as e:
return f"❌ 接続エラー: {str(e)}"
def chat_response(self, message: str, history: List[Tuple[str, str]],
max_length: int, temperature: float) -> Tuple[str, List[Tuple[str, str]]]:
"""チャット応答を生成"""
if not message.strip():
return "", history
# 対話履歴を考慮したプロンプト作成
conversation_context = ""
for user_msg, bot_msg in history[-3:]: # 直近3回の会話を含める
conversation_context += f"ユーザー: {user_msg}\nアシスタント: {bot_msg}\n"
# プロンプトの構築
if self.current_model == "rinna/japanese-gpt-neox-3.6b-instruction-sft":
prompt = f"{conversation_context}ユーザー: {message}\nアシスタント:"
elif "instruct" in self.current_model.lower():
prompt = f"以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書いてください。\n\n### 指示:\n日本語で自然な会話を行ってください。\n\n### 入力:\n{conversation_context}ユーザー: {message}\n\n### 応答:\n"
else:
prompt = f"{conversation_context}ユーザー: {message}\nアシスタント:"
# モデルから応答を取得
response = self.query_model(prompt, max_length, temperature)
# 履歴に追加
history.append((message, response))
return "", history
# チャットインスタンスを作成
chat_bot = JapaneseLLMChat()
# Gradio インターフェースの構築
def create_interface():
with gr.Blocks(
title="日本語LLMチャット",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1000px !important;
}
"""
) as demo:
gr.Markdown(
"""
# 🤖 日本語LLMチャット
HuggingFace Inference APIを使用した日本語対話システム
"""
)
with gr.Row():
with gr.Column(scale=2):
# APIキー設定
with gr.Group():
gr.Markdown("### 🔑 設定")
api_key_input = gr.Textbox(
label="HuggingFace API Token",
placeholder="hf_xxxxxxxxxxxxxxxxx",
type="password"
)
api_key_btn = gr.Button("APIキーを設定", variant="primary")
api_key_status = gr.Textbox(label="ステータス", interactive=False)
# モデル選択
with gr.Group():
gr.Markdown("### 🧠 モデル選択")
model_dropdown = gr.Dropdown(
choices=[(v, k) for k, v in chat_bot.models.items()],
value="cyberagent/open-calm-7b",
label="使用するモデル"
)
model_status = gr.Textbox(label="現在のモデル", interactive=False,
value=chat_bot.models[chat_bot.current_model])
# パラメータ設定
with gr.Group():
gr.Markdown("### ⚙️ 生成パラメータ")
max_length_slider = gr.Slider(
minimum=50, maximum=500, value=200,
label="最大生成長"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=2.0, value=0.7,
label="Temperature(創造性)"
)
with gr.Column(scale=3):
# チャットインターフェース
chatbot = gr.Chatbot(
height=500,
label="会話",
show_label=True,
avatar_images=["👤", "🤖"]
)
msg = gr.Textbox(
label="メッセージ",
placeholder="メッセージを入力してください...",
lines=2
)
with gr.Row():
send_btn = gr.Button("送信", variant="primary")
clear_btn = gr.Button("会話をクリア", variant="secondary")
# 使用方法の説明
with gr.Accordion("📖 使用方法", open=False):
gr.Markdown(
"""
1. **APIキーの設定**: HuggingFace(https://huggingface.co/settings/tokens)からAccess Tokenを取得し、上記フィールドに入力してください
2. **モデル選択**: 使用したい日本語LLMを選択してください
3. **パラメータ調整**: 必要に応じて生成パラメータを調整してください
4. **チャット開始**: メッセージを入力して「送信」ボタンをクリックしてください
**注意**:
- 初回使用時はモデルの読み込みに時間がかかる場合があります
- 大きなモデル(7B以上)の使用には有料アカウントが必要な場合があります
"""
)
# イベントハンドラーの設定
api_key_btn.click(
chat_bot.set_api_key,
inputs=[api_key_input],
outputs=[api_key_status]
)
model_dropdown.change(
chat_bot.set_model,
inputs=[model_dropdown],
outputs=[model_status]
)
send_btn.click(
chat_bot.chat_response,
inputs=[msg, chatbot, max_length_slider, temperature_slider],
outputs=[msg, chatbot]
)
msg.submit(
chat_bot.chat_response,
inputs=[msg, chatbot, max_length_slider, temperature_slider],
outputs=[msg, chatbot]
)
clear_btn.click(
lambda: ([], ""),
outputs=[chatbot, msg]
)
return demo
# アプリケーションの起動
if __name__ == "__main__":
demo = create_interface()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
) |