from fastapi import APIRouter from fastapi.responses import StreamingResponse from models.chat_completion import ChatRequest from huggingface_hub import InferenceClient import json router = APIRouter() def generate_stream(response): for chunk in response: yield f"data: {json.dumps(chunk.__dict__, separators=(',', ':'))}\n\n" @router.post("/v1/chat/completions", tags=["Chat Completion"]) async def chat_completion(body: ChatRequest): client = InferenceClient(model=body.model) res = client.chat_completion( messages=body.messages, frequency_penalty=body.frequency_penalty, logit_bias=body.logit_bias, logprobs=body.logprobs, max_tokens=body.max_tokens, n=body.n, presence_penalty=body.presence_penalty, response_format=body.response_format, seed=body.seed, stop=body.stop, stream=body.stream, stream_options=body.stream_options, temperature=body.temperature, top_logprobs=body.top_logprobs, top_p=body.top_p, tool_choice=body.tool_choice, tool_prompt=body.tool_prompt, tools=body.tools ) if not body.stream: return json.dumps(res.__dict__, indent=2) else: return StreamingResponse(generate_stream(res), media_type="text/event-stream")