Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request, Response | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from llama_cpp import Llama | |
from fastapi.responses import PlainTextResponse, JSONResponse | |
from starlette.middleware.base import BaseHTTPMiddleware | |
import logging | |
import json | |
import os | |
import time | |
import uuid | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("api_logger") | |
class LoggingMiddleware(BaseHTTPMiddleware): | |
async def dispatch(self, request: Request, call_next): | |
# Read request body (must be buffered manually) | |
body = await request.body() | |
logger.info(f"REQUEST: {request.method} {request.url}\nBody: {body.decode('utf-8')}") | |
# Rebuild the request with body for downstream handlers | |
request = Request(request.scope, receive=lambda: {'type': 'http.request', 'body': body}) | |
# Process the response | |
response = await call_next(request) | |
response_body = b"" | |
async for chunk in response.body_iterator: | |
response_body += chunk | |
# Log response body and status code | |
logger.info(f"RESPONSE: Status {response.status_code}\nBody: {response_body.decode('utf-8')}") | |
# Rebuild response to preserve original functionality | |
return Response( | |
content=response_body, | |
status_code=response.status_code, | |
headers=dict(response.headers), | |
media_type=response.media_type | |
) | |
# FastAPI app with middleware | |
app = FastAPI() | |
app.add_middleware(LoggingMiddleware) | |
llm = None | |
# Models | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
model: str | |
messages: List[Message] | |
temperature: Optional[float] = 0.7 | |
max_tokens: Optional[int] = 256 | |
class GenerateRequest(BaseModel): | |
model: str | |
prompt: str | |
max_tokens: Optional[int] = 256 | |
temperature: Optional[float] = 0.7 | |
class ModelInfo(BaseModel): | |
id: str | |
object: str | |
type: str | |
publisher: str | |
arch: str | |
compatibility_type: str | |
quantization: str | |
state: str | |
max_context_length: int | |
AVAILABLE_MODELS = [ | |
ModelInfo( | |
id="codellama-7b-instruct", | |
object="model", | |
type="llm", | |
publisher="lmstudio-community", | |
arch="llama", | |
compatibility_type="gguf", | |
quantization="Q4_K_M", | |
state="loaded", | |
max_context_length=32768 | |
) | |
] | |
def load_model(): | |
global llm | |
model_path_file = "/tmp/model_path.txt" | |
if not os.path.exists(model_path_file): | |
raise RuntimeError(f"Model path file not found: {model_path_file}") | |
with open(model_path_file, "r") as f: | |
model_path = f.read().strip() | |
if not os.path.exists(model_path): | |
raise RuntimeError(f"Model not found at path: {model_path}") | |
llm = Llama(model_path=model_path) | |
async def root(): | |
return "Ollama is running" | |
async def health_check(): | |
return {"status": "ok"} | |
async def api_tags(): | |
return JSONResponse(content={ | |
"data": [model.dict() for model in AVAILABLE_MODELS] | |
}) | |
async def list_models(): | |
# Return available models info | |
return [model.dict() for model in AVAILABLE_MODELS] | |
async def api_models(): | |
return {"data": [model.dict() for model in AVAILABLE_MODELS]} | |
async def get_model(model_id: str): | |
for model in AVAILABLE_MODELS: | |
if model.id == model_id: | |
return model.dict() | |
raise HTTPException(status_code=404, detail="Model not found") | |
async def chat(req: ChatRequest): | |
global llm | |
if llm is None: | |
return {"error": "Model not initialized."} | |
# Validate model - simple check | |
if req.model not in [m.id for m in AVAILABLE_MODELS]: | |
raise HTTPException(status_code=400, detail="Unsupported model") | |
# Construct prompt from messages | |
prompt = "" | |
for m in req.messages: | |
prompt += f"{m.role}: {m.content}\n" | |
prompt += "assistant:" | |
output = llm( | |
prompt, | |
max_tokens=req.max_tokens, | |
temperature=req.temperature, | |
stop=["user:", "assistant:"] | |
) | |
text = output.get("choices", [{}])[0].get("text", "").strip() | |
response = { | |
"id": str(uuid.uuid4()), | |
"model": req.model, | |
"choices": [ | |
{ | |
"message": {"role": "assistant", "content": text}, | |
"finish_reason": "stop" | |
} | |
] | |
} | |
return response | |
async def api_generate(req: GenerateRequest): | |
global llm | |
if llm is None: | |
raise HTTPException(status_code=503, detail="Model not initialized") | |
if req.model not in [m.id for m in AVAILABLE_MODELS]: | |
raise HTTPException(status_code=400, detail="Unsupported model") | |
output = llm( | |
req.prompt, | |
max_tokens=req.max_tokens, | |
temperature=req.temperature, | |
stop=["\n\n"] # Or any stop sequence you want | |
) | |
text = output.get("choices", [{}])[0].get("text", "").strip() | |
return { | |
"id": str(uuid.uuid4()), | |
"model": req.model, | |
"choices": [ | |
{ | |
"text": text, | |
"index": 0, | |
"finish_reason": "stop" | |
} | |
] | |
} |