Update main.py
Browse files
main.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from fastapi import FastAPI, Request
|
2 |
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
-
from pydantic import BaseModel
|
4 |
from typing import List, Optional, Union
|
5 |
import requests
|
6 |
import json
|
@@ -49,15 +49,23 @@ def build_payload(chat: ChatRequest):
|
|
49 |
"frequency_penalty": chat.frequency_penalty
|
50 |
}
|
51 |
|
52 |
-
def stream_generator(requested_model, payload, headers):
|
53 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
54 |
-
for line in r.iter_lines():
|
55 |
-
if line:
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
yield "data: [DONE]\n\n"
|
62 |
|
63 |
@app.post("/v1/chat/completions")
|
@@ -78,7 +86,6 @@ async def proxy_chat(request: Request):
|
|
78 |
else:
|
79 |
response = requests.post(API_URL, headers=headers, json=payload)
|
80 |
data = response.json()
|
81 |
-
# Replace model in final result
|
82 |
if "model" in data and data["model"] == BACKEND_MODEL:
|
83 |
data["model"] = chat_request.model
|
84 |
return JSONResponse(content=data)
|
|
|
1 |
from fastapi import FastAPI, Request
|
2 |
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
+
from pydantic import BaseModel
|
4 |
from typing import List, Optional, Union
|
5 |
import requests
|
6 |
import json
|
|
|
49 |
"frequency_penalty": chat.frequency_penalty
|
50 |
}
|
51 |
|
52 |
+
def stream_generator(requested_model: str, payload: dict, headers: dict):
|
53 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
54 |
+
for line in r.iter_lines(decode_unicode=True):
|
55 |
+
if line and line.startswith("data:"):
|
56 |
+
# Remove "data: " prefix
|
57 |
+
content = line[6:].strip()
|
58 |
+
try:
|
59 |
+
# Try to parse and replace model field
|
60 |
+
json_obj = json.loads(content)
|
61 |
+
if "model" in json_obj and json_obj["model"] == BACKEND_MODEL:
|
62 |
+
json_obj["model"] = requested_model
|
63 |
+
fixed_line = f"data: {json.dumps(json_obj)}\n\n"
|
64 |
+
except json.JSONDecodeError:
|
65 |
+
fixed_line = f"data: {content}\n\n"
|
66 |
+
yield fixed_line
|
67 |
+
elif line:
|
68 |
+
yield f"data: {line}\n\n"
|
69 |
yield "data: [DONE]\n\n"
|
70 |
|
71 |
@app.post("/v1/chat/completions")
|
|
|
86 |
else:
|
87 |
response = requests.post(API_URL, headers=headers, json=payload)
|
88 |
data = response.json()
|
|
|
89 |
if "model" in data and data["model"] == BACKEND_MODEL:
|
90 |
data["model"] = chat_request.model
|
91 |
return JSONResponse(content=data)
|