JBAIP / app.py
Diamanta's picture
Update app.py
3e2914b verified
from fastapi import FastAPI, HTTPException, Request, Response
from pydantic import BaseModel
from typing import List, Optional
from llama_cpp import Llama
from fastapi.responses import PlainTextResponse, JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import logging
import json
import os
import time
import uuid
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("api_logger")
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Read request body (must be buffered manually)
body = await request.body()
logger.info(f"REQUEST: {request.method} {request.url}\nBody: {body.decode('utf-8')}")
# Rebuild the request with body for downstream handlers
request = Request(request.scope, receive=lambda: {'type': 'http.request', 'body': body})
# Process the response
response = await call_next(request)
response_body = b""
async for chunk in response.body_iterator:
response_body += chunk
# Log response body and status code
logger.info(f"RESPONSE: Status {response.status_code}\nBody: {response_body.decode('utf-8')}")
# Rebuild response to preserve original functionality
return Response(
content=response_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
# FastAPI app with middleware
app = FastAPI()
app.add_middleware(LoggingMiddleware)
llm = None
# Models
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 256
class GenerateRequest(BaseModel):
model: str
prompt: str
max_tokens: Optional[int] = 256
temperature: Optional[float] = 0.7
class ModelInfo(BaseModel):
id: str
object: str
type: str
publisher: str
arch: str
compatibility_type: str
quantization: str
state: str
max_context_length: int
AVAILABLE_MODELS = [
ModelInfo(
id="codellama-7b-instruct",
object="model",
type="llm",
publisher="lmstudio-community",
arch="llama",
compatibility_type="gguf",
quantization="Q4_K_M",
state="loaded",
max_context_length=32768
)
]
@app.on_event("startup")
def load_model():
global llm
model_path_file = "/tmp/model_path.txt"
if not os.path.exists(model_path_file):
raise RuntimeError(f"Model path file not found: {model_path_file}")
with open(model_path_file, "r") as f:
model_path = f.read().strip()
if not os.path.exists(model_path):
raise RuntimeError(f"Model not found at path: {model_path}")
llm = Llama(model_path=model_path)
@app.get("/", response_class=PlainTextResponse)
async def root():
return "Ollama is running"
@app.get("/health")
async def health_check():
return {"status": "ok"}
@app.get("/api/tags")
async def api_tags():
return JSONResponse(content={
"data": [model.dict() for model in AVAILABLE_MODELS]
})
@app.get("/models")
async def list_models():
# Return available models info
return [model.dict() for model in AVAILABLE_MODELS]
@app.get("/api/v0/models")
async def api_models():
return {"data": [model.dict() for model in AVAILABLE_MODELS]}
@app.get("/models/{model_id}")
async def get_model(model_id: str):
for model in AVAILABLE_MODELS:
if model.id == model_id:
return model.dict()
raise HTTPException(status_code=404, detail="Model not found")
@app.post("/chat")
async def chat(req: ChatRequest):
global llm
if llm is None:
return {"error": "Model not initialized."}
# Validate model - simple check
if req.model not in [m.id for m in AVAILABLE_MODELS]:
raise HTTPException(status_code=400, detail="Unsupported model")
# Construct prompt from messages
prompt = ""
for m in req.messages:
prompt += f"{m.role}: {m.content}\n"
prompt += "assistant:"
output = llm(
prompt,
max_tokens=req.max_tokens,
temperature=req.temperature,
stop=["user:", "assistant:"]
)
text = output.get("choices", [{}])[0].get("text", "").strip()
response = {
"id": str(uuid.uuid4()),
"model": req.model,
"choices": [
{
"message": {"role": "assistant", "content": text},
"finish_reason": "stop"
}
]
}
return response
@app.post("/api/v0/generate")
async def api_generate(req: GenerateRequest):
global llm
if llm is None:
raise HTTPException(status_code=503, detail="Model not initialized")
if req.model not in [m.id for m in AVAILABLE_MODELS]:
raise HTTPException(status_code=400, detail="Unsupported model")
output = llm(
req.prompt,
max_tokens=req.max_tokens,
temperature=req.temperature,
stop=["\n\n"] # Or any stop sequence you want
)
text = output.get("choices", [{}])[0].get("text", "").strip()
return {
"id": str(uuid.uuid4()),
"model": req.model,
"choices": [
{
"text": text,
"index": 0,
"finish_reason": "stop"
}
]
}