rkihacker commited on
Commit
4d2b16a
·
verified ·
1 Parent(s): f7f30cc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +17 -10
main.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, Request
2
  from fastapi.responses import StreamingResponse, JSONResponse
3
- from pydantic import BaseModel, Field
4
  from typing import List, Optional, Union
5
  import requests
6
  import json
@@ -49,15 +49,23 @@ def build_payload(chat: ChatRequest):
49
  "frequency_penalty": chat.frequency_penalty
50
  }
51
 
52
- def stream_generator(requested_model, payload, headers):
53
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
54
- for line in r.iter_lines():
55
- if line:
56
- decoded = line.decode('utf-8')
57
- # Rewrite the model name in streaming output
58
- if BACKEND_MODEL in decoded:
59
- decoded = decoded.replace(BACKEND_MODEL, requested_model)
60
- yield f"data: {decoded}\n\n"
 
 
 
 
 
 
 
 
61
  yield "data: [DONE]\n\n"
62
 
63
  @app.post("/v1/chat/completions")
@@ -78,7 +86,6 @@ async def proxy_chat(request: Request):
78
  else:
79
  response = requests.post(API_URL, headers=headers, json=payload)
80
  data = response.json()
81
- # Replace model in final result
82
  if "model" in data and data["model"] == BACKEND_MODEL:
83
  data["model"] = chat_request.model
84
  return JSONResponse(content=data)
 
1
  from fastapi import FastAPI, Request
2
  from fastapi.responses import StreamingResponse, JSONResponse
3
+ from pydantic import BaseModel
4
  from typing import List, Optional, Union
5
  import requests
6
  import json
 
49
  "frequency_penalty": chat.frequency_penalty
50
  }
51
 
52
+ def stream_generator(requested_model: str, payload: dict, headers: dict):
53
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
54
+ for line in r.iter_lines(decode_unicode=True):
55
+ if line and line.startswith("data:"):
56
+ # Remove "data: " prefix
57
+ content = line[6:].strip()
58
+ try:
59
+ # Try to parse and replace model field
60
+ json_obj = json.loads(content)
61
+ if "model" in json_obj and json_obj["model"] == BACKEND_MODEL:
62
+ json_obj["model"] = requested_model
63
+ fixed_line = f"data: {json.dumps(json_obj)}\n\n"
64
+ except json.JSONDecodeError:
65
+ fixed_line = f"data: {content}\n\n"
66
+ yield fixed_line
67
+ elif line:
68
+ yield f"data: {line}\n\n"
69
  yield "data: [DONE]\n\n"
70
 
71
  @app.post("/v1/chat/completions")
 
86
  else:
87
  response = requests.post(API_URL, headers=headers, json=payload)
88
  data = response.json()
 
89
  if "model" in data and data["model"] == BACKEND_MODEL:
90
  data["model"] = chat_request.model
91
  return JSONResponse(content=data)