Spaces:
Running
Running
from fastapi import APIRouter | |
from fastapi.responses import StreamingResponse | |
from models.chat_completion import ChatRequest | |
from huggingface_hub import InferenceClient | |
import json | |
router = APIRouter() | |
def generate_stream(response): | |
try: | |
for chunk in response: | |
try: | |
# Attempt to process and yield the chunk | |
yield f"data: {json.dumps(chunk.__dict__, separators=(',', ':'))}\n\n" | |
except Exception as e: | |
# Optional: Log the error for debugging | |
print(f"Error during stream processing: {e}") | |
# Stop sending chunks if an error occurs | |
break | |
finally: | |
# Ensure the [DONE] message is always sent, even if an error occurred | |
yield "data: [DONE]\n\n" | |
async def chat_completion(body: ChatRequest): | |
client = InferenceClient(model=body.model) | |
res = client.chat_completion( | |
messages=body.messages, | |
frequency_penalty=body.frequency_penalty, | |
logit_bias=body.logit_bias, | |
logprobs=body.logprobs, | |
max_tokens=body.max_tokens, | |
n=body.n, | |
presence_penalty=body.presence_penalty, | |
response_format=body.response_format, | |
seed=body.seed, | |
stop=body.stop, | |
stream=body.stream, | |
stream_options=body.stream_options, | |
temperature=body.temperature, | |
top_logprobs=body.top_logprobs, | |
top_p=body.top_p, | |
tool_choice=body.tool_choice, | |
tool_prompt=body.tool_prompt, | |
tools=body.tools | |
) | |
if not body.stream: | |
return json.dumps(res.__dict__, indent=2) | |
else: | |
return StreamingResponse(generate_stream(res), media_type="text/event-stream") |