from fastapi import FastAPI, HTTPException, Depends, Response from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from prometheus_client import generate_latest from health import check_docker_health, check_gpu_availability from typing import List, Optional, Union import time import logging import json from auth import get_api_key, rate_limiter, api_requests, request_duration from model_manager import ModelManager # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Docker Model Runner OpenAI-Compatible API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize model manager model_manager = ModelManager() class ChatMessage(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0) max_tokens: Optional[int] = Field(256, gt=0) stream: Optional[bool] = False class CompletionRequest(BaseModel): model: str prompt: str temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0) max_tokens: Optional[int] = Field(256, gt=0) stream: Optional[bool] = False class EmbeddingRequest(BaseModel): model: str input: Union[str, List[str]] encoding_format: Optional[str] = "float" @app.post("/v1/chat/completions") async def create_chat_completion( request: ChatCompletionRequest, api_key: str = Depends(get_api_key) ): """Create a chat completion.""" rate_limiter.check(api_key) api_requests.labels(endpoint="chat_completions").inc() with request_duration.time(): try: formatted_messages = [ {"role": msg.role, "content": msg.content} for msg in request.messages ] response = model_manager.run_model( request.model, formatted_messages, temperature=request.temperature, max_tokens=request.max_tokens ) return { "id": f"chatcmpl-{int(time.time()*1000)}", "object": "chat.completion", "created": int(time.time()), "model": request.model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": response["output"] }, "finish_reason": "stop" }], "usage": response.get("usage", { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 }) } except Exception as e: logger.error(f"Chat completion error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/completions") async def create_completion( request: CompletionRequest, api_key: str = Depends(get_api_key) ): """Create a text completion.""" rate_limiter.check(api_key) api_requests.labels(endpoint="completions").inc() with request_duration.time(): try: response = model_manager.run_model( request.model, request.prompt, temperature=request.temperature, max_tokens=request.max_tokens ) return { "id": f"cmpl-{int(time.time()*1000)}", "object": "text_completion", "created": int(time.time()), "model": request.model, "choices": [{ "text": response["output"], "index": 0, "finish_reason": "stop" }], "usage": response.get("usage", { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 }) } except Exception as e: logger.error(f"Completion error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/embeddings") async def create_embedding( request: EmbeddingRequest, api_key: str = Depends(get_api_key) ): """Create embeddings for text.""" rate_limiter.check(api_key) api_requests.labels(endpoint="embeddings").inc() with request_duration.time(): try: inputs = request.input if isinstance(request.input, list) else [request.input] response = model_manager.run_model( request.model, inputs ) return { "object": "list", "data": [ { "object": "embedding", "embedding": emb, "index": i } for i, emb in enumerate(response["embeddings"]) ], "model": request.model, "usage": response.get("usage", { "prompt_tokens": 0, "total_tokens": 0 }) } except Exception as e: logger.error(f"Embedding error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/v1/models") async def list_models( api_key: str = Depends(get_api_key) ): """List available models.""" api_requests.labels(endpoint="models").inc() return model_manager.list_models() @app.get("/metrics") async def metrics(): """Expose Prometheus metrics.""" return Response( media_type="text/plain", content=generate_latest() ) @app.get("/health") async def health_check(): """Check the health of the API and its dependencies.""" docker_health = check_docker_health() gpu_status = check_gpu_availability() health_status = { "status": "healthy" if docker_health["status"] == "healthy" else "unhealthy", "docker": docker_health, "gpu": gpu_status, "api_version": "1.0.0" } status_code = 200 if health_status["status"] == "healthy" else 503 return Response( content=json.dumps(health_status), media_type="application/json", status_code=status_code )