Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import json | |
import os | |
from typing import List, Tuple | |
class JapaneseLLMChat: | |
def __init__(self): | |
# 利用可能な日本語LLMモデル | |
self.models = { | |
"cyberagent/open-calm-7b": "CyberAgent Open CALM 7B", | |
"rinna/japanese-gpt-neox-3.6b-instruction-sft": "Rinna GPT-NeoX 3.6B", | |
"matsuo-lab/weblab-10b-instruction-sft": "Matsuo Lab WebLab 10B", | |
"stabilityai/japanese-stablelm-instruct-alpha-7b": "Japanese StableLM 7B" | |
} | |
# デフォルトモデル | |
self.current_model = "cyberagent/open-calm-7b" | |
# HuggingFace API設定 | |
self.api_url = "https://api-inference.huggingface.co/models/" | |
self.headers = {} | |
def set_api_key(self, api_key: str): | |
"""APIキーを設定""" | |
if api_key.strip(): | |
self.headers = {"Authorization": f"Bearer {api_key}"} | |
return "✅ APIキーが設定されました" | |
else: | |
return "❌ 有効なAPIキーを入力してください" | |
def set_model(self, model_name: str): | |
"""使用するモデルを変更""" | |
self.current_model = model_name | |
return f"モデルを {self.models[model_name]} に変更しました" | |
def query_model(self, prompt: str, max_length: int = 200, temperature: float = 0.7) -> str: | |
"""HuggingFace Inference APIにクエリを送信""" | |
if not self.headers: | |
return "❌ APIキーが設定されていません" | |
url = self.api_url + self.current_model | |
payload = { | |
"inputs": prompt, | |
"parameters": { | |
"max_length": max_length, | |
"temperature": temperature, | |
"do_sample": True, | |
"top_p": 0.95, | |
"return_full_text": False | |
} | |
} | |
try: | |
response = requests.post(url, headers=self.headers, json=payload, timeout=30) | |
if response.status_code == 200: | |
result = response.json() | |
if isinstance(result, list) and len(result) > 0: | |
generated_text = result[0].get("generated_text", "") | |
return generated_text.strip() | |
else: | |
return "❌ 予期しないレスポンス形式です" | |
elif response.status_code == 503: | |
return "⏳ モデルが読み込み中です。しばらく待ってから再試行してください。" | |
elif response.status_code == 401: | |
return "❌ APIキーが無効です" | |
else: | |
return f"❌ エラーが発生しました (ステータス: {response.status_code})" | |
except requests.exceptions.Timeout: | |
return "⏳ リクエストがタイムアウトしました。再試行してください。" | |
except requests.exceptions.RequestException as e: | |
return f"❌ 接続エラー: {str(e)}" | |
def chat_response(self, message: str, history: List[Tuple[str, str]], | |
max_length: int, temperature: float) -> Tuple[str, List[Tuple[str, str]]]: | |
"""チャット応答を生成""" | |
if not message.strip(): | |
return "", history | |
# 対話履歴を考慮したプロンプト作成 | |
conversation_context = "" | |
for user_msg, bot_msg in history[-3:]: # 直近3回の会話を含める | |
conversation_context += f"ユーザー: {user_msg}\nアシスタント: {bot_msg}\n" | |
# プロンプトの構築 | |
if self.current_model == "rinna/japanese-gpt-neox-3.6b-instruction-sft": | |
prompt = f"{conversation_context}ユーザー: {message}\nアシスタント:" | |
elif "instruct" in self.current_model.lower(): | |
prompt = f"以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書いてください。\n\n### 指示:\n日本語で自然な会話を行ってください。\n\n### 入力:\n{conversation_context}ユーザー: {message}\n\n### 応答:\n" | |
else: | |
prompt = f"{conversation_context}ユーザー: {message}\nアシスタント:" | |
# モデルから応答を取得 | |
response = self.query_model(prompt, max_length, temperature) | |
# 履歴に追加 | |
history.append((message, response)) | |
return "", history | |
# チャットインスタンスを作成 | |
chat_bot = JapaneseLLMChat() | |
# Gradio インターフェースの構築 | |
def create_interface(): | |
with gr.Blocks( | |
title="日本語LLMチャット", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1000px !important; | |
} | |
""" | |
) as demo: | |
gr.Markdown( | |
""" | |
# 🤖 日本語LLMチャット | |
HuggingFace Inference APIを使用した日本語対話システム | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# APIキー設定 | |
with gr.Group(): | |
gr.Markdown("### 🔑 設定") | |
api_key_input = gr.Textbox( | |
label="HuggingFace API Token", | |
placeholder="hf_xxxxxxxxxxxxxxxxx", | |
type="password" | |
) | |
api_key_btn = gr.Button("APIキーを設定", variant="primary") | |
api_key_status = gr.Textbox(label="ステータス", interactive=False) | |
# モデル選択 | |
with gr.Group(): | |
gr.Markdown("### 🧠 モデル選択") | |
model_dropdown = gr.Dropdown( | |
choices=[(v, k) for k, v in chat_bot.models.items()], | |
value="cyberagent/open-calm-7b", | |
label="使用するモデル" | |
) | |
model_status = gr.Textbox(label="現在のモデル", interactive=False, | |
value=chat_bot.models[chat_bot.current_model]) | |
# パラメータ設定 | |
with gr.Group(): | |
gr.Markdown("### ⚙️ 生成パラメータ") | |
max_length_slider = gr.Slider( | |
minimum=50, maximum=500, value=200, | |
label="最大生成長" | |
) | |
temperature_slider = gr.Slider( | |
minimum=0.1, maximum=2.0, value=0.7, | |
label="Temperature(創造性)" | |
) | |
with gr.Column(scale=3): | |
# チャットインターフェース | |
chatbot = gr.Chatbot( | |
height=500, | |
label="会話", | |
show_label=True, | |
avatar_images=["👤", "🤖"] | |
) | |
msg = gr.Textbox( | |
label="メッセージ", | |
placeholder="メッセージを入力してください...", | |
lines=2 | |
) | |
with gr.Row(): | |
send_btn = gr.Button("送信", variant="primary") | |
clear_btn = gr.Button("会話をクリア", variant="secondary") | |
# 使用方法の説明 | |
with gr.Accordion("📖 使用方法", open=False): | |
gr.Markdown( | |
""" | |
1. **APIキーの設定**: HuggingFace(https://huggingface.co/settings/tokens)からAccess Tokenを取得し、上記フィールドに入力してください | |
2. **モデル選択**: 使用したい日本語LLMを選択してください | |
3. **パラメータ調整**: 必要に応じて生成パラメータを調整してください | |
4. **チャット開始**: メッセージを入力して「送信」ボタンをクリックしてください | |
**注意**: | |
- 初回使用時はモデルの読み込みに時間がかかる場合があります | |
- 大きなモデル(7B以上)の使用には有料アカウントが必要な場合があります | |
""" | |
) | |
# イベントハンドラーの設定 | |
api_key_btn.click( | |
chat_bot.set_api_key, | |
inputs=[api_key_input], | |
outputs=[api_key_status] | |
) | |
model_dropdown.change( | |
chat_bot.set_model, | |
inputs=[model_dropdown], | |
outputs=[model_status] | |
) | |
send_btn.click( | |
chat_bot.chat_response, | |
inputs=[msg, chatbot, max_length_slider, temperature_slider], | |
outputs=[msg, chatbot] | |
) | |
msg.submit( | |
chat_bot.chat_response, | |
inputs=[msg, chatbot, max_length_slider, temperature_slider], | |
outputs=[msg, chatbot] | |
) | |
clear_btn.click( | |
lambda: ([], ""), | |
outputs=[chatbot, msg] | |
) | |
return demo | |
# アプリケーションの起動 | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |