John6666's picture
Upload 4 files
c500d6d verified
import os
import threading
from typing import Any, Dict, Iterable, List, Union
import gradio as gr
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
# -----------------------------
# Model (HF GGUF)
# -----------------------------
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Qwen/Qwen2.5-0.5B-Instruct-GGUF")
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "qwen2.5-0.5b-instruct-q4_k_m.gguf")
SYSTEM_PROMPT = os.getenv(
"SYSTEM_PROMPT",
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
)
# Keep modest on free CPU (KV cache grows with context).
N_CTX = int(os.getenv("N_CTX", "4096"))
# Generation defaults
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
TOP_P = float(os.getenv("TOP_P", "0.9"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "512"))
# -----------------------------
# Lazy singleton model loader
# -----------------------------
_llm: Llama | None = None
_llm_lock = threading.Lock()
def _load_llm() -> Llama:
global _llm
if _llm is not None:
return _llm
with _llm_lock:
if _llm is not None:
return _llm
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
# Qwen instruct GGUFs commonly use ChatML-style formatting.
_llm = Llama(
model_path=model_path,
n_ctx=N_CTX,
n_threads=os.cpu_count() or 4,
n_gpu_layers=0,
chat_format="chatml",
verbose=False,
)
return _llm
# -----------------------------
# Gradio message normalization
# -----------------------------
Content = Union[str, List[Any], Dict[str, Any]]
def _content_to_text(content: Content) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "".join(parts).strip()
if isinstance(content, dict):
for k in ("text", "content"):
v = content.get(k)
if isinstance(v, str):
return v
return str(content)
def _history_to_messages(history: Any) -> List[Dict[str, str]]:
if not history:
return []
msgs: List[Dict[str, str]] = []
# Old format: list[(user, assistant), ...]
if isinstance(history, list) and history and isinstance(history[0], (tuple, list)) and len(history[0]) == 2:
for user, assistant in history:
if user:
msgs.append({"role": "user", "content": str(user)})
if assistant:
msgs.append({"role": "assistant", "content": str(assistant)})
return msgs
# Newer format: list[{"role": "...", "content": ...}, ...]
if isinstance(history, list) and history and isinstance(history[0], dict):
for m in history:
role = m.get("role")
if role not in ("user", "assistant", "system"):
continue
text = _content_to_text(m.get("content", ""))
if text:
msgs.append({"role": role, "content": text})
return msgs
return []
def _stream_chat(llm: Llama, messages: List[Dict[str, str]]) -> Iterable[str]:
# llama-cpp-python yields OpenAI-like streaming chunks.
stream = llm.create_chat_completion(
messages=messages,
temperature=TEMPERATURE,
top_p=TOP_P,
max_tokens=MAX_TOKENS,
stream=True,
)
partial = ""
for chunk in stream:
token = ""
try:
choice = chunk["choices"][0]
delta = choice.get("delta") or {}
token = delta.get("content") or ""
except Exception:
token = ""
if token:
partial += token
yield partial
def respond(message: str, history: Any):
llm = _load_llm()
msgs: List[Dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}]
prior = _history_to_messages(history)
# Simple history trim
if len(prior) > 20:
prior = prior[-20:]
msgs.extend(prior)
msgs.append({"role": "user", "content": message})
for partial in _stream_chat(llm, msgs):
yield partial
demo = gr.ChatInterface(
fn=respond,
title="GGUF Chatbot (llama-cpp-python)",
)
if __name__ == "__main__":
demo.launch()