Spaces:
Running
Running
import json | |
import asyncio | |
from typing import List, Optional, Dict, Any, Generator, AsyncGenerator | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel, Field | |
from pydantic_settings import BaseSettings, SettingsConfigDict | |
from llama_cpp import Llama | |
from huggingface_hub import hf_hub_download | |
from sse_starlette.sse import EventSourceResponse | |
# --- 1. 配置管理 --- | |
class Settings(BaseSettings): | |
model_config = SettingsConfigDict( | |
env_file=".env", env_file_encoding="utf-8", extra="ignore" | |
) | |
MODEL_ID: str = Field( | |
"unsloth/Qwen3-8B-GGUF", description="Hugging Face上的模型仓库ID" | |
) | |
FILENAME: str = Field("Qwen3-8B-Q8_0.gguf", description="要下载的GGUF模型文件名") | |
N_CTX: int = Field(4096, description="模型的上下文窗口大小") | |
N_GPU_LAYERS: int = Field(0, description="要卸载到GPU的层数 (0表示完全使用CPU)") | |
N_THREADS: Optional[int] = Field( | |
None, description="用于推理的CPU核心数 (None为自动)" | |
) | |
VERBOSE: bool = Field(True, description="是否启用Llama.cpp的详细日志") | |
settings = Settings() | |
# --- 2. 模型加载 --- | |
def load_model(): | |
"""从Hugging Face Hub下载并加载GGUF模型""" | |
print(f"正在从Hub下载模型: {settings.MODEL_ID}/{settings.FILENAME}...") | |
try: | |
model_path = hf_hub_download( | |
repo_id=settings.MODEL_ID, filename=settings.FILENAME | |
) | |
except Exception as e: | |
print(f"模型下载失败: {e}") | |
raise RuntimeError(f"无法从Hugging Face Hub下载模型: {e}") | |
print("模型下载完成。正在加载模型到内存...") | |
try: | |
model = Llama( | |
model_path=model_path, | |
n_ctx=settings.N_CTX, | |
n_gpu_layers=settings.N_GPU_LAYERS, | |
n_threads=settings.N_THREADS, | |
verbose=settings.VERBOSE, | |
) | |
print("模型加载完成。") | |
return model | |
except Exception as e: | |
print(f"模型加载失败: {e}") | |
raise RuntimeError(f"无法加载Llama模型: {e}") | |
model = load_model() | |
# --- 3. API 服务逻辑 --- | |
app = FastAPI( | |
title="Sparkle-Server - GGUF 大模型 API", | |
description="一个基于 llama-cpp-python 和 FastAPI 的、兼容 OpenAI 格式的高性能LLM推理服务。", | |
version="1.0.0", | |
) | |
# --- 4. API 数据模型 (兼容 OpenAI) --- | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class ChatCompletionRequest(BaseModel): | |
messages: List[ChatMessage] | |
model: str = settings.MODEL_ID | |
max_tokens: int = 1024 | |
temperature: float = 0.7 | |
stream: bool = False | |
# --- 5. 流式响应生成器 --- | |
async def stream_generator( | |
chat_iterator: Generator[Dict[str, Any], Any, None], | |
) -> AsyncGenerator[str, None]: | |
"""将 llama-cpp-python 的输出流转换为 Server-Sent Events (SSE) 格式,并在流结束后打印完整响应。""" | |
full_response_content = [] | |
for chunk in chat_iterator: | |
if "content" in chunk["choices"][0]["delta"]: | |
content_piece = chunk["choices"][0]["delta"]["content"] | |
full_response_content.append(content_piece) | |
yield f"data: {json.dumps(chunk)}\n\n" | |
await asyncio.sleep(0) # 允许事件循环处理其他任务 | |
# 流结束后,打印完整响应 | |
print("\n--- [流式] 生成响应 ---") | |
print("".join(full_response_content)) | |
print("--------------------------\n") | |
# --- 6. API 端点 (兼容 OpenAI) --- | |
async def create_chat_completion(request: ChatCompletionRequest): | |
""" | |
处理聊天补全请求,支持流式和非流式响应。 | |
""" | |
if not request.messages: | |
raise HTTPException(status_code=400, detail="messages 列表不能为空") | |
# 打印收到的请求 | |
print("\n--- 收到请求 ---") | |
print(json.dumps(request.dict(), indent=2, ensure_ascii=False)) | |
print("--------------------\n") | |
try: | |
if request.stream: | |
# 流式响应 | |
chat_iterator = model.create_chat_completion( | |
messages=request.dict()["messages"], | |
max_tokens=request.max_tokens, | |
temperature=request.temperature, | |
stream=True, | |
) | |
return EventSourceResponse(stream_generator(chat_iterator)) | |
else: | |
# 非流式响应 | |
result = model.create_chat_completion( | |
messages=request.dict()["messages"], | |
max_tokens=request.max_tokens, | |
temperature=request.temperature, | |
stream=False, | |
) | |
# 打印生成的响应 | |
print("\n--- [非流式] 生成响应 ---") | |
print(json.dumps(result, indent=2, ensure_ascii=False)) | |
print("--------------------------\n") | |
return result | |
except Exception as e: | |
print(f"处理请求时发生错误: {e}") | |
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}") | |
def read_root(): | |
return {"message": "Sparkle-Server (GGUF版) 正在运行。请访问 /docs 查看 API 文档。"} | |