| from __future__ import annotations |
|
|
| import os |
| import json |
| import typing |
| import contextlib |
|
|
| from anyio import Lock |
| from functools import partial |
| from typing import List, Optional, Union, Dict |
|
|
| import llama_cpp |
|
|
| import anyio |
| from anyio.streams.memory import MemoryObjectSendStream |
| from starlette.concurrency import run_in_threadpool, iterate_in_threadpool |
| from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body |
| from fastapi.middleware import Middleware |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.security import HTTPBearer |
| from sse_starlette.sse import EventSourceResponse |
| from starlette_context.plugins import RequestIdPlugin |
| from starlette_context.middleware import RawContextMiddleware |
|
|
| from llama_cpp.server.model import ( |
| LlamaProxy, |
| ) |
| from llama_cpp.server.settings import ( |
| ConfigFileSettings, |
| Settings, |
| ModelSettings, |
| ServerSettings, |
| ) |
| from llama_cpp.server.types import ( |
| CreateCompletionRequest, |
| CreateEmbeddingRequest, |
| CreateChatCompletionRequest, |
| ModelList, |
| TokenizeInputRequest, |
| TokenizeInputResponse, |
| TokenizeInputCountResponse, |
| DetokenizeInputRequest, |
| DetokenizeInputResponse, |
| ) |
| from llama_cpp.server.errors import RouteErrorHandler |
|
|
|
|
| router = APIRouter(route_class=RouteErrorHandler) |
|
|
| _server_settings: Optional[ServerSettings] = None |
|
|
|
|
| def set_server_settings(server_settings: ServerSettings): |
| global _server_settings |
| _server_settings = server_settings |
|
|
|
|
| def get_server_settings(): |
| yield _server_settings |
|
|
|
|
| _llama_proxy: Optional[LlamaProxy] = None |
|
|
| llama_outer_lock = Lock() |
| llama_inner_lock = Lock() |
|
|
|
|
| def set_llama_proxy(model_settings: List[ModelSettings]): |
| global _llama_proxy |
| _llama_proxy = LlamaProxy(models=model_settings) |
|
|
|
|
| async def get_llama_proxy(): |
| |
| |
| |
| await llama_outer_lock.acquire() |
| release_outer_lock = True |
| try: |
| await llama_inner_lock.acquire() |
| try: |
| llama_outer_lock.release() |
| release_outer_lock = False |
| yield _llama_proxy |
| finally: |
| llama_inner_lock.release() |
| finally: |
| if release_outer_lock: |
| llama_outer_lock.release() |
|
|
|
|
| _ping_message_factory: typing.Optional[typing.Callable[[], bytes]] = None |
|
|
|
|
| def set_ping_message_factory(factory: typing.Callable[[], bytes]): |
| global _ping_message_factory |
| _ping_message_factory = factory |
|
|
|
|
| def create_app( |
| settings: Settings | None = None, |
| server_settings: ServerSettings | None = None, |
| model_settings: List[ModelSettings] | None = None, |
| ): |
| config_file = os.environ.get("CONFIG_FILE", None) |
| if config_file is not None: |
| if not os.path.exists(config_file): |
| raise ValueError(f"Config file {config_file} not found!") |
| with open(config_file, "rb") as f: |
| |
| if config_file.endswith(".yaml") or config_file.endswith(".yml"): |
| import yaml |
|
|
| config_file_settings = ConfigFileSettings.model_validate_json( |
| json.dumps(yaml.safe_load(f)) |
| ) |
| else: |
| config_file_settings = ConfigFileSettings.model_validate_json(f.read()) |
| server_settings = ServerSettings.model_validate(config_file_settings) |
| model_settings = config_file_settings.models |
|
|
| if server_settings is None and model_settings is None: |
| if settings is None: |
| settings = Settings() |
| server_settings = ServerSettings.model_validate(settings) |
| model_settings = [ModelSettings.model_validate(settings)] |
|
|
| assert ( |
| server_settings is not None and model_settings is not None |
| ), "server_settings and model_settings must be provided together" |
|
|
| set_server_settings(server_settings) |
| middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))] |
| app = FastAPI( |
| middleware=middleware, |
| title="🦙 llama.cpp Python API", |
| version=llama_cpp.__version__, |
| root_path=server_settings.root_path, |
| ) |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| app.include_router(router) |
|
|
| assert model_settings is not None |
| set_llama_proxy(model_settings=model_settings) |
|
|
| if server_settings.disable_ping_events: |
| set_ping_message_factory(lambda: bytes()) |
|
|
| return app |
|
|
|
|
| def prepare_request_resources( |
| body: CreateCompletionRequest | CreateChatCompletionRequest, |
| llama_proxy: LlamaProxy, |
| body_model: str | None, |
| kwargs, |
| ) -> llama_cpp.Llama: |
| if llama_proxy is None: |
| raise HTTPException( |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| detail="Service is not available", |
| ) |
| llama = llama_proxy(body_model) |
| if body.logit_bias is not None: |
| kwargs["logit_bias"] = ( |
| _logit_bias_tokens_to_input_ids(llama, body.logit_bias) |
| if body.logit_bias_type == "tokens" |
| else body.logit_bias |
| ) |
|
|
| if body.grammar is not None: |
| kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) |
|
|
| if body.min_tokens > 0: |
| _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( |
| [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] |
| ) |
| if "logits_processor" not in kwargs: |
| kwargs["logits_processor"] = _min_tokens_logits_processor |
| else: |
| kwargs["logits_processor"].extend(_min_tokens_logits_processor) |
| return llama |
|
|
|
|
| async def get_event_publisher( |
| request: Request, |
| inner_send_chan: MemoryObjectSendStream[typing.Any], |
| body: CreateCompletionRequest | CreateChatCompletionRequest, |
| body_model: str | None, |
| llama_call, |
| kwargs, |
| ): |
| server_settings = next(get_server_settings()) |
| interrupt_requests = ( |
| server_settings.interrupt_requests if server_settings else False |
| ) |
| async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy: |
| llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) |
| async with inner_send_chan: |
| try: |
| iterator = await run_in_threadpool(llama_call, llama, **kwargs) |
| async for chunk in iterate_in_threadpool(iterator): |
| await inner_send_chan.send(dict(data=json.dumps(chunk))) |
| if await request.is_disconnected(): |
| raise anyio.get_cancelled_exc_class()() |
| if interrupt_requests and llama_outer_lock.locked(): |
| await inner_send_chan.send(dict(data="[DONE]")) |
| raise anyio.get_cancelled_exc_class()() |
| await inner_send_chan.send(dict(data="[DONE]")) |
| except anyio.get_cancelled_exc_class() as e: |
| print("disconnected") |
| with anyio.move_on_after(1, shield=True): |
| print( |
| f"Disconnected from client (via refresh/close) {request.client}" |
| ) |
| raise e |
|
|
|
|
| def _logit_bias_tokens_to_input_ids( |
| llama: llama_cpp.Llama, |
| logit_bias: Dict[str, float], |
| ) -> Dict[str, float]: |
| to_bias: Dict[str, float] = {} |
| for token, score in logit_bias.items(): |
| token = token.encode("utf-8") |
| for input_id in llama.tokenize(token, add_bos=False, special=True): |
| to_bias[str(input_id)] = score |
| return to_bias |
|
|
|
|
| |
| bearer_scheme = HTTPBearer(auto_error=False) |
|
|
|
|
| async def authenticate( |
| settings: Settings = Depends(get_server_settings), |
| authorization: Optional[str] = Depends(bearer_scheme), |
| ): |
| |
| if settings.api_key is None: |
| return True |
|
|
| |
| if authorization and authorization.credentials == settings.api_key: |
| |
| return authorization.credentials |
|
|
| |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid API key", |
| ) |
|
|
|
|
| openai_v1_tag = "OpenAI V1" |
|
|
|
|
| @router.post( |
| "/v1/completions", |
| summary="Completion", |
| dependencies=[Depends(authenticate)], |
| response_model=Union[ |
| llama_cpp.CreateCompletionResponse, |
| str, |
| ], |
| responses={ |
| "200": { |
| "description": "Successful Response", |
| "content": { |
| "application/json": { |
| "schema": { |
| "anyOf": [ |
| {"$ref": "#/components/schemas/CreateCompletionResponse"} |
| ], |
| "title": "Completion response, when stream=False", |
| } |
| }, |
| "text/event-stream": { |
| "schema": { |
| "type": "string", |
| "title": "Server Side Streaming response, when stream=True. " |
| + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", |
| "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", |
| } |
| }, |
| }, |
| } |
| }, |
| tags=[openai_v1_tag], |
| ) |
| @router.post( |
| "/v1/engines/copilot-codex/completions", |
| include_in_schema=False, |
| dependencies=[Depends(authenticate)], |
| tags=[openai_v1_tag], |
| ) |
| async def create_completion( |
| request: Request, |
| body: CreateCompletionRequest, |
| ) -> llama_cpp.Completion: |
| if isinstance(body.prompt, list): |
| assert len(body.prompt) <= 1 |
| body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" |
|
|
| body_model = ( |
| body.model |
| if request.url.path != "/v1/engines/copilot-codex/completions" |
| else "copilot-codex" |
| ) |
|
|
| exclude = { |
| "n", |
| "best_of", |
| "logit_bias_type", |
| "user", |
| "min_tokens", |
| } |
| kwargs = body.model_dump(exclude=exclude) |
|
|
| |
| if kwargs.get("stream", False): |
| send_chan, recv_chan = anyio.create_memory_object_stream(10) |
| return EventSourceResponse( |
| recv_chan, |
| data_sender_callable=partial( |
| get_event_publisher, |
| request=request, |
| inner_send_chan=send_chan, |
| body=body, |
| body_model=body_model, |
| llama_call=llama_cpp.Llama.__call__, |
| kwargs=kwargs, |
| ), |
| sep="\n", |
| ping_message_factory=_ping_message_factory, |
| ) |
|
|
| |
| async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy: |
| llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) |
|
|
| if await request.is_disconnected(): |
| print( |
| f"Disconnected from client (via refresh/close) before llm invoked {request.client}" |
| ) |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Client closed request", |
| ) |
|
|
| return await run_in_threadpool(llama, **kwargs) |
|
|
|
|
| @router.post( |
| "/v1/embeddings", |
| summary="Embedding", |
| dependencies=[Depends(authenticate)], |
| tags=[openai_v1_tag], |
| ) |
| async def create_embedding( |
| request: CreateEmbeddingRequest, |
| llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
| ): |
| return await run_in_threadpool( |
| llama_proxy(request.model).create_embedding, |
| **request.model_dump(exclude={"user"}), |
| ) |
|
|
|
|
| @router.post( |
| "/v1/chat/completions", |
| summary="Chat", |
| dependencies=[Depends(authenticate)], |
| response_model=Union[llama_cpp.ChatCompletion, str], |
| responses={ |
| "200": { |
| "description": "Successful Response", |
| "content": { |
| "application/json": { |
| "schema": { |
| "anyOf": [ |
| { |
| "$ref": "#/components/schemas/CreateChatCompletionResponse" |
| } |
| ], |
| "title": "Completion response, when stream=False", |
| } |
| }, |
| "text/event-stream": { |
| "schema": { |
| "type": "string", |
| "title": "Server Side Streaming response, when stream=True" |
| + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", |
| "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", |
| } |
| }, |
| }, |
| } |
| }, |
| tags=[openai_v1_tag], |
| ) |
| async def create_chat_completion( |
| request: Request, |
| body: CreateChatCompletionRequest = Body( |
| openapi_examples={ |
| "normal": { |
| "summary": "Chat Completion", |
| "value": { |
| "model": "gpt-3.5-turbo", |
| "messages": [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": "What is the capital of France?"}, |
| ], |
| }, |
| }, |
| "json_mode": { |
| "summary": "JSON Mode", |
| "value": { |
| "model": "gpt-3.5-turbo", |
| "messages": [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": "Who won the world series in 2020"}, |
| ], |
| "response_format": {"type": "json_object"}, |
| }, |
| }, |
| "tool_calling": { |
| "summary": "Tool Calling", |
| "value": { |
| "model": "gpt-3.5-turbo", |
| "messages": [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": "Extract Jason is 30 years old."}, |
| ], |
| "tools": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "User", |
| "description": "User record", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "name": {"type": "string"}, |
| "age": {"type": "number"}, |
| }, |
| "required": ["name", "age"], |
| }, |
| }, |
| } |
| ], |
| "tool_choice": { |
| "type": "function", |
| "function": { |
| "name": "User", |
| }, |
| }, |
| }, |
| }, |
| "logprobs": { |
| "summary": "Logprobs", |
| "value": { |
| "model": "gpt-3.5-turbo", |
| "messages": [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": "What is the capital of France?"}, |
| ], |
| "logprobs": True, |
| "top_logprobs": 10, |
| }, |
| }, |
| } |
| ), |
| ) -> llama_cpp.ChatCompletion: |
| |
| |
| |
| |
|
|
| body_model = body.model |
| exclude = { |
| "n", |
| "logit_bias_type", |
| "user", |
| "min_tokens", |
| } |
| kwargs = body.model_dump(exclude=exclude) |
|
|
| |
| if kwargs.get("stream", False): |
| send_chan, recv_chan = anyio.create_memory_object_stream(10) |
| return EventSourceResponse( |
| recv_chan, |
| data_sender_callable=partial( |
| get_event_publisher, |
| request=request, |
| inner_send_chan=send_chan, |
| body=body, |
| body_model=body_model, |
| llama_call=llama_cpp.Llama.create_chat_completion, |
| kwargs=kwargs, |
| ), |
| sep="\n", |
| ping_message_factory=_ping_message_factory, |
| ) |
|
|
| |
| async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy: |
| llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) |
|
|
| if await request.is_disconnected(): |
| print( |
| f"Disconnected from client (via refresh/close) before llm invoked {request.client}" |
| ) |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Client closed request", |
| ) |
|
|
| return await run_in_threadpool(llama.create_chat_completion, **kwargs) |
|
|
|
|
| @router.get( |
| "/v1/models", |
| summary="Models", |
| dependencies=[Depends(authenticate)], |
| tags=[openai_v1_tag], |
| ) |
| async def get_models( |
| llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
| ) -> ModelList: |
| return { |
| "object": "list", |
| "data": [ |
| { |
| "id": model_alias, |
| "object": "model", |
| "owned_by": "me", |
| "permissions": [], |
| } |
| for model_alias in llama_proxy |
| ], |
| } |
|
|
|
|
| extras_tag = "Extras" |
|
|
|
|
| @router.post( |
| "/extras/tokenize", |
| summary="Tokenize", |
| dependencies=[Depends(authenticate)], |
| tags=[extras_tag], |
| ) |
| async def tokenize( |
| body: TokenizeInputRequest, |
| llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
| ) -> TokenizeInputResponse: |
| tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) |
|
|
| return TokenizeInputResponse(tokens=tokens) |
|
|
|
|
| @router.post( |
| "/extras/tokenize/count", |
| summary="Tokenize Count", |
| dependencies=[Depends(authenticate)], |
| tags=[extras_tag], |
| ) |
| async def count_query_tokens( |
| body: TokenizeInputRequest, |
| llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
| ) -> TokenizeInputCountResponse: |
| tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) |
|
|
| return TokenizeInputCountResponse(count=len(tokens)) |
|
|
|
|
| @router.post( |
| "/extras/detokenize", |
| summary="Detokenize", |
| dependencies=[Depends(authenticate)], |
| tags=[extras_tag], |
| ) |
| async def detokenize( |
| body: DetokenizeInputRequest, |
| llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
| ) -> DetokenizeInputResponse: |
| text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8") |
|
|
| return DetokenizeInputResponse(text=text) |
|
|