import json import asyncio from typing import List, Optional, Dict, Any, Generator, AsyncGenerator from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict from llama_cpp import Llama from huggingface_hub import hf_hub_download from sse_starlette.sse import EventSourceResponse # --- 1. 配置管理 --- class Settings(BaseSettings): model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", extra="ignore" ) MODEL_ID: str = Field( "unsloth/Qwen3-8B-GGUF", description="Hugging Face上的模型仓库ID" ) FILENAME: str = Field("Qwen3-8B-Q8_0.gguf", description="要下载的GGUF模型文件名") N_CTX: int = Field(4096, description="模型的上下文窗口大小") N_GPU_LAYERS: int = Field(0, description="要卸载到GPU的层数 (0表示完全使用CPU)") N_THREADS: Optional[int] = Field( None, description="用于推理的CPU核心数 (None为自动)" ) VERBOSE: bool = Field(True, description="是否启用Llama.cpp的详细日志") settings = Settings() # --- 2. 模型加载 --- def load_model(): """从Hugging Face Hub下载并加载GGUF模型""" print(f"正在从Hub下载模型: {settings.MODEL_ID}/{settings.FILENAME}...") try: model_path = hf_hub_download( repo_id=settings.MODEL_ID, filename=settings.FILENAME ) except Exception as e: print(f"模型下载失败: {e}") raise RuntimeError(f"无法从Hugging Face Hub下载模型: {e}") print("模型下载完成。正在加载模型到内存...") try: model = Llama( model_path=model_path, n_ctx=settings.N_CTX, n_gpu_layers=settings.N_GPU_LAYERS, n_threads=settings.N_THREADS, verbose=settings.VERBOSE, ) print("模型加载完成。") return model except Exception as e: print(f"模型加载失败: {e}") raise RuntimeError(f"无法加载Llama模型: {e}") model = load_model() # --- 3. API 服务逻辑 --- app = FastAPI( title="Sparkle-Server - GGUF 大模型 API", description="一个基于 llama-cpp-python 和 FastAPI 的、兼容 OpenAI 格式的高性能LLM推理服务。", version="1.0.0", ) # --- 4. API 数据模型 (兼容 OpenAI) --- class ChatMessage(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): messages: List[ChatMessage] model: str = settings.MODEL_ID max_tokens: int = 1024 temperature: float = 0.7 stream: bool = False # --- 5. 流式响应生成器 --- async def stream_generator( chat_iterator: Generator[Dict[str, Any], Any, None], ) -> AsyncGenerator[str, None]: """将 llama-cpp-python 的输出流转换为 Server-Sent Events (SSE) 格式,并在流结束后打印完整响应。""" full_response_content = [] for chunk in chat_iterator: if "content" in chunk["choices"][0]["delta"]: content_piece = chunk["choices"][0]["delta"]["content"] full_response_content.append(content_piece) yield f"data: {json.dumps(chunk)}\n\n" await asyncio.sleep(0) # 允许事件循环处理其他任务 # 流结束后,打印完整响应 print("\n--- [流式] 生成响应 ---") print("".join(full_response_content)) print("--------------------------\n") # --- 6. API 端点 (兼容 OpenAI) --- @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """ 处理聊天补全请求,支持流式和非流式响应。 """ if not request.messages: raise HTTPException(status_code=400, detail="messages 列表不能为空") # 打印收到的请求 print("\n--- 收到请求 ---") print(json.dumps(request.dict(), indent=2, ensure_ascii=False)) print("--------------------\n") try: if request.stream: # 流式响应 chat_iterator = model.create_chat_completion( messages=request.dict()["messages"], max_tokens=request.max_tokens, temperature=request.temperature, stream=True, ) return EventSourceResponse(stream_generator(chat_iterator)) else: # 非流式响应 result = model.create_chat_completion( messages=request.dict()["messages"], max_tokens=request.max_tokens, temperature=request.temperature, stream=False, ) # 打印生成的响应 print("\n--- [非流式] 生成响应 ---") print(json.dumps(result, indent=2, ensure_ascii=False)) print("--------------------------\n") return result except Exception as e: print(f"处理请求时发生错误: {e}") raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}") @app.get("/") def read_root(): return {"message": "Sparkle-Server (GGUF版) 正在运行。请访问 /docs 查看 API 文档。"}