|
|
import os |
|
|
import threading |
|
|
from typing import Any, Dict, Iterable, List, Union |
|
|
|
|
|
import gradio as gr |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from llama_cpp import Llama |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Qwen/Qwen2.5-0.5B-Instruct-GGUF") |
|
|
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "qwen2.5-0.5b-instruct-q4_k_m.gguf") |
|
|
|
|
|
SYSTEM_PROMPT = os.getenv( |
|
|
"SYSTEM_PROMPT", |
|
|
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", |
|
|
) |
|
|
|
|
|
|
|
|
N_CTX = int(os.getenv("N_CTX", "4096")) |
|
|
|
|
|
|
|
|
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) |
|
|
TOP_P = float(os.getenv("TOP_P", "0.9")) |
|
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "512")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_llm: Llama | None = None |
|
|
_llm_lock = threading.Lock() |
|
|
|
|
|
|
|
|
def _load_llm() -> Llama: |
|
|
global _llm |
|
|
if _llm is not None: |
|
|
return _llm |
|
|
|
|
|
with _llm_lock: |
|
|
if _llm is not None: |
|
|
return _llm |
|
|
|
|
|
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME) |
|
|
|
|
|
|
|
|
_llm = Llama( |
|
|
model_path=model_path, |
|
|
n_ctx=N_CTX, |
|
|
n_threads=os.cpu_count() or 4, |
|
|
n_gpu_layers=0, |
|
|
chat_format="chatml", |
|
|
verbose=False, |
|
|
) |
|
|
return _llm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Content = Union[str, List[Any], Dict[str, Any]] |
|
|
|
|
|
|
|
|
def _content_to_text(content: Content) -> str: |
|
|
if isinstance(content, str): |
|
|
return content |
|
|
if isinstance(content, list): |
|
|
parts: List[str] = [] |
|
|
for item in content: |
|
|
if isinstance(item, str): |
|
|
parts.append(item) |
|
|
elif isinstance(item, dict) and item.get("type") == "text": |
|
|
parts.append(str(item.get("text", ""))) |
|
|
return "".join(parts).strip() |
|
|
if isinstance(content, dict): |
|
|
for k in ("text", "content"): |
|
|
v = content.get(k) |
|
|
if isinstance(v, str): |
|
|
return v |
|
|
return str(content) |
|
|
|
|
|
|
|
|
def _history_to_messages(history: Any) -> List[Dict[str, str]]: |
|
|
if not history: |
|
|
return [] |
|
|
|
|
|
msgs: List[Dict[str, str]] = [] |
|
|
|
|
|
|
|
|
if isinstance(history, list) and history and isinstance(history[0], (tuple, list)) and len(history[0]) == 2: |
|
|
for user, assistant in history: |
|
|
if user: |
|
|
msgs.append({"role": "user", "content": str(user)}) |
|
|
if assistant: |
|
|
msgs.append({"role": "assistant", "content": str(assistant)}) |
|
|
return msgs |
|
|
|
|
|
|
|
|
if isinstance(history, list) and history and isinstance(history[0], dict): |
|
|
for m in history: |
|
|
role = m.get("role") |
|
|
if role not in ("user", "assistant", "system"): |
|
|
continue |
|
|
text = _content_to_text(m.get("content", "")) |
|
|
if text: |
|
|
msgs.append({"role": role, "content": text}) |
|
|
return msgs |
|
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
def _stream_chat(llm: Llama, messages: List[Dict[str, str]]) -> Iterable[str]: |
|
|
|
|
|
stream = llm.create_chat_completion( |
|
|
messages=messages, |
|
|
temperature=TEMPERATURE, |
|
|
top_p=TOP_P, |
|
|
max_tokens=MAX_TOKENS, |
|
|
stream=True, |
|
|
) |
|
|
|
|
|
partial = "" |
|
|
for chunk in stream: |
|
|
token = "" |
|
|
try: |
|
|
choice = chunk["choices"][0] |
|
|
delta = choice.get("delta") or {} |
|
|
token = delta.get("content") or "" |
|
|
except Exception: |
|
|
token = "" |
|
|
if token: |
|
|
partial += token |
|
|
yield partial |
|
|
|
|
|
|
|
|
def respond(message: str, history: Any): |
|
|
llm = _load_llm() |
|
|
|
|
|
msgs: List[Dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}] |
|
|
prior = _history_to_messages(history) |
|
|
|
|
|
|
|
|
if len(prior) > 20: |
|
|
prior = prior[-20:] |
|
|
|
|
|
msgs.extend(prior) |
|
|
msgs.append({"role": "user", "content": message}) |
|
|
|
|
|
for partial in _stream_chat(llm, msgs): |
|
|
yield partial |
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=respond, |
|
|
title="GGUF Chatbot (llama-cpp-python)", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|