Update main.py
Browse files
main.py
CHANGED
@@ -8,16 +8,16 @@ import logging
|
|
8 |
|
9 |
app = FastAPI()
|
10 |
|
11 |
-
#
|
12 |
logging.basicConfig(level=logging.INFO)
|
13 |
logger = logging.getLogger("proxy")
|
14 |
|
15 |
-
#
|
16 |
API_URL = "https://api.typegpt.net/v1/chat/completions"
|
17 |
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
|
18 |
BACKEND_MODEL = "pixtral-large-latest"
|
19 |
|
20 |
-
# Load
|
21 |
with open("model_map.json", "r", encoding="utf-8") as f:
|
22 |
MODEL_PROMPTS = json.load(f)
|
23 |
|
@@ -37,10 +37,9 @@ class ChatRequest(BaseModel):
|
|
37 |
presence_penalty: Optional[float] = 0.0
|
38 |
frequency_penalty: Optional[float] = 0.0
|
39 |
|
40 |
-
#
|
41 |
def build_payload(chat: ChatRequest):
|
42 |
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
|
43 |
-
# Strip user system messages
|
44 |
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
|
45 |
payload_messages = [{"role": "system", "content": system_prompt}] + [
|
46 |
{"role": msg.role, "content": msg.content} for msg in filtered_messages
|
@@ -57,7 +56,7 @@ def build_payload(chat: ChatRequest):
|
|
57 |
"frequency_penalty": chat.frequency_penalty
|
58 |
}
|
59 |
|
60 |
-
#
|
61 |
def stream_generator(requested_model: str, payload: dict, headers: dict):
|
62 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
63 |
for line in r.iter_lines(decode_unicode=True):
|
@@ -72,13 +71,13 @@ def stream_generator(requested_model: str, payload: dict, headers: dict):
|
|
72 |
json_obj = json.loads(content)
|
73 |
if json_obj.get("model") == BACKEND_MODEL:
|
74 |
json_obj["model"] = requested_model
|
75 |
-
yield
|
76 |
except json.JSONDecodeError:
|
77 |
logger.warning("Invalid JSON in stream chunk: %s", content)
|
78 |
else:
|
79 |
logger.debug("Non-data stream line skipped: %s", line)
|
80 |
|
81 |
-
# Main endpoint
|
82 |
@app.post("/v1/chat/completions")
|
83 |
async def proxy_chat(request: Request):
|
84 |
try:
|
@@ -102,8 +101,15 @@ async def proxy_chat(request: Request):
|
|
102 |
data = response.json()
|
103 |
if "model" in data and data["model"] == BACKEND_MODEL:
|
104 |
data["model"] = chat_request.model
|
105 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
106 |
|
107 |
except Exception as e:
|
108 |
logger.error("Error in /v1/chat/completions: %s", str(e))
|
109 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
8 |
|
9 |
app = FastAPI()
|
10 |
|
11 |
+
# Logging setup
|
12 |
logging.basicConfig(level=logging.INFO)
|
13 |
logger = logging.getLogger("proxy")
|
14 |
|
15 |
+
# Configuration
|
16 |
API_URL = "https://api.typegpt.net/v1/chat/completions"
|
17 |
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
|
18 |
BACKEND_MODEL = "pixtral-large-latest"
|
19 |
|
20 |
+
# Load system prompt mappings
|
21 |
with open("model_map.json", "r", encoding="utf-8") as f:
|
22 |
MODEL_PROMPTS = json.load(f)
|
23 |
|
|
|
37 |
presence_penalty: Optional[float] = 0.0
|
38 |
frequency_penalty: Optional[float] = 0.0
|
39 |
|
40 |
+
# Build request to backend with injected system prompt
|
41 |
def build_payload(chat: ChatRequest):
|
42 |
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
|
|
|
43 |
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
|
44 |
payload_messages = [{"role": "system", "content": system_prompt}] + [
|
45 |
{"role": msg.role, "content": msg.content} for msg in filtered_messages
|
|
|
56 |
"frequency_penalty": chat.frequency_penalty
|
57 |
}
|
58 |
|
59 |
+
# Streaming chunk handler with model replacement and UTF-8 fix
|
60 |
def stream_generator(requested_model: str, payload: dict, headers: dict):
|
61 |
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
|
62 |
for line in r.iter_lines(decode_unicode=True):
|
|
|
71 |
json_obj = json.loads(content)
|
72 |
if json_obj.get("model") == BACKEND_MODEL:
|
73 |
json_obj["model"] = requested_model
|
74 |
+
yield "data: " + json.dumps(json_obj, ensure_ascii=False) + "\n\n"
|
75 |
except json.JSONDecodeError:
|
76 |
logger.warning("Invalid JSON in stream chunk: %s", content)
|
77 |
else:
|
78 |
logger.debug("Non-data stream line skipped: %s", line)
|
79 |
|
80 |
+
# Main API endpoint
|
81 |
@app.post("/v1/chat/completions")
|
82 |
async def proxy_chat(request: Request):
|
83 |
try:
|
|
|
101 |
data = response.json()
|
102 |
if "model" in data and data["model"] == BACKEND_MODEL:
|
103 |
data["model"] = chat_request.model
|
104 |
+
return JSONResponse(
|
105 |
+
content=data,
|
106 |
+
media_type="application/json; charset=utf-8",
|
107 |
+
headers={"Content-Type": "application/json; charset=utf-8"}
|
108 |
+
)
|
109 |
|
110 |
except Exception as e:
|
111 |
logger.error("Error in /v1/chat/completions: %s", str(e))
|
112 |
+
return JSONResponse(
|
113 |
+
content={"error": "Internal server error."},
|
114 |
+
status_code=500
|
115 |
+
)
|