File size: 5,424 Bytes
daa1bb4
f7f30cc
4d2b16a
f7f30cc
bcdcebb
daa1bb4
89d8cc9
daa1bb4
 
bcdcebb
ad95a9a
89d8cc9
 
 
ad95a9a
bcdcebb
 
0c75fa7
bcdcebb
ad95a9a
36f72ba
daa1bb4
bcdcebb
89d8cc9
db3a9bd
 
 
 
 
 
 
 
 
 
 
 
 
daa1bb4
 
db3a9bd
bcdcebb
db3a9bd
daa1bb4
 
f7f30cc
 
 
 
 
 
 
 
bcdcebb
ad95a9a
f7f30cc
 
a7b9a59
db3a9bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f30cc
bcdcebb
a7b9a59
f7f30cc
 
 
 
 
 
a7b9a59
bcdcebb
 
06ea63c
4d2b16a
f7f30cc
06ea63c
89d8cc9
 
06ea63c
4d2b16a
06ea63c
 
89d8cc9
4d2b16a
06ea63c
89d8cc9
4d2b16a
06ea63c
4d2b16a
89d8cc9
 
 
f7f30cc
06ea63c
f7f30cc
 
89d8cc9
 
 
 
a7b9a59
89d8cc9
 
 
 
 
 
 
 
06ea63c
89d8cc9
 
 
06ea63c
89d8cc9
 
 
06ea63c
bcdcebb
89d8cc9
a7b9a59
ad95a9a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging

app = FastAPI()

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")

# Configuration
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "gpt-4o-mini"

# Load system prompt mappings
with open("model_map.json", "r", encoding="utf-8") as f:
    MODEL_PROMPTS = json.load(f)

# Request schema

# Define ContentType for vision
class ContentImage(BaseModel):
    type: str  # must be "image_url"
    image_url: dict  # {"url": "https://..." or "data:image/...;base64,..."}

class ContentText(BaseModel):
    type: str  # must be "text"
    text: str

ContentType = Union[ContentText, ContentImage]

# Message model allows BOTH old and new formats
class Message(BaseModel):
    role: str
    content: Union[str, List[ContentType]]  # str (legacy) or list of ContentType

# ChatRequest model
class ChatRequest(BaseModel):
    model: str
    messages: List[Message]
    stream: Optional[bool] = False
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
    n: Optional[int] = 1
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0

# Build request to backend with injected system prompt
def build_payload(chat: ChatRequest):
    system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
    filtered_messages = [msg for msg in chat.messages if msg.role != "system"]

    payload_messages = [{"role": "system", "content": system_prompt}]

    for msg in filtered_messages:
        # Legacy format: string
        if isinstance(msg.content, str):
            payload_messages.append({"role": msg.role, "content": msg.content})

        # Multimodal format
        elif isinstance(msg.content, list):
            content_payload = []
            for content_item in msg.content:
                # ContentText
                if content_item.type == "text":
                    content_payload.append({
                        "type": "text",
                        "text": content_item.text
                    })
                # ContentImage
                elif content_item.type == "image_url":
                    content_payload.append({
                        "type": "image_url",
                        "image_url": content_item.image_url
                    })
                else:
                    logger.warning(f"Unknown content type: {content_item.type}, skipping.")

            payload_messages.append({"role": msg.role, "content": content_payload})

        else:
            logger.warning(f"Unknown message content format: {msg.content}")

    return {
        "model": BACKEND_MODEL,
        "messages": payload_messages,
        "stream": chat.stream,
        "temperature": chat.temperature,
        "top_p": chat.top_p,
        "n": chat.n,
        "stop": chat.stop,
        "presence_penalty": chat.presence_penalty,
        "frequency_penalty": chat.frequency_penalty
    }

# Stream generator without forcing UTF-8
def stream_generator(requested_model: str, payload: dict, headers: dict):
    with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
        for line in r.iter_lines(decode_unicode=False):  # Keep as bytes
            if not line:
                continue
            if line.startswith(b"data:"):
                content = line[6:].strip()
                if content == b"[DONE]":
                    yield b"data: [DONE]\n\n"
                    continue
                try:
                    json_obj = json.loads(content.decode("utf-8"))
                    if json_obj.get("model") == BACKEND_MODEL:
                        json_obj["model"] = requested_model
                    yield f"data: {json.dumps(json_obj)}\n\n".encode("utf-8")
                except json.JSONDecodeError:
                    logger.warning("Invalid JSON in stream chunk: %s", content)
            else:
                logger.debug("Non-data stream line skipped: %s", line)

# Main endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
    try:
        body = await request.json()
        chat_request = ChatRequest(**body)
        payload = build_payload(chat_request)

        headers = {
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": "application/json"
        }

        if chat_request.stream:
            return StreamingResponse(
                stream_generator(chat_request.model, payload, headers),
                media_type="text/event-stream"
            )
        else:
            response = requests.post(API_URL, headers=headers, json=payload)
            response.raise_for_status()  # Raise error for bad responses
            data = response.json()
            if "model" in data and data["model"] == BACKEND_MODEL:
                data["model"] = chat_request.model
            return JSONResponse(content=data)

    except Exception as e:
        logger.error("Error in /v1/chat/completions: %s", str(e))
        return JSONResponse(
            content={"error": "Internal server error."},
            status_code=500
        )