File size: 5,489 Bytes
1c0d67f
cdbdba1
ec4633f
43a49a4
cf3741b
1c0d67f
4633b64
1c0d67f
 
 
45c840a
ec4633f
c0132d6
cdbdba1
1c0d67f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdbdba1
1c0d67f
cdbdba1
ec4633f
cdbdba1
9d9d39a
cdbdba1
 
 
 
 
 
 
ec4633f
 
cdbdba1
e1b187e
 
 
 
 
 
 
9d9d39a
6c2dddd
 
 
 
 
 
 
 
5f418ba
9d9d39a
 
e1b187e
 
5f418ba
 
 
 
 
 
 
 
e1b187e
9d9d39a
 
e1b187e
45c840a
 
ec4633f
 
 
 
 
ddfcea6
45c840a
ec4633f
 
45c840a
384689e
ae0f1b9
 
 
9d9d39a
 
 
2cdd46e
72505c7
 
02da8f3
3e2914b
02da8f3
72505c7
9d9d39a
 
 
 
 
0c56b3b
d8aaf5f
6f8c0a3
e1b187e
9d9d39a
 
 
 
 
 
2cdd46e
c0132d6
 
45c840a
 
 
 
9d9d39a
 
 
 
 
ec4633f
c0132d6
 
ec4633f
 
cdbdba1
 
 
 
 
 
ec4633f
 
 
c0132d6
ec4633f
 
 
 
 
 
 
cdbdba1
e1b187e
 
0c56b3b
e1b187e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from fastapi import FastAPI, HTTPException, Request, Response
from pydantic import BaseModel
from typing import List, Optional
from llama_cpp import Llama
from fastapi.responses import PlainTextResponse, JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware


import logging
import json
import os
import time
import uuid

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("api_logger")

class LoggingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # Read request body (must be buffered manually)
        body = await request.body()
        logger.info(f"REQUEST: {request.method} {request.url}\nBody: {body.decode('utf-8')}")
        
        # Rebuild the request with body for downstream handlers
        request = Request(request.scope, receive=lambda: {'type': 'http.request', 'body': body})

        # Process the response
        response = await call_next(request)
        response_body = b""
        async for chunk in response.body_iterator:
            response_body += chunk

        # Log response body and status code
        logger.info(f"RESPONSE: Status {response.status_code}\nBody: {response_body.decode('utf-8')}")

        # Rebuild response to preserve original functionality
        return Response(
            content=response_body,
            status_code=response.status_code,
            headers=dict(response.headers),
            media_type=response.media_type
        )

# FastAPI app with middleware
app = FastAPI()
app.add_middleware(LoggingMiddleware)

llm = None

# Models
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: List[Message]
    temperature: Optional[float] = 0.7
    max_tokens: Optional[int] = 256

class GenerateRequest(BaseModel):
    model: str
    prompt: str
    max_tokens: Optional[int] = 256
    temperature: Optional[float] = 0.7


class ModelInfo(BaseModel):
      id: str
      object: str
      type: str
      publisher: str
      arch: str
      compatibility_type: str
      quantization: str
      state: str
      max_context_length: int

AVAILABLE_MODELS = [
    ModelInfo(
        id="codellama-7b-instruct",
        object="model",
        type="llm",
        publisher="lmstudio-community",
        arch="llama",
        compatibility_type="gguf",
        quantization="Q4_K_M",
        state="loaded",
        max_context_length=32768
    )
]


@app.on_event("startup")
def load_model():
    global llm
    model_path_file = "/tmp/model_path.txt"
    if not os.path.exists(model_path_file):
        raise RuntimeError(f"Model path file not found: {model_path_file}")
    with open(model_path_file, "r") as f:
        model_path = f.read().strip()
    if not os.path.exists(model_path):
        raise RuntimeError(f"Model not found at path: {model_path}")
    llm = Llama(model_path=model_path)

@app.get("/", response_class=PlainTextResponse)
async def root():
    return "Ollama is running"

@app.get("/health")
async def health_check():
    return {"status": "ok"}

@app.get("/api/tags")
async def api_tags():
    return JSONResponse(content={
        "data": [model.dict() for model in AVAILABLE_MODELS]
    })

@app.get("/models")
async def list_models():
    # Return available models info
    return [model.dict() for model in AVAILABLE_MODELS]

@app.get("/api/v0/models")
async def api_models():
    return {"data": [model.dict() for model in AVAILABLE_MODELS]}

@app.get("/models/{model_id}")
async def get_model(model_id: str):
    for model in AVAILABLE_MODELS:
        if model.id == model_id:
            return model.dict()
    raise HTTPException(status_code=404, detail="Model not found")

@app.post("/chat")
async def chat(req: ChatRequest):
    global llm
    if llm is None:
        return {"error": "Model not initialized."}

    # Validate model - simple check
    if req.model not in [m.id for m in AVAILABLE_MODELS]:
        raise HTTPException(status_code=400, detail="Unsupported model")

    # Construct prompt from messages
    prompt = ""
    for m in req.messages:
        prompt += f"{m.role}: {m.content}\n"
    prompt += "assistant:"

    output = llm(
        prompt,
        max_tokens=req.max_tokens,
        temperature=req.temperature,
        stop=["user:", "assistant:"]
    )
    text = output.get("choices", [{}])[0].get("text", "").strip()

    response = {
        "id": str(uuid.uuid4()),
        "model": req.model,
        "choices": [
            {
                "message": {"role": "assistant", "content": text},
                "finish_reason": "stop"
            }
        ]
    }
    return response

@app.post("/api/v0/generate")
async def api_generate(req: GenerateRequest):
    global llm
    if llm is None:
        raise HTTPException(status_code=503, detail="Model not initialized")

    if req.model not in [m.id for m in AVAILABLE_MODELS]:
        raise HTTPException(status_code=400, detail="Unsupported model")

    output = llm(
        req.prompt,
        max_tokens=req.max_tokens,
        temperature=req.temperature,
        stop=["\n\n"]  # Or any stop sequence you want
    )
    text = output.get("choices", [{}])[0].get("text", "").strip()

    return {
        "id": str(uuid.uuid4()),
        "model": req.model,
        "choices": [
            {
                "text": text,
                "index": 0,
                "finish_reason": "stop"
            }
        ]
    }