Update main.py
Browse files
main.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
from fastapi import FastAPI, Request
|
2 |
-
from
|
|
|
|
|
3 |
import requests
|
4 |
import json
|
5 |
-
import os
|
6 |
|
7 |
app = FastAPI()
|
8 |
|
@@ -10,6 +11,7 @@ API_URL = "https://api.typegpt.net/v1/chat/completions"
|
|
10 |
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
|
11 |
BACKEND_MODEL = "pixtral-large-latest"
|
12 |
|
|
|
13 |
with open("model_map.json", "r") as f:
|
14 |
MODEL_PROMPTS = json.load(f)
|
15 |
|
@@ -19,26 +21,64 @@ class Message(BaseModel):
|
|
19 |
|
20 |
class ChatRequest(BaseModel):
|
21 |
model: str
|
22 |
-
messages:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
payload = {
|
32 |
"model": BACKEND_MODEL,
|
33 |
-
"messages":
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
}
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
headers = {
|
39 |
"Authorization": f"Bearer {API_KEY}",
|
40 |
"Content-Type": "application/json"
|
41 |
}
|
42 |
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI, Request
|
2 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from typing import List, Optional, Union
|
5 |
import requests
|
6 |
import json
|
|
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
|
|
11 |
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
|
12 |
BACKEND_MODEL = "pixtral-large-latest"
|
13 |
|
14 |
+
# Load virtual model -> system prompt mappings
|
15 |
with open("model_map.json", "r") as f:
|
16 |
MODEL_PROMPTS = json.load(f)
|
17 |
|
|
|
21 |
|
22 |
class ChatRequest(BaseModel):
|
23 |
model: str
|
24 |
+
messages: List[Message]
|
25 |
+
stream: Optional[bool] = False
|
26 |
+
temperature: Optional[float] = 1.0
|
27 |
+
top_p: Optional[float] = 1.0
|
28 |
+
n: Optional[int] = 1
|
29 |
+
stop: Optional[Union[str, List[str]]] = None
|
30 |
+
presence_penalty: Optional[float] = 0.0
|
31 |
+
frequency_penalty: Optional[float] = 0.0
|
32 |
|
33 |
+
def build_payload(chat: ChatRequest):
|
34 |
+
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
|
35 |
+
|
36 |
+
messages = [{"role": "system", "content": system_prompt}] + [
|
37 |
+
{"role": msg.role, "content": msg.content} for msg in chat.messages
|
38 |
+
]
|
39 |
|
40 |
+
return {
|
|
|
41 |
"model": BACKEND_MODEL,
|
42 |
+
"messages": messages,
|
43 |
+
"stream": chat.stream,
|
44 |
+
"temperature": chat.temperature,
|
45 |
+
"top_p": chat.top_p,
|
46 |
+
"n": chat.n,
|
47 |
+
"stop": chat.stop,
|
48 |
+
"presence_penalty": chat.presence_penalty,
|
49 |
+
"frequency_penalty": chat.frequency_penalty
|
50 |
}
|
51 |
|
52 |
+
def stream_generator(requested_model, payload, headers):
|
53 |
+
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
54 |
+
for line in r.iter_lines():
|
55 |
+
if line:
|
56 |
+
decoded = line.decode('utf-8')
|
57 |
+
# Rewrite the model name in streaming output
|
58 |
+
if BACKEND_MODEL in decoded:
|
59 |
+
decoded = decoded.replace(BACKEND_MODEL, requested_model)
|
60 |
+
yield f"data: {decoded}\n\n"
|
61 |
+
yield "data: [DONE]\n\n"
|
62 |
+
|
63 |
+
@app.post("/v1/chat/completions")
|
64 |
+
async def proxy_chat(request: Request):
|
65 |
+
body = await request.json()
|
66 |
+
chat_request = ChatRequest(**body)
|
67 |
+
payload = build_payload(chat_request)
|
68 |
headers = {
|
69 |
"Authorization": f"Bearer {API_KEY}",
|
70 |
"Content-Type": "application/json"
|
71 |
}
|
72 |
|
73 |
+
if chat_request.stream:
|
74 |
+
return StreamingResponse(
|
75 |
+
stream_generator(chat_request.model, payload, headers),
|
76 |
+
media_type="text/event-stream"
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
80 |
+
data = response.json()
|
81 |
+
# Replace model in final result
|
82 |
+
if "model" in data and data["model"] == BACKEND_MODEL:
|
83 |
+
data["model"] = chat_request.model
|
84 |
+
return JSONResponse(content=data)
|