rkihacker commited on
Commit
ad95a9a
·
verified ·
1 Parent(s): 36f72ba

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +16 -10
main.py CHANGED
@@ -8,16 +8,16 @@ import logging
8
 
9
  app = FastAPI()
10
 
11
- # Setup logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger("proxy")
14
 
15
- # TypeGPT API settings
16
  API_URL = "https://api.typegpt.net/v1/chat/completions"
17
  API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
18
  BACKEND_MODEL = "pixtral-large-latest"
19
 
20
- # Load model -> system prompt mappings
21
  with open("model_map.json", "r", encoding="utf-8") as f:
22
  MODEL_PROMPTS = json.load(f)
23
 
@@ -37,10 +37,9 @@ class ChatRequest(BaseModel):
37
  presence_penalty: Optional[float] = 0.0
38
  frequency_penalty: Optional[float] = 0.0
39
 
40
- # Construct payload with enforced system prompt
41
  def build_payload(chat: ChatRequest):
42
  system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
43
- # Strip user system messages
44
  filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
45
  payload_messages = [{"role": "system", "content": system_prompt}] + [
46
  {"role": msg.role, "content": msg.content} for msg in filtered_messages
@@ -57,7 +56,7 @@ def build_payload(chat: ChatRequest):
57
  "frequency_penalty": chat.frequency_penalty
58
  }
59
 
60
- # Properly streamed UTF-8 chunks with model rewrite
61
  def stream_generator(requested_model: str, payload: dict, headers: dict):
62
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
63
  for line in r.iter_lines(decode_unicode=True):
@@ -72,13 +71,13 @@ def stream_generator(requested_model: str, payload: dict, headers: dict):
72
  json_obj = json.loads(content)
73
  if json_obj.get("model") == BACKEND_MODEL:
74
  json_obj["model"] = requested_model
75
- yield f"data: {json.dumps(json_obj, ensure_ascii=False)}\n\n"
76
  except json.JSONDecodeError:
77
  logger.warning("Invalid JSON in stream chunk: %s", content)
78
  else:
79
  logger.debug("Non-data stream line skipped: %s", line)
80
 
81
- # Main endpoint
82
  @app.post("/v1/chat/completions")
83
  async def proxy_chat(request: Request):
84
  try:
@@ -102,8 +101,15 @@ async def proxy_chat(request: Request):
102
  data = response.json()
103
  if "model" in data and data["model"] == BACKEND_MODEL:
104
  data["model"] = chat_request.model
105
- return JSONResponse(content=data, media_type="application/json; charset=utf-8")
 
 
 
 
106
 
107
  except Exception as e:
108
  logger.error("Error in /v1/chat/completions: %s", str(e))
109
- return JSONResponse(content={"error": "Internal server error."}, status_code=500)
 
 
 
 
8
 
9
  app = FastAPI()
10
 
11
+ # Logging setup
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger("proxy")
14
 
15
+ # Configuration
16
  API_URL = "https://api.typegpt.net/v1/chat/completions"
17
  API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
18
  BACKEND_MODEL = "pixtral-large-latest"
19
 
20
+ # Load system prompt mappings
21
  with open("model_map.json", "r", encoding="utf-8") as f:
22
  MODEL_PROMPTS = json.load(f)
23
 
 
37
  presence_penalty: Optional[float] = 0.0
38
  frequency_penalty: Optional[float] = 0.0
39
 
40
+ # Build request to backend with injected system prompt
41
  def build_payload(chat: ChatRequest):
42
  system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
 
43
  filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
44
  payload_messages = [{"role": "system", "content": system_prompt}] + [
45
  {"role": msg.role, "content": msg.content} for msg in filtered_messages
 
56
  "frequency_penalty": chat.frequency_penalty
57
  }
58
 
59
+ # Streaming chunk handler with model replacement and UTF-8 fix
60
  def stream_generator(requested_model: str, payload: dict, headers: dict):
61
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
62
  for line in r.iter_lines(decode_unicode=True):
 
71
  json_obj = json.loads(content)
72
  if json_obj.get("model") == BACKEND_MODEL:
73
  json_obj["model"] = requested_model
74
+ yield "data: " + json.dumps(json_obj, ensure_ascii=False) + "\n\n"
75
  except json.JSONDecodeError:
76
  logger.warning("Invalid JSON in stream chunk: %s", content)
77
  else:
78
  logger.debug("Non-data stream line skipped: %s", line)
79
 
80
+ # Main API endpoint
81
  @app.post("/v1/chat/completions")
82
  async def proxy_chat(request: Request):
83
  try:
 
101
  data = response.json()
102
  if "model" in data and data["model"] == BACKEND_MODEL:
103
  data["model"] = chat_request.model
104
+ return JSONResponse(
105
+ content=data,
106
+ media_type="application/json; charset=utf-8",
107
+ headers={"Content-Type": "application/json; charset=utf-8"}
108
+ )
109
 
110
  except Exception as e:
111
  logger.error("Error in /v1/chat/completions: %s", str(e))
112
+ return JSONResponse(
113
+ content={"error": "Internal server error."},
114
+ status_code=500
115
+ )