Spaces:
Running
Running
File size: 5,177 Bytes
7dce215 701742d 7dce215 1d3379a 7dce215 65559ae 7dce215 9c0355b 7dce215 9c0355b 7dce215 9c0355b 7dce215 1d3379a 7dce215 1d3379a 7dce215 9c0355b 7dce215 9c0355b 7dce215 1d3379a 65559ae 1d3379a 7dce215 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import json
import asyncio
from typing import List, Optional, Dict, Any, Generator, AsyncGenerator
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
from sse_starlette.sse import EventSourceResponse
# --- 1. 配置管理 ---
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
MODEL_ID: str = Field(
"unsloth/Qwen3-8B-GGUF", description="Hugging Face上的模型仓库ID"
)
FILENAME: str = Field("Qwen3-8B-Q8_0.gguf", description="要下载的GGUF模型文件名")
N_CTX: int = Field(4096, description="模型的上下文窗口大小")
N_GPU_LAYERS: int = Field(0, description="要卸载到GPU的层数 (0表示完全使用CPU)")
N_THREADS: Optional[int] = Field(
None, description="用于推理的CPU核心数 (None为自动)"
)
VERBOSE: bool = Field(True, description="是否启用Llama.cpp的详细日志")
settings = Settings()
# --- 2. 模型加载 ---
def load_model():
"""从Hugging Face Hub下载并加载GGUF模型"""
print(f"正在从Hub下载模型: {settings.MODEL_ID}/{settings.FILENAME}...")
try:
model_path = hf_hub_download(
repo_id=settings.MODEL_ID, filename=settings.FILENAME
)
except Exception as e:
print(f"模型下载失败: {e}")
raise RuntimeError(f"无法从Hugging Face Hub下载模型: {e}")
print("模型下载完成。正在加载模型到内存...")
try:
model = Llama(
model_path=model_path,
n_ctx=settings.N_CTX,
n_gpu_layers=settings.N_GPU_LAYERS,
n_threads=settings.N_THREADS,
verbose=settings.VERBOSE,
)
print("模型加载完成。")
return model
except Exception as e:
print(f"模型加载失败: {e}")
raise RuntimeError(f"无法加载Llama模型: {e}")
model = load_model()
# --- 3. API 服务逻辑 ---
app = FastAPI(
title="Sparkle-Server - GGUF 大模型 API",
description="一个基于 llama-cpp-python 和 FastAPI 的、兼容 OpenAI 格式的高性能LLM推理服务。",
version="1.0.0",
)
# --- 4. API 数据模型 (兼容 OpenAI) ---
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
model: str = settings.MODEL_ID
max_tokens: int = 1024
temperature: float = 0.7
stream: bool = False
# --- 5. 流式响应生成器 ---
async def stream_generator(
chat_iterator: Generator[Dict[str, Any], Any, None],
) -> AsyncGenerator[str, None]:
"""将 llama-cpp-python 的输出流转换为 Server-Sent Events (SSE) 格式,并在流结束后打印完整响应。"""
full_response_content = []
for chunk in chat_iterator:
if "content" in chunk["choices"][0]["delta"]:
content_piece = chunk["choices"][0]["delta"]["content"]
full_response_content.append(content_piece)
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0) # 允许事件循环处理其他任务
# 流结束后,打印完整响应
print("\n--- [流式] 生成响应 ---")
print("".join(full_response_content))
print("--------------------------\n")
# --- 6. API 端点 (兼容 OpenAI) ---
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""
处理聊天补全请求,支持流式和非流式响应。
"""
if not request.messages:
raise HTTPException(status_code=400, detail="messages 列表不能为空")
# 打印收到的请求
print("\n--- 收到请求 ---")
print(json.dumps(request.dict(), indent=2, ensure_ascii=False))
print("--------------------\n")
try:
if request.stream:
# 流式响应
chat_iterator = model.create_chat_completion(
messages=request.dict()["messages"],
max_tokens=request.max_tokens,
temperature=request.temperature,
stream=True,
)
return EventSourceResponse(stream_generator(chat_iterator))
else:
# 非流式响应
result = model.create_chat_completion(
messages=request.dict()["messages"],
max_tokens=request.max_tokens,
temperature=request.temperature,
stream=False,
)
# 打印生成的响应
print("\n--- [非流式] 生成响应 ---")
print(json.dumps(result, indent=2, ensure_ascii=False))
print("--------------------------\n")
return result
except Exception as e:
print(f"处理请求时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
@app.get("/")
def read_root():
return {"message": "Sparkle-Server (GGUF版) 正在运行。请访问 /docs 查看 API 文档。"}
|