Spaces:
Sleeping
Sleeping
File size: 5,489 Bytes
1c0d67f cdbdba1 ec4633f 43a49a4 cf3741b 1c0d67f 4633b64 1c0d67f 45c840a ec4633f c0132d6 cdbdba1 1c0d67f cdbdba1 1c0d67f cdbdba1 ec4633f cdbdba1 9d9d39a cdbdba1 ec4633f cdbdba1 e1b187e 9d9d39a 6c2dddd 5f418ba 9d9d39a e1b187e 5f418ba e1b187e 9d9d39a e1b187e 45c840a ec4633f ddfcea6 45c840a ec4633f 45c840a 384689e ae0f1b9 9d9d39a 2cdd46e 72505c7 02da8f3 3e2914b 02da8f3 72505c7 9d9d39a 0c56b3b d8aaf5f 6f8c0a3 e1b187e 9d9d39a 2cdd46e c0132d6 45c840a 9d9d39a ec4633f c0132d6 ec4633f cdbdba1 ec4633f c0132d6 ec4633f cdbdba1 e1b187e 0c56b3b e1b187e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
from fastapi import FastAPI, HTTPException, Request, Response
from pydantic import BaseModel
from typing import List, Optional
from llama_cpp import Llama
from fastapi.responses import PlainTextResponse, JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import logging
import json
import os
import time
import uuid
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("api_logger")
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Read request body (must be buffered manually)
body = await request.body()
logger.info(f"REQUEST: {request.method} {request.url}\nBody: {body.decode('utf-8')}")
# Rebuild the request with body for downstream handlers
request = Request(request.scope, receive=lambda: {'type': 'http.request', 'body': body})
# Process the response
response = await call_next(request)
response_body = b""
async for chunk in response.body_iterator:
response_body += chunk
# Log response body and status code
logger.info(f"RESPONSE: Status {response.status_code}\nBody: {response_body.decode('utf-8')}")
# Rebuild response to preserve original functionality
return Response(
content=response_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
# FastAPI app with middleware
app = FastAPI()
app.add_middleware(LoggingMiddleware)
llm = None
# Models
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 256
class GenerateRequest(BaseModel):
model: str
prompt: str
max_tokens: Optional[int] = 256
temperature: Optional[float] = 0.7
class ModelInfo(BaseModel):
id: str
object: str
type: str
publisher: str
arch: str
compatibility_type: str
quantization: str
state: str
max_context_length: int
AVAILABLE_MODELS = [
ModelInfo(
id="codellama-7b-instruct",
object="model",
type="llm",
publisher="lmstudio-community",
arch="llama",
compatibility_type="gguf",
quantization="Q4_K_M",
state="loaded",
max_context_length=32768
)
]
@app.on_event("startup")
def load_model():
global llm
model_path_file = "/tmp/model_path.txt"
if not os.path.exists(model_path_file):
raise RuntimeError(f"Model path file not found: {model_path_file}")
with open(model_path_file, "r") as f:
model_path = f.read().strip()
if not os.path.exists(model_path):
raise RuntimeError(f"Model not found at path: {model_path}")
llm = Llama(model_path=model_path)
@app.get("/", response_class=PlainTextResponse)
async def root():
return "Ollama is running"
@app.get("/health")
async def health_check():
return {"status": "ok"}
@app.get("/api/tags")
async def api_tags():
return JSONResponse(content={
"data": [model.dict() for model in AVAILABLE_MODELS]
})
@app.get("/models")
async def list_models():
# Return available models info
return [model.dict() for model in AVAILABLE_MODELS]
@app.get("/api/v0/models")
async def api_models():
return {"data": [model.dict() for model in AVAILABLE_MODELS]}
@app.get("/models/{model_id}")
async def get_model(model_id: str):
for model in AVAILABLE_MODELS:
if model.id == model_id:
return model.dict()
raise HTTPException(status_code=404, detail="Model not found")
@app.post("/chat")
async def chat(req: ChatRequest):
global llm
if llm is None:
return {"error": "Model not initialized."}
# Validate model - simple check
if req.model not in [m.id for m in AVAILABLE_MODELS]:
raise HTTPException(status_code=400, detail="Unsupported model")
# Construct prompt from messages
prompt = ""
for m in req.messages:
prompt += f"{m.role}: {m.content}\n"
prompt += "assistant:"
output = llm(
prompt,
max_tokens=req.max_tokens,
temperature=req.temperature,
stop=["user:", "assistant:"]
)
text = output.get("choices", [{}])[0].get("text", "").strip()
response = {
"id": str(uuid.uuid4()),
"model": req.model,
"choices": [
{
"message": {"role": "assistant", "content": text},
"finish_reason": "stop"
}
]
}
return response
@app.post("/api/v0/generate")
async def api_generate(req: GenerateRequest):
global llm
if llm is None:
raise HTTPException(status_code=503, detail="Model not initialized")
if req.model not in [m.id for m in AVAILABLE_MODELS]:
raise HTTPException(status_code=400, detail="Unsupported model")
output = llm(
req.prompt,
max_tokens=req.max_tokens,
temperature=req.temperature,
stop=["\n\n"] # Or any stop sequence you want
)
text = output.get("choices", [{}])[0].get("text", "").strip()
return {
"id": str(uuid.uuid4()),
"model": req.model,
"choices": [
{
"text": text,
"index": 0,
"finish_reason": "stop"
}
]
} |