| from __future__ import annotations |
|
|
| import sys |
| import traceback |
| import time |
| from re import compile, Match, Pattern |
| from typing import Callable, Coroutine, Optional, Tuple, Union, Dict |
| from typing_extensions import TypedDict |
|
|
|
|
| from fastapi import ( |
| Request, |
| Response, |
| HTTPException, |
| ) |
| from fastapi.responses import JSONResponse |
| from fastapi.routing import APIRoute |
|
|
| from llama_cpp.server.types import ( |
| CreateCompletionRequest, |
| CreateEmbeddingRequest, |
| CreateChatCompletionRequest, |
| ) |
|
|
|
|
| class ErrorResponse(TypedDict): |
| """OpenAI style error response""" |
|
|
| message: str |
| type: str |
| param: Optional[str] |
| code: Optional[str] |
|
|
|
|
| class ErrorResponseFormatters: |
| """Collection of formatters for error responses. |
| |
| Args: |
| request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): |
| Request body |
| match (Match[str]): Match object from regex pattern |
| |
| Returns: |
| Tuple[int, ErrorResponse]: Status code and error response |
| """ |
|
|
| @staticmethod |
| def context_length_exceeded( |
| request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], |
| match, |
| ) -> Tuple[int, ErrorResponse]: |
| """Formatter for context length exceeded error""" |
|
|
| context_window = int(match.group(2)) |
| prompt_tokens = int(match.group(1)) |
| completion_tokens = request.max_tokens |
| if hasattr(request, "messages"): |
| |
| message = ( |
| "This model's maximum context length is {} tokens. " |
| "However, you requested {} tokens " |
| "({} in the messages, {} in the completion). " |
| "Please reduce the length of the messages or completion." |
| ) |
| else: |
| |
| message = ( |
| "This model's maximum context length is {} tokens, " |
| "however you requested {} tokens " |
| "({} in your prompt; {} for the completion). " |
| "Please reduce your prompt; or completion length." |
| ) |
| return 400, ErrorResponse( |
| message=message.format( |
| context_window, |
| (completion_tokens or 0) + prompt_tokens, |
| prompt_tokens, |
| completion_tokens, |
| ), |
| type="invalid_request_error", |
| param="messages", |
| code="context_length_exceeded", |
| ) |
|
|
| @staticmethod |
| def model_not_found( |
| request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], |
| match, |
| ) -> Tuple[int, ErrorResponse]: |
| """Formatter for model_not_found error""" |
|
|
| model_path = str(match.group(1)) |
| message = f"The model `{model_path}` does not exist" |
| return 400, ErrorResponse( |
| message=message, |
| type="invalid_request_error", |
| param=None, |
| code="model_not_found", |
| ) |
|
|
|
|
| class RouteErrorHandler(APIRoute): |
| """Custom APIRoute that handles application errors and exceptions""" |
|
|
| |
| |
| pattern_and_formatters: Dict[ |
| "Pattern[str]", |
| Callable[ |
| [ |
| Union["CreateCompletionRequest", "CreateChatCompletionRequest"], |
| "Match[str]", |
| ], |
| Tuple[int, ErrorResponse], |
| ], |
| ] = { |
| compile( |
| r"Requested tokens \((\d+)\) exceed context window of (\d+)" |
| ): ErrorResponseFormatters.context_length_exceeded, |
| compile( |
| r"Model path does not exist: (.+)" |
| ): ErrorResponseFormatters.model_not_found, |
| } |
|
|
| def error_message_wrapper( |
| self, |
| error: Exception, |
| body: Optional[ |
| Union[ |
| "CreateChatCompletionRequest", |
| "CreateCompletionRequest", |
| "CreateEmbeddingRequest", |
| ] |
| ] = None, |
| ) -> Tuple[int, ErrorResponse]: |
| """Wraps error message in OpenAI style error response""" |
| if body is not None and isinstance( |
| body, |
| ( |
| CreateCompletionRequest, |
| CreateChatCompletionRequest, |
| ), |
| ): |
| |
| for pattern, callback in self.pattern_and_formatters.items(): |
| match = pattern.search(str(error)) |
| if match is not None: |
| return callback(body, match) |
|
|
| |
| print(f"Exception: {str(error)}", file=sys.stderr) |
| traceback.print_exc(file=sys.stderr) |
|
|
| |
| return 500, ErrorResponse( |
| message=str(error), |
| type="internal_server_error", |
| param=None, |
| code=None, |
| ) |
|
|
| def get_route_handler( |
| self, |
| ) -> Callable[[Request], Coroutine[None, None, Response]]: |
| """Defines custom route handler that catches exceptions and formats |
| in OpenAI style error response""" |
|
|
| original_route_handler = super().get_route_handler() |
|
|
| async def custom_route_handler(request: Request) -> Response: |
| try: |
| start_sec = time.perf_counter() |
| response = await original_route_handler(request) |
| elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000) |
| response.headers["openai-processing-ms"] = f"{elapsed_time_ms}" |
| return response |
| except HTTPException as unauthorized: |
| |
| raise unauthorized |
| except Exception as exc: |
| json_body = await request.json() |
| try: |
| if "messages" in json_body: |
| |
| body: Optional[ |
| Union[ |
| CreateChatCompletionRequest, |
| CreateCompletionRequest, |
| CreateEmbeddingRequest, |
| ] |
| ] = CreateChatCompletionRequest(**json_body) |
| elif "prompt" in json_body: |
| |
| body = CreateCompletionRequest(**json_body) |
| else: |
| |
| body = CreateEmbeddingRequest(**json_body) |
| except Exception: |
| |
| body = None |
|
|
| |
| ( |
| status_code, |
| error_message, |
| ) = self.error_message_wrapper(error=exc, body=body) |
| return JSONResponse( |
| {"error": error_message}, |
| status_code=status_code, |
| ) |
|
|
| return custom_route_handler |
|
|