Spaces:
Paused
Paused
| import time | |
| import types | |
| from typing import ( | |
| Any, | |
| AsyncIterator, | |
| Callable, | |
| Coroutine, | |
| Iterable, | |
| Iterator, | |
| List, | |
| Literal, | |
| Optional, | |
| Union, | |
| cast, | |
| ) | |
| from urllib.parse import urlparse | |
| import httpx | |
| import openai | |
| from openai import AsyncOpenAI, OpenAI | |
| from openai.types.beta.assistant_deleted import AssistantDeleted | |
| from openai.types.file_deleted import FileDeleted | |
| from pydantic import BaseModel | |
| from typing_extensions import overload | |
| import litellm | |
| from litellm import LlmProviders | |
| from litellm._logging import verbose_logger | |
| from litellm.constants import DEFAULT_MAX_RETRIES | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
| from litellm.litellm_core_utils.logging_utils import track_llm_api_timing | |
| from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator | |
| from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException | |
| from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator | |
| from litellm.types.utils import ( | |
| EmbeddingResponse, | |
| ImageResponse, | |
| LiteLLMBatch, | |
| ModelResponse, | |
| ModelResponseStream, | |
| ) | |
| from litellm.utils import ( | |
| CustomStreamWrapper, | |
| ProviderConfigManager, | |
| convert_to_model_response_object, | |
| ) | |
| from ...types.llms.openai import * | |
| from ..base import BaseLLM | |
| from .chat.o_series_transformation import OpenAIOSeriesConfig | |
| from .common_utils import ( | |
| BaseOpenAILLM, | |
| OpenAIError, | |
| drop_params_from_unprocessable_entity_error, | |
| ) | |
| openaiOSeriesConfig = OpenAIOSeriesConfig() | |
| class MistralEmbeddingConfig: | |
| """ | |
| Reference: https://docs.mistral.ai/api/#operation/createEmbedding | |
| """ | |
| def __init__( | |
| self, | |
| ) -> None: | |
| locals_ = locals().copy() | |
| for key, value in locals_.items(): | |
| if key != "self" and value is not None: | |
| setattr(self.__class__, key, value) | |
| def get_config(cls): | |
| return { | |
| k: v | |
| for k, v in cls.__dict__.items() | |
| if not k.startswith("__") | |
| and not isinstance( | |
| v, | |
| ( | |
| types.FunctionType, | |
| types.BuiltinFunctionType, | |
| classmethod, | |
| staticmethod, | |
| ), | |
| ) | |
| and v is not None | |
| } | |
| def get_supported_openai_params(self): | |
| return [ | |
| "encoding_format", | |
| ] | |
| def map_openai_params(self, non_default_params: dict, optional_params: dict): | |
| for param, value in non_default_params.items(): | |
| if param == "encoding_format": | |
| optional_params["encoding_format"] = value | |
| return optional_params | |
| class OpenAIConfig(BaseConfig): | |
| """ | |
| Reference: https://platform.openai.com/docs/api-reference/chat/create | |
| The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters: | |
| - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. | |
| - `function_call` (string or object): This optional parameter controls how the model calls functions. | |
| - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. | |
| - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. | |
| - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. OpenAI has now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. | |
| - `max_completion_tokens` (integer or null): An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. | |
| - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. | |
| - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. | |
| - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. | |
| - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. | |
| - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. | |
| """ | |
| frequency_penalty: Optional[int] = None | |
| function_call: Optional[Union[str, dict]] = None | |
| functions: Optional[list] = None | |
| logit_bias: Optional[dict] = None | |
| max_completion_tokens: Optional[int] = None | |
| max_tokens: Optional[int] = None | |
| n: Optional[int] = None | |
| presence_penalty: Optional[int] = None | |
| stop: Optional[Union[str, list]] = None | |
| temperature: Optional[int] = None | |
| top_p: Optional[int] = None | |
| response_format: Optional[dict] = None | |
| def __init__( | |
| self, | |
| frequency_penalty: Optional[int] = None, | |
| function_call: Optional[Union[str, dict]] = None, | |
| functions: Optional[list] = None, | |
| logit_bias: Optional[dict] = None, | |
| max_completion_tokens: Optional[int] = None, | |
| max_tokens: Optional[int] = None, | |
| n: Optional[int] = None, | |
| presence_penalty: Optional[int] = None, | |
| stop: Optional[Union[str, list]] = None, | |
| temperature: Optional[int] = None, | |
| top_p: Optional[int] = None, | |
| response_format: Optional[dict] = None, | |
| ) -> None: | |
| locals_ = locals().copy() | |
| for key, value in locals_.items(): | |
| if key != "self" and value is not None: | |
| setattr(self.__class__, key, value) | |
| def get_config(cls): | |
| return super().get_config() | |
| def get_supported_openai_params(self, model: str) -> list: | |
| """ | |
| This function returns the list | |
| of supported openai parameters for a given OpenAI Model | |
| - If O1 model, returns O1 supported params | |
| - If gpt-audio model, returns gpt-audio supported params | |
| - Else, returns gpt supported params | |
| Args: | |
| model (str): OpenAI model | |
| Returns: | |
| list: List of supported openai parameters | |
| """ | |
| if openaiOSeriesConfig.is_model_o_series_model(model=model): | |
| return openaiOSeriesConfig.get_supported_openai_params(model=model) | |
| elif litellm.openAIGPTAudioConfig.is_model_gpt_audio_model(model=model): | |
| return litellm.openAIGPTAudioConfig.get_supported_openai_params(model=model) | |
| else: | |
| return litellm.openAIGPTConfig.get_supported_openai_params(model=model) | |
| def _map_openai_params( | |
| self, non_default_params: dict, optional_params: dict, model: str | |
| ) -> dict: | |
| supported_openai_params = self.get_supported_openai_params(model) | |
| for param, value in non_default_params.items(): | |
| if param in supported_openai_params: | |
| optional_params[param] = value | |
| return optional_params | |
| def _transform_messages( | |
| self, messages: List[AllMessageValues], model: str | |
| ) -> List[AllMessageValues]: | |
| return messages | |
| def map_openai_params( | |
| self, | |
| non_default_params: dict, | |
| optional_params: dict, | |
| model: str, | |
| drop_params: bool, | |
| ) -> dict: | |
| """ """ | |
| if openaiOSeriesConfig.is_model_o_series_model(model=model): | |
| return openaiOSeriesConfig.map_openai_params( | |
| non_default_params=non_default_params, | |
| optional_params=optional_params, | |
| model=model, | |
| drop_params=drop_params, | |
| ) | |
| elif litellm.openAIGPTAudioConfig.is_model_gpt_audio_model(model=model): | |
| return litellm.openAIGPTAudioConfig.map_openai_params( | |
| non_default_params=non_default_params, | |
| optional_params=optional_params, | |
| model=model, | |
| drop_params=drop_params, | |
| ) | |
| return litellm.openAIGPTConfig.map_openai_params( | |
| non_default_params=non_default_params, | |
| optional_params=optional_params, | |
| model=model, | |
| drop_params=drop_params, | |
| ) | |
| def get_error_class( | |
| self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] | |
| ) -> BaseLLMException: | |
| return OpenAIError( | |
| status_code=status_code, | |
| message=error_message, | |
| headers=headers, | |
| ) | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| messages = self._transform_messages(messages=messages, model=model) | |
| return {"model": model, "messages": messages, **optional_params} | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: httpx.Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: dict, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| encoding: Any, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| logging_obj.post_call(original_response=raw_response.text) | |
| logging_obj.model_call_details["response_headers"] = raw_response.headers | |
| final_response_obj = cast( | |
| ModelResponse, | |
| convert_to_model_response_object( | |
| response_object=raw_response.json(), | |
| model_response_object=model_response, | |
| hidden_params={"headers": raw_response.headers}, | |
| _response_headers=dict(raw_response.headers), | |
| ), | |
| ) | |
| return final_response_obj | |
| def validate_environment( | |
| self, | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| return { | |
| "Authorization": f"Bearer {api_key}", | |
| **headers, | |
| } | |
| def get_model_response_iterator( | |
| self, | |
| streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], | |
| sync_stream: bool, | |
| json_mode: Optional[bool] = False, | |
| ) -> Any: | |
| return OpenAIChatCompletionResponseIterator( | |
| streaming_response=streaming_response, | |
| sync_stream=sync_stream, | |
| json_mode=json_mode, | |
| ) | |
| class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator): | |
| def chunk_parser(self, chunk: dict) -> ModelResponseStream: | |
| """ | |
| {'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'finish_reason': None, 'index': 0, 'logprobs': None}], 'created': 1735763082, 'id': 'a83a2b0fbfaf4aab9c2c93cb8ba346d7', 'model': 'mistral-large', 'object': 'chat.completion.chunk'} | |
| """ | |
| try: | |
| return ModelResponseStream(**chunk) | |
| except Exception as e: | |
| raise e | |
| class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def _set_dynamic_params_on_client( | |
| self, | |
| client: Union[OpenAI, AsyncOpenAI], | |
| organization: Optional[str] = None, | |
| max_retries: Optional[int] = None, | |
| ): | |
| if organization is not None: | |
| client.organization = organization | |
| if max_retries is not None: | |
| client.max_retries = max_retries | |
| def _get_openai_client( | |
| self, | |
| is_async: bool, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| api_version: Optional[str] = None, | |
| timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), | |
| max_retries: Optional[int] = DEFAULT_MAX_RETRIES, | |
| organization: Optional[str] = None, | |
| client: Optional[Union[OpenAI, AsyncOpenAI]] = None, | |
| ) -> Optional[Union[OpenAI, AsyncOpenAI]]: | |
| client_initialization_params: Dict = locals() | |
| if client is None: | |
| if not isinstance(max_retries, int): | |
| raise OpenAIError( | |
| status_code=422, | |
| message="max retries must be an int. Passed in value: {}".format( | |
| max_retries | |
| ), | |
| ) | |
| cached_client = self.get_cached_openai_client( | |
| client_initialization_params=client_initialization_params, | |
| client_type="openai", | |
| ) | |
| if cached_client: | |
| if isinstance(cached_client, OpenAI) or isinstance( | |
| cached_client, AsyncOpenAI | |
| ): | |
| return cached_client | |
| if is_async: | |
| _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=OpenAIChatCompletion._get_async_http_client(), | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| ) | |
| else: | |
| _new_client = OpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=OpenAIChatCompletion._get_sync_http_client(), | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| ) | |
| ## SAVE CACHE KEY | |
| self.set_cached_openai_client( | |
| openai_client=_new_client, | |
| client_initialization_params=client_initialization_params, | |
| client_type="openai", | |
| ) | |
| return _new_client | |
| else: | |
| self._set_dynamic_params_on_client( | |
| client=client, | |
| organization=organization, | |
| max_retries=max_retries, | |
| ) | |
| return client | |
| async def make_openai_chat_completion_request( | |
| self, | |
| openai_aclient: AsyncOpenAI, | |
| data: dict, | |
| timeout: Union[float, httpx.Timeout], | |
| logging_obj: LiteLLMLoggingObj, | |
| ) -> Tuple[dict, BaseModel]: | |
| """ | |
| Helper to: | |
| - call chat.completions.create.with_raw_response when litellm.return_response_headers is True | |
| - call chat.completions.create by default | |
| """ | |
| start_time = time.time() | |
| try: | |
| raw_response = ( | |
| await openai_aclient.chat.completions.with_raw_response.create( | |
| **data, timeout=timeout | |
| ) | |
| ) | |
| end_time = time.time() | |
| if hasattr(raw_response, "headers"): | |
| headers = dict(raw_response.headers) | |
| else: | |
| headers = {} | |
| response = raw_response.parse() | |
| return headers, response | |
| except openai.APITimeoutError as e: | |
| end_time = time.time() | |
| time_delta = round(end_time - start_time, 2) | |
| e.message += f" - timeout value={timeout}, time taken={time_delta} seconds" | |
| raise e | |
| except Exception as e: | |
| raise e | |
| def make_sync_openai_chat_completion_request( | |
| self, | |
| openai_client: OpenAI, | |
| data: dict, | |
| timeout: Union[float, httpx.Timeout], | |
| logging_obj: LiteLLMLoggingObj, | |
| ) -> Tuple[dict, BaseModel]: | |
| """ | |
| Helper to: | |
| - call chat.completions.create.with_raw_response when litellm.return_response_headers is True | |
| - call chat.completions.create by default | |
| """ | |
| raw_response = None | |
| try: | |
| raw_response = openai_client.chat.completions.with_raw_response.create( | |
| **data, timeout=timeout | |
| ) | |
| if hasattr(raw_response, "headers"): | |
| headers = dict(raw_response.headers) | |
| else: | |
| headers = {} | |
| response = raw_response.parse() | |
| return headers, response | |
| except Exception as e: | |
| if raw_response is not None: | |
| raise Exception( | |
| "error - {}, Received response - {}, Type of response - {}".format( | |
| e, raw_response, type(raw_response) | |
| ) | |
| ) | |
| else: | |
| raise e | |
| def mock_streaming( | |
| self, | |
| response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| model: str, | |
| stream_options: Optional[dict] = None, | |
| ) -> CustomStreamWrapper: | |
| completion_stream = MockResponseIterator(model_response=response) | |
| streaming_response = CustomStreamWrapper( | |
| completion_stream=completion_stream, | |
| model=model, | |
| custom_llm_provider="openai", | |
| logging_obj=logging_obj, | |
| stream_options=stream_options, | |
| ) | |
| return streaming_response | |
| def completion( # type: ignore # noqa: PLR0915 | |
| self, | |
| model_response: ModelResponse, | |
| timeout: Union[float, httpx.Timeout], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| logging_obj: Any, | |
| model: Optional[str] = None, | |
| messages: Optional[list] = None, | |
| print_verbose: Optional[Callable] = None, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| api_version: Optional[str] = None, | |
| dynamic_params: Optional[bool] = None, | |
| azure_ad_token: Optional[str] = None, | |
| acompletion: bool = False, | |
| logger_fn=None, | |
| headers: Optional[dict] = None, | |
| custom_prompt_dict: dict = {}, | |
| client=None, | |
| organization: Optional[str] = None, | |
| custom_llm_provider: Optional[str] = None, | |
| drop_params: Optional[bool] = None, | |
| ): | |
| super().completion() | |
| try: | |
| fake_stream: bool = False | |
| inference_params = optional_params.copy() | |
| stream_options: Optional[dict] = inference_params.pop( | |
| "stream_options", None | |
| ) | |
| stream: Optional[bool] = inference_params.pop("stream", False) | |
| provider_config: Optional[BaseConfig] = None | |
| if custom_llm_provider is not None and model is not None: | |
| provider_config = ProviderConfigManager.get_provider_chat_config( | |
| model=model, provider=LlmProviders(custom_llm_provider) | |
| ) | |
| if provider_config: | |
| fake_stream = provider_config.should_fake_stream( | |
| model=model, custom_llm_provider=custom_llm_provider, stream=stream | |
| ) | |
| if headers: | |
| inference_params["extra_headers"] = headers | |
| if model is None or messages is None: | |
| raise OpenAIError(status_code=422, message="Missing model or messages") | |
| if not isinstance(timeout, float) and not isinstance( | |
| timeout, httpx.Timeout | |
| ): | |
| raise OpenAIError( | |
| status_code=422, | |
| message="Timeout needs to be a float or httpx.Timeout", | |
| ) | |
| if custom_llm_provider is not None and custom_llm_provider != "openai": | |
| model_response.model = f"{custom_llm_provider}/{model}" | |
| for _ in range( | |
| 2 | |
| ): # if call fails due to alternating messages, retry with reformatted message | |
| if provider_config is not None: | |
| data = provider_config.transform_request( | |
| model=model, | |
| messages=messages, | |
| optional_params=inference_params, | |
| litellm_params=litellm_params, | |
| headers=headers or {}, | |
| ) | |
| else: | |
| data = OpenAIConfig().transform_request( | |
| model=model, | |
| messages=messages, | |
| optional_params=inference_params, | |
| litellm_params=litellm_params, | |
| headers=headers or {}, | |
| ) | |
| try: | |
| max_retries = data.pop("max_retries", 2) | |
| if acompletion is True: | |
| if stream is True and fake_stream is False: | |
| return self.async_streaming( | |
| logging_obj=logging_obj, | |
| headers=headers, | |
| data=data, | |
| model=model, | |
| api_base=api_base, | |
| api_key=api_key, | |
| api_version=api_version, | |
| timeout=timeout, | |
| client=client, | |
| max_retries=max_retries, | |
| organization=organization, | |
| drop_params=drop_params, | |
| stream_options=stream_options, | |
| ) | |
| else: | |
| return self.acompletion( | |
| data=data, | |
| headers=headers, | |
| model=model, | |
| logging_obj=logging_obj, | |
| model_response=model_response, | |
| api_base=api_base, | |
| api_key=api_key, | |
| api_version=api_version, | |
| timeout=timeout, | |
| client=client, | |
| max_retries=max_retries, | |
| organization=organization, | |
| drop_params=drop_params, | |
| fake_stream=fake_stream, | |
| ) | |
| elif stream is True and fake_stream is False: | |
| return self.streaming( | |
| logging_obj=logging_obj, | |
| headers=headers, | |
| data=data, | |
| model=model, | |
| api_base=api_base, | |
| api_key=api_key, | |
| api_version=api_version, | |
| timeout=timeout, | |
| client=client, | |
| max_retries=max_retries, | |
| organization=organization, | |
| stream_options=stream_options, | |
| ) | |
| else: | |
| if not isinstance(max_retries, int): | |
| raise OpenAIError( | |
| status_code=422, message="max retries must be an int" | |
| ) | |
| openai_client: OpenAI = self._get_openai_client( # type: ignore | |
| is_async=False, | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=messages, | |
| api_key=openai_client.api_key, | |
| additional_args={ | |
| "headers": headers, | |
| "api_base": openai_client._base_url._uri_reference, | |
| "acompletion": acompletion, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| ( | |
| headers, | |
| response, | |
| ) = self.make_sync_openai_chat_completion_request( | |
| openai_client=openai_client, | |
| data=data, | |
| timeout=timeout, | |
| logging_obj=logging_obj, | |
| ) | |
| logging_obj.model_call_details["response_headers"] = headers | |
| stringified_response = response.model_dump() | |
| logging_obj.post_call( | |
| input=messages, | |
| api_key=api_key, | |
| original_response=stringified_response, | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| final_response_obj = convert_to_model_response_object( | |
| response_object=stringified_response, | |
| model_response_object=model_response, | |
| _response_headers=headers, | |
| ) | |
| if fake_stream is True: | |
| return self.mock_streaming( | |
| response=cast(ModelResponse, final_response_obj), | |
| logging_obj=logging_obj, | |
| model=model, | |
| stream_options=stream_options, | |
| ) | |
| return final_response_obj | |
| except openai.UnprocessableEntityError as e: | |
| ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 | |
| if litellm.drop_params is True or drop_params is True: | |
| inference_params = drop_params_from_unprocessable_entity_error( | |
| e, inference_params | |
| ) | |
| else: | |
| raise e | |
| # e.message | |
| except Exception as e: | |
| if print_verbose is not None: | |
| print_verbose(f"openai.py: Received openai error - {str(e)}") | |
| if ( | |
| "Conversation roles must alternate user/assistant" in str(e) | |
| or "user and assistant roles should be alternating" in str(e) | |
| ) and messages is not None: | |
| if print_verbose is not None: | |
| print_verbose("openai.py: REFORMATS THE MESSAGE!") | |
| # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility | |
| new_messages = [] | |
| for i in range(len(messages) - 1): # type: ignore | |
| new_messages.append(messages[i]) | |
| if messages[i]["role"] == messages[i + 1]["role"]: | |
| if messages[i]["role"] == "user": | |
| new_messages.append( | |
| {"role": "assistant", "content": ""} | |
| ) | |
| else: | |
| new_messages.append({"role": "user", "content": ""}) | |
| new_messages.append(messages[-1]) | |
| messages = new_messages | |
| elif ( | |
| "Last message must have role `user`" in str(e) | |
| ) and messages is not None: | |
| new_messages = messages | |
| new_messages.append({"role": "user", "content": ""}) | |
| messages = new_messages | |
| elif "unknown field: parameter index is not a valid field" in str( | |
| e | |
| ): | |
| litellm.remove_index_from_tool_calls(messages=messages) | |
| else: | |
| raise e | |
| except OpenAIError as e: | |
| raise e | |
| except Exception as e: | |
| status_code = getattr(e, "status_code", 500) | |
| error_headers = getattr(e, "headers", None) | |
| error_text = getattr(e, "text", str(e)) | |
| error_response = getattr(e, "response", None) | |
| error_body = getattr(e, "body", None) | |
| if error_headers is None and error_response: | |
| error_headers = getattr(error_response, "headers", None) | |
| raise OpenAIError( | |
| status_code=status_code, | |
| message=error_text, | |
| headers=error_headers, | |
| body=error_body, | |
| ) | |
| async def acompletion( | |
| self, | |
| data: dict, | |
| model: str, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| timeout: Union[float, httpx.Timeout], | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| api_version: Optional[str] = None, | |
| organization: Optional[str] = None, | |
| client=None, | |
| max_retries=None, | |
| headers=None, | |
| drop_params: Optional[bool] = None, | |
| stream_options: Optional[dict] = None, | |
| fake_stream: bool = False, | |
| ): | |
| response = None | |
| for _ in range( | |
| 2 | |
| ): # if call fails due to alternating messages, retry with reformatted message | |
| try: | |
| openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore | |
| is_async=True, | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=data["messages"], | |
| api_key=openai_aclient.api_key, | |
| additional_args={ | |
| "headers": { | |
| "Authorization": f"Bearer {openai_aclient.api_key}" | |
| }, | |
| "api_base": openai_aclient._base_url._uri_reference, | |
| "acompletion": True, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| headers, response = await self.make_openai_chat_completion_request( | |
| openai_aclient=openai_aclient, | |
| data=data, | |
| timeout=timeout, | |
| logging_obj=logging_obj, | |
| ) | |
| stringified_response = response.model_dump() | |
| logging_obj.post_call( | |
| input=data["messages"], | |
| api_key=api_key, | |
| original_response=stringified_response, | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| logging_obj.model_call_details["response_headers"] = headers | |
| final_response_obj = convert_to_model_response_object( | |
| response_object=stringified_response, | |
| model_response_object=model_response, | |
| hidden_params={"headers": headers}, | |
| _response_headers=headers, | |
| ) | |
| if fake_stream is True: | |
| return self.mock_streaming( | |
| response=cast(ModelResponse, final_response_obj), | |
| logging_obj=logging_obj, | |
| model=model, | |
| stream_options=stream_options, | |
| ) | |
| return final_response_obj | |
| except openai.UnprocessableEntityError as e: | |
| ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 | |
| if litellm.drop_params is True or drop_params is True: | |
| data = drop_params_from_unprocessable_entity_error(e, data) | |
| else: | |
| raise e | |
| # e.message | |
| except Exception as e: | |
| exception_response = getattr(e, "response", None) | |
| status_code = getattr(e, "status_code", 500) | |
| exception_body = getattr(e, "body", None) | |
| error_headers = getattr(e, "headers", None) | |
| if error_headers is None and exception_response: | |
| error_headers = getattr(exception_response, "headers", None) | |
| message = getattr(e, "message", str(e)) | |
| raise OpenAIError( | |
| status_code=status_code, | |
| message=message, | |
| headers=error_headers, | |
| body=exception_body, | |
| ) | |
| def streaming( | |
| self, | |
| logging_obj, | |
| timeout: Union[float, httpx.Timeout], | |
| data: dict, | |
| model: str, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| api_version: Optional[str] = None, | |
| organization: Optional[str] = None, | |
| client=None, | |
| max_retries=None, | |
| headers=None, | |
| stream_options: Optional[dict] = None, | |
| ): | |
| data["stream"] = True | |
| data.update( | |
| self.get_stream_options(stream_options=stream_options, api_base=api_base) | |
| ) | |
| openai_client: OpenAI = self._get_openai_client( # type: ignore | |
| is_async=False, | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=data["messages"], | |
| api_key=api_key, | |
| additional_args={ | |
| "headers": {"Authorization": f"Bearer {openai_client.api_key}"}, | |
| "api_base": openai_client._base_url._uri_reference, | |
| "acompletion": False, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| headers, response = self.make_sync_openai_chat_completion_request( | |
| openai_client=openai_client, | |
| data=data, | |
| timeout=timeout, | |
| logging_obj=logging_obj, | |
| ) | |
| logging_obj.model_call_details["response_headers"] = headers | |
| streamwrapper = CustomStreamWrapper( | |
| completion_stream=response, | |
| model=model, | |
| custom_llm_provider="openai", | |
| logging_obj=logging_obj, | |
| stream_options=data.get("stream_options", None), | |
| _response_headers=headers, | |
| ) | |
| return streamwrapper | |
| async def async_streaming( | |
| self, | |
| timeout: Union[float, httpx.Timeout], | |
| data: dict, | |
| model: str, | |
| logging_obj: LiteLLMLoggingObj, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| api_version: Optional[str] = None, | |
| organization: Optional[str] = None, | |
| client=None, | |
| max_retries=None, | |
| headers=None, | |
| drop_params: Optional[bool] = None, | |
| stream_options: Optional[dict] = None, | |
| ): | |
| response = None | |
| data["stream"] = True | |
| data.update( | |
| self.get_stream_options(stream_options=stream_options, api_base=api_base) | |
| ) | |
| for _ in range(2): | |
| try: | |
| openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore | |
| is_async=True, | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=data["messages"], | |
| api_key=api_key, | |
| additional_args={ | |
| "headers": headers, | |
| "api_base": api_base, | |
| "acompletion": True, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| headers, response = await self.make_openai_chat_completion_request( | |
| openai_aclient=openai_aclient, | |
| data=data, | |
| timeout=timeout, | |
| logging_obj=logging_obj, | |
| ) | |
| logging_obj.model_call_details["response_headers"] = headers | |
| streamwrapper = CustomStreamWrapper( | |
| completion_stream=response, | |
| model=model, | |
| custom_llm_provider="openai", | |
| logging_obj=logging_obj, | |
| stream_options=data.get("stream_options", None), | |
| _response_headers=headers, | |
| ) | |
| return streamwrapper | |
| except openai.UnprocessableEntityError as e: | |
| ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 | |
| if litellm.drop_params is True or drop_params is True: | |
| data = drop_params_from_unprocessable_entity_error(e, data) | |
| else: | |
| raise e | |
| except ( | |
| Exception | |
| ) as e: # need to exception handle here. async exceptions don't get caught in sync functions. | |
| if isinstance(e, OpenAIError): | |
| raise e | |
| error_headers = getattr(e, "headers", None) | |
| status_code = getattr(e, "status_code", 500) | |
| error_response = getattr(e, "response", None) | |
| exception_body = getattr(e, "body", None) | |
| if error_headers is None and error_response: | |
| error_headers = getattr(error_response, "headers", None) | |
| if response is not None and hasattr(response, "text"): | |
| raise OpenAIError( | |
| status_code=status_code, | |
| message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore | |
| headers=error_headers, | |
| body=exception_body, | |
| ) | |
| else: | |
| if type(e).__name__ == "ReadTimeout": | |
| raise OpenAIError( | |
| status_code=408, | |
| message=f"{type(e).__name__}", | |
| headers=error_headers, | |
| body=exception_body, | |
| ) | |
| elif hasattr(e, "status_code"): | |
| raise OpenAIError( | |
| status_code=getattr(e, "status_code", 500), | |
| message=str(e), | |
| headers=error_headers, | |
| body=exception_body, | |
| ) | |
| else: | |
| raise OpenAIError( | |
| status_code=500, | |
| message=f"{str(e)}", | |
| headers=error_headers, | |
| body=exception_body, | |
| ) | |
| def get_stream_options( | |
| self, stream_options: Optional[dict], api_base: Optional[str] | |
| ) -> dict: | |
| """ | |
| Pass `stream_options` to the data dict for OpenAI requests | |
| """ | |
| if stream_options is not None: | |
| return {"stream_options": stream_options} | |
| else: | |
| # by default litellm will include usage for openai endpoints | |
| if api_base is None or urlparse(api_base).hostname == "api.openai.com": | |
| return {"stream_options": {"include_usage": True}} | |
| return {} | |
| # Embedding | |
| async def make_openai_embedding_request( | |
| self, | |
| openai_aclient: AsyncOpenAI, | |
| data: dict, | |
| timeout: Union[float, httpx.Timeout], | |
| logging_obj: LiteLLMLoggingObj, | |
| ): | |
| """ | |
| Helper to: | |
| - call embeddings.create.with_raw_response when litellm.return_response_headers is True | |
| - call embeddings.create by default | |
| """ | |
| try: | |
| raw_response = await openai_aclient.embeddings.with_raw_response.create( | |
| **data, timeout=timeout | |
| ) # type: ignore | |
| headers = dict(raw_response.headers) | |
| response = raw_response.parse() | |
| return headers, response | |
| except Exception as e: | |
| raise e | |
| def make_sync_openai_embedding_request( | |
| self, | |
| openai_client: OpenAI, | |
| data: dict, | |
| timeout: Union[float, httpx.Timeout], | |
| logging_obj: LiteLLMLoggingObj, | |
| ): | |
| """ | |
| Helper to: | |
| - call embeddings.create.with_raw_response when litellm.return_response_headers is True | |
| - call embeddings.create by default | |
| """ | |
| try: | |
| raw_response = openai_client.embeddings.with_raw_response.create( | |
| **data, timeout=timeout | |
| ) # type: ignore | |
| headers = dict(raw_response.headers) | |
| response = raw_response.parse() | |
| return headers, response | |
| except Exception as e: | |
| raise e | |
| async def aembedding( | |
| self, | |
| input: list, | |
| data: dict, | |
| model_response: EmbeddingResponse, | |
| timeout: float, | |
| logging_obj: LiteLLMLoggingObj, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| client: Optional[AsyncOpenAI] = None, | |
| max_retries=None, | |
| ): | |
| try: | |
| openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore | |
| is_async=True, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| ) | |
| headers, response = await self.make_openai_embedding_request( | |
| openai_aclient=openai_aclient, | |
| data=data, | |
| timeout=timeout, | |
| logging_obj=logging_obj, | |
| ) | |
| logging_obj.model_call_details["response_headers"] = headers | |
| stringified_response = response.model_dump() | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=stringified_response, | |
| ) | |
| returned_response: EmbeddingResponse = convert_to_model_response_object( | |
| response_object=stringified_response, | |
| model_response_object=model_response, | |
| response_type="embedding", | |
| _response_headers=headers, | |
| ) # type: ignore | |
| return returned_response | |
| except OpenAIError as e: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=str(e), | |
| ) | |
| raise e | |
| except Exception as e: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=str(e), | |
| ) | |
| status_code = getattr(e, "status_code", 500) | |
| error_headers = getattr(e, "headers", None) | |
| error_text = getattr(e, "text", str(e)) | |
| error_response = getattr(e, "response", None) | |
| if error_headers is None and error_response: | |
| error_headers = getattr(error_response, "headers", None) | |
| raise OpenAIError( | |
| status_code=status_code, message=error_text, headers=error_headers | |
| ) | |
| def embedding( # type: ignore | |
| self, | |
| model: str, | |
| input: list, | |
| timeout: float, | |
| logging_obj, | |
| model_response: EmbeddingResponse, | |
| optional_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| client=None, | |
| aembedding=None, | |
| max_retries: Optional[int] = None, | |
| ) -> EmbeddingResponse: | |
| super().embedding() | |
| try: | |
| model = model | |
| data = {"model": model, "input": input, **optional_params} | |
| max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES | |
| if not isinstance(max_retries, int): | |
| raise OpenAIError(status_code=422, message="max retries must be an int") | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data, "api_base": api_base}, | |
| ) | |
| if aembedding is True: | |
| return self.aembedding( # type: ignore | |
| data=data, | |
| input=input, | |
| logging_obj=logging_obj, | |
| model_response=model_response, | |
| api_base=api_base, | |
| api_key=api_key, | |
| timeout=timeout, | |
| client=client, | |
| max_retries=max_retries, | |
| ) | |
| openai_client: OpenAI = self._get_openai_client( # type: ignore | |
| is_async=False, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| ) | |
| ## embedding CALL | |
| headers: Optional[Dict] = None | |
| headers, sync_embedding_response = self.make_sync_openai_embedding_request( | |
| openai_client=openai_client, | |
| data=data, | |
| timeout=timeout, | |
| logging_obj=logging_obj, | |
| ) # type: ignore | |
| ## LOGGING | |
| logging_obj.model_call_details["response_headers"] = headers | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=sync_embedding_response, | |
| ) | |
| response: EmbeddingResponse = convert_to_model_response_object( | |
| response_object=sync_embedding_response.model_dump(), | |
| model_response_object=model_response, | |
| _response_headers=headers, | |
| response_type="embedding", | |
| ) # type: ignore | |
| return response | |
| except OpenAIError as e: | |
| raise e | |
| except Exception as e: | |
| status_code = getattr(e, "status_code", 500) | |
| error_headers = getattr(e, "headers", None) | |
| error_text = getattr(e, "text", str(e)) | |
| error_response = getattr(e, "response", None) | |
| if error_headers is None and error_response: | |
| error_headers = getattr(error_response, "headers", None) | |
| raise OpenAIError( | |
| status_code=status_code, message=error_text, headers=error_headers | |
| ) | |
| async def aimage_generation( | |
| self, | |
| prompt: str, | |
| data: dict, | |
| model_response: ModelResponse, | |
| timeout: float, | |
| logging_obj: Any, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| client=None, | |
| max_retries=None, | |
| ): | |
| response = None | |
| try: | |
| openai_aclient = self._get_openai_client( | |
| is_async=True, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| ) | |
| response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore | |
| stringified_response = response.model_dump() | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=prompt, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=stringified_response, | |
| ) | |
| return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore | |
| except Exception as e: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=prompt, | |
| api_key=api_key, | |
| original_response=str(e), | |
| ) | |
| raise e | |
| def image_generation( | |
| self, | |
| model: Optional[str], | |
| prompt: str, | |
| timeout: float, | |
| optional_params: dict, | |
| logging_obj: Any, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| model_response: Optional[ImageResponse] = None, | |
| client=None, | |
| aimg_generation=None, | |
| ) -> ImageResponse: | |
| data = {} | |
| try: | |
| model = model | |
| data = {"model": model, "prompt": prompt, **optional_params} | |
| max_retries = data.pop("max_retries", 2) | |
| if not isinstance(max_retries, int): | |
| raise OpenAIError(status_code=422, message="max retries must be an int") | |
| if aimg_generation is True: | |
| return self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore | |
| openai_client: OpenAI = self._get_openai_client( # type: ignore | |
| is_async=False, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| ) | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=prompt, | |
| api_key=openai_client.api_key, | |
| additional_args={ | |
| "headers": {"Authorization": f"Bearer {openai_client.api_key}"}, | |
| "api_base": openai_client._base_url._uri_reference, | |
| "acompletion": True, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| ## COMPLETION CALL | |
| _response = openai_client.images.generate(**data, timeout=timeout) # type: ignore | |
| response = _response.model_dump() | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=prompt, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=response, | |
| ) | |
| return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore | |
| except OpenAIError as e: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=prompt, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=str(e), | |
| ) | |
| raise e | |
| except Exception as e: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=prompt, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=str(e), | |
| ) | |
| if hasattr(e, "status_code"): | |
| raise OpenAIError( | |
| status_code=getattr(e, "status_code", 500), message=str(e) | |
| ) | |
| else: | |
| raise OpenAIError(status_code=500, message=str(e)) | |
| def audio_speech( | |
| self, | |
| model: str, | |
| input: str, | |
| voice: str, | |
| optional_params: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| organization: Optional[str], | |
| project: Optional[str], | |
| max_retries: int, | |
| timeout: Union[float, httpx.Timeout], | |
| aspeech: Optional[bool] = None, | |
| client=None, | |
| ) -> HttpxBinaryResponseContent: | |
| if aspeech is not None and aspeech is True: | |
| return self.async_audio_speech( | |
| model=model, | |
| input=input, | |
| voice=voice, | |
| optional_params=optional_params, | |
| api_key=api_key, | |
| api_base=api_base, | |
| organization=organization, | |
| project=project, | |
| max_retries=max_retries, | |
| timeout=timeout, | |
| client=client, | |
| ) # type: ignore | |
| openai_client = self._get_openai_client( | |
| is_async=False, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| ) | |
| response = cast(OpenAI, openai_client).audio.speech.create( | |
| model=model, | |
| voice=voice, # type: ignore | |
| input=input, | |
| **optional_params, | |
| ) | |
| return HttpxBinaryResponseContent(response=response.response) | |
| async def async_audio_speech( | |
| self, | |
| model: str, | |
| input: str, | |
| voice: str, | |
| optional_params: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| organization: Optional[str], | |
| project: Optional[str], | |
| max_retries: int, | |
| timeout: Union[float, httpx.Timeout], | |
| client=None, | |
| ) -> HttpxBinaryResponseContent: | |
| openai_client = cast( | |
| AsyncOpenAI, | |
| self._get_openai_client( | |
| is_async=True, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| ), | |
| ) | |
| response = await openai_client.audio.speech.create( | |
| model=model, | |
| voice=voice, # type: ignore | |
| input=input, | |
| **optional_params, | |
| ) | |
| return HttpxBinaryResponseContent(response=response.response) | |
| class OpenAIFilesAPI(BaseLLM): | |
| """ | |
| OpenAI methods to support for batches | |
| - create_file() | |
| - retrieve_file() | |
| - list_files() | |
| - delete_file() | |
| - file_content() | |
| - update_file() | |
| """ | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def get_openai_client( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[Union[OpenAI, AsyncOpenAI]] = None, | |
| _is_async: bool = False, | |
| ) -> Optional[Union[OpenAI, AsyncOpenAI]]: | |
| received_args = locals() | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None | |
| if client is None: | |
| data = {} | |
| for k, v in received_args.items(): | |
| if k == "self" or k == "client" or k == "_is_async": | |
| pass | |
| elif k == "api_base" and v is not None: | |
| data["base_url"] = v | |
| elif v is not None: | |
| data[k] = v | |
| if _is_async is True: | |
| openai_client = AsyncOpenAI(**data) | |
| else: | |
| openai_client = OpenAI(**data) # type: ignore | |
| else: | |
| openai_client = client | |
| return openai_client | |
| async def acreate_file( | |
| self, | |
| create_file_data: CreateFileRequest, | |
| openai_client: AsyncOpenAI, | |
| ) -> OpenAIFileObject: | |
| response = await openai_client.files.create(**create_file_data) | |
| return OpenAIFileObject(**response.model_dump()) | |
| def create_file( | |
| self, | |
| _is_async: bool, | |
| create_file_data: CreateFileRequest, | |
| api_base: str, | |
| api_key: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[Union[OpenAI, AsyncOpenAI]] = None, | |
| ) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]: | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.acreate_file( # type: ignore | |
| create_file_data=create_file_data, openai_client=openai_client | |
| ) | |
| response = cast(OpenAI, openai_client).files.create(**create_file_data) | |
| return OpenAIFileObject(**response.model_dump()) | |
| async def afile_content( | |
| self, | |
| file_content_request: FileContentRequest, | |
| openai_client: AsyncOpenAI, | |
| ) -> HttpxBinaryResponseContent: | |
| response = await openai_client.files.content(**file_content_request) | |
| return HttpxBinaryResponseContent(response=response.response) | |
| def file_content( | |
| self, | |
| _is_async: bool, | |
| file_content_request: FileContentRequest, | |
| api_base: str, | |
| api_key: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[Union[OpenAI, AsyncOpenAI]] = None, | |
| ) -> Union[ | |
| HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent] | |
| ]: | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.afile_content( # type: ignore | |
| file_content_request=file_content_request, | |
| openai_client=openai_client, | |
| ) | |
| response = cast(OpenAI, openai_client).files.content(**file_content_request) | |
| return HttpxBinaryResponseContent(response=response.response) | |
| async def aretrieve_file( | |
| self, | |
| file_id: str, | |
| openai_client: AsyncOpenAI, | |
| ) -> FileObject: | |
| response = await openai_client.files.retrieve(file_id=file_id) | |
| return response | |
| def retrieve_file( | |
| self, | |
| _is_async: bool, | |
| file_id: str, | |
| api_base: str, | |
| api_key: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[Union[OpenAI, AsyncOpenAI]] = None, | |
| ): | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.aretrieve_file( # type: ignore | |
| file_id=file_id, | |
| openai_client=openai_client, | |
| ) | |
| response = openai_client.files.retrieve(file_id=file_id) | |
| return response | |
| async def adelete_file( | |
| self, | |
| file_id: str, | |
| openai_client: AsyncOpenAI, | |
| ) -> FileDeleted: | |
| response = await openai_client.files.delete(file_id=file_id) | |
| return response | |
| def delete_file( | |
| self, | |
| _is_async: bool, | |
| file_id: str, | |
| api_base: str, | |
| api_key: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[Union[OpenAI, AsyncOpenAI]] = None, | |
| ): | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.adelete_file( # type: ignore | |
| file_id=file_id, | |
| openai_client=openai_client, | |
| ) | |
| response = openai_client.files.delete(file_id=file_id) | |
| return response | |
| async def alist_files( | |
| self, | |
| openai_client: AsyncOpenAI, | |
| purpose: Optional[str] = None, | |
| ): | |
| if isinstance(purpose, str): | |
| response = await openai_client.files.list(purpose=purpose) | |
| else: | |
| response = await openai_client.files.list() | |
| return response | |
| def list_files( | |
| self, | |
| _is_async: bool, | |
| api_base: str, | |
| api_key: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| purpose: Optional[str] = None, | |
| client: Optional[Union[OpenAI, AsyncOpenAI]] = None, | |
| ): | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.alist_files( # type: ignore | |
| purpose=purpose, | |
| openai_client=openai_client, | |
| ) | |
| if isinstance(purpose, str): | |
| response = openai_client.files.list(purpose=purpose) | |
| else: | |
| response = openai_client.files.list() | |
| return response | |
| class OpenAIBatchesAPI(BaseLLM): | |
| """ | |
| OpenAI methods to support for batches | |
| - create_batch() | |
| - retrieve_batch() | |
| - cancel_batch() | |
| - list_batch() | |
| """ | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def get_openai_client( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[Union[OpenAI, AsyncOpenAI]] = None, | |
| _is_async: bool = False, | |
| ) -> Optional[Union[OpenAI, AsyncOpenAI]]: | |
| received_args = locals() | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None | |
| if client is None: | |
| data = {} | |
| for k, v in received_args.items(): | |
| if k == "self" or k == "client" or k == "_is_async": | |
| pass | |
| elif k == "api_base" and v is not None: | |
| data["base_url"] = v | |
| elif v is not None: | |
| data[k] = v | |
| if _is_async is True: | |
| openai_client = AsyncOpenAI(**data) | |
| else: | |
| openai_client = OpenAI(**data) # type: ignore | |
| else: | |
| openai_client = client | |
| return openai_client | |
| async def acreate_batch( | |
| self, | |
| create_batch_data: CreateBatchRequest, | |
| openai_client: AsyncOpenAI, | |
| ) -> LiteLLMBatch: | |
| response = await openai_client.batches.create(**create_batch_data) | |
| return LiteLLMBatch(**response.model_dump()) | |
| def create_batch( | |
| self, | |
| _is_async: bool, | |
| create_batch_data: CreateBatchRequest, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[Union[OpenAI, AsyncOpenAI]] = None, | |
| ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.acreate_batch( # type: ignore | |
| create_batch_data=create_batch_data, openai_client=openai_client | |
| ) | |
| response = cast(OpenAI, openai_client).batches.create(**create_batch_data) | |
| return LiteLLMBatch(**response.model_dump()) | |
| async def aretrieve_batch( | |
| self, | |
| retrieve_batch_data: RetrieveBatchRequest, | |
| openai_client: AsyncOpenAI, | |
| ) -> LiteLLMBatch: | |
| verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data) | |
| response = await openai_client.batches.retrieve(**retrieve_batch_data) | |
| return LiteLLMBatch(**response.model_dump()) | |
| def retrieve_batch( | |
| self, | |
| _is_async: bool, | |
| retrieve_batch_data: RetrieveBatchRequest, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[OpenAI] = None, | |
| ): | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.aretrieve_batch( # type: ignore | |
| retrieve_batch_data=retrieve_batch_data, openai_client=openai_client | |
| ) | |
| response = cast(OpenAI, openai_client).batches.retrieve(**retrieve_batch_data) | |
| return LiteLLMBatch(**response.model_dump()) | |
| async def acancel_batch( | |
| self, | |
| cancel_batch_data: CancelBatchRequest, | |
| openai_client: AsyncOpenAI, | |
| ) -> Batch: | |
| verbose_logger.debug("async cancelling batch, args= %s", cancel_batch_data) | |
| response = await openai_client.batches.cancel(**cancel_batch_data) | |
| return response | |
| def cancel_batch( | |
| self, | |
| _is_async: bool, | |
| cancel_batch_data: CancelBatchRequest, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[OpenAI] = None, | |
| ): | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.acancel_batch( # type: ignore | |
| cancel_batch_data=cancel_batch_data, openai_client=openai_client | |
| ) | |
| response = openai_client.batches.cancel(**cancel_batch_data) | |
| return response | |
| async def alist_batches( | |
| self, | |
| openai_client: AsyncOpenAI, | |
| after: Optional[str] = None, | |
| limit: Optional[int] = None, | |
| ): | |
| verbose_logger.debug("listing batches, after= %s, limit= %s", after, limit) | |
| response = await openai_client.batches.list(after=after, limit=limit) # type: ignore | |
| return response | |
| def list_batches( | |
| self, | |
| _is_async: bool, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| after: Optional[str] = None, | |
| limit: Optional[int] = None, | |
| client: Optional[OpenAI] = None, | |
| ): | |
| openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.alist_batches( # type: ignore | |
| openai_client=openai_client, after=after, limit=limit | |
| ) | |
| response = openai_client.batches.list(after=after, limit=limit) # type: ignore | |
| return response | |
| class OpenAIAssistantsAPI(BaseLLM): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def get_openai_client( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[OpenAI] = None, | |
| ) -> OpenAI: | |
| received_args = locals() | |
| if client is None: | |
| data = {} | |
| for k, v in received_args.items(): | |
| if k == "self" or k == "client": | |
| pass | |
| elif k == "api_base" and v is not None: | |
| data["base_url"] = v | |
| elif v is not None: | |
| data[k] = v | |
| openai_client = OpenAI(**data) # type: ignore | |
| else: | |
| openai_client = client | |
| return openai_client | |
| def async_get_openai_client( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI] = None, | |
| ) -> AsyncOpenAI: | |
| received_args = locals() | |
| if client is None: | |
| data = {} | |
| for k, v in received_args.items(): | |
| if k == "self" or k == "client": | |
| pass | |
| elif k == "api_base" and v is not None: | |
| data["base_url"] = v | |
| elif v is not None: | |
| data[k] = v | |
| openai_client = AsyncOpenAI(**data) # type: ignore | |
| else: | |
| openai_client = client | |
| return openai_client | |
| ### ASSISTANTS ### | |
| async def async_get_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| order: Optional[str] = "desc", | |
| limit: Optional[int] = 20, | |
| before: Optional[str] = None, | |
| after: Optional[str] = None, | |
| ) -> AsyncCursorPage[Assistant]: | |
| openai_client = self.async_get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| request_params = { | |
| "order": order, | |
| "limit": limit, | |
| } | |
| if before: | |
| request_params["before"] = before | |
| if after: | |
| request_params["after"] = after | |
| response = await openai_client.beta.assistants.list(**request_params) # type: ignore | |
| return response | |
| # fmt: off | |
| def get_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| aget_assistants: Literal[True], | |
| ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]: | |
| ... | |
| def get_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[OpenAI], | |
| aget_assistants: Optional[Literal[False]], | |
| ) -> SyncCursorPage[Assistant]: | |
| ... | |
| # fmt: on | |
| def get_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client=None, | |
| aget_assistants=None, | |
| order: Optional[str] = "desc", | |
| limit: Optional[int] = 20, | |
| before: Optional[str] = None, | |
| after: Optional[str] = None, | |
| ): | |
| if aget_assistants is not None and aget_assistants is True: | |
| return self.async_get_assistants( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| openai_client = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| request_params = { | |
| "order": order, | |
| "limit": limit, | |
| } | |
| if before: | |
| request_params["before"] = before | |
| if after: | |
| request_params["after"] = after | |
| response = openai_client.beta.assistants.list(**request_params) # type: ignore | |
| return response | |
| # Create Assistant | |
| async def async_create_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| create_assistant_data: dict, | |
| ) -> Assistant: | |
| openai_client = self.async_get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| response = await openai_client.beta.assistants.create(**create_assistant_data) | |
| return response | |
| def create_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| create_assistant_data: dict, | |
| client=None, | |
| async_create_assistants=None, | |
| ): | |
| if async_create_assistants is not None and async_create_assistants is True: | |
| return self.async_create_assistants( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| create_assistant_data=create_assistant_data, | |
| ) | |
| openai_client = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| response = openai_client.beta.assistants.create(**create_assistant_data) | |
| return response | |
| # Delete Assistant | |
| async def async_delete_assistant( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| assistant_id: str, | |
| ) -> AssistantDeleted: | |
| openai_client = self.async_get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| response = await openai_client.beta.assistants.delete(assistant_id=assistant_id) | |
| return response | |
| def delete_assistant( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| assistant_id: str, | |
| client=None, | |
| async_delete_assistants=None, | |
| ): | |
| if async_delete_assistants is not None and async_delete_assistants is True: | |
| return self.async_delete_assistant( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| assistant_id=assistant_id, | |
| ) | |
| openai_client = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| response = openai_client.beta.assistants.delete(assistant_id=assistant_id) | |
| return response | |
| ### MESSAGES ### | |
| async def a_add_message( | |
| self, | |
| thread_id: str, | |
| message_data: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI] = None, | |
| ) -> OpenAIMessage: | |
| openai_client = self.async_get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore | |
| thread_id, **message_data # type: ignore | |
| ) | |
| response_obj: Optional[OpenAIMessage] = None | |
| if getattr(thread_message, "status", None) is None: | |
| thread_message.status = "completed" | |
| response_obj = OpenAIMessage(**thread_message.dict()) | |
| else: | |
| response_obj = OpenAIMessage(**thread_message.dict()) | |
| return response_obj | |
| # fmt: off | |
| def add_message( | |
| self, | |
| thread_id: str, | |
| message_data: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| a_add_message: Literal[True], | |
| ) -> Coroutine[None, None, OpenAIMessage]: | |
| ... | |
| def add_message( | |
| self, | |
| thread_id: str, | |
| message_data: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[OpenAI], | |
| a_add_message: Optional[Literal[False]], | |
| ) -> OpenAIMessage: | |
| ... | |
| # fmt: on | |
| def add_message( | |
| self, | |
| thread_id: str, | |
| message_data: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client=None, | |
| a_add_message: Optional[bool] = None, | |
| ): | |
| if a_add_message is not None and a_add_message is True: | |
| return self.a_add_message( | |
| thread_id=thread_id, | |
| message_data=message_data, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| openai_client = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore | |
| thread_id, **message_data # type: ignore | |
| ) | |
| response_obj: Optional[OpenAIMessage] = None | |
| if getattr(thread_message, "status", None) is None: | |
| thread_message.status = "completed" | |
| response_obj = OpenAIMessage(**thread_message.dict()) | |
| else: | |
| response_obj = OpenAIMessage(**thread_message.dict()) | |
| return response_obj | |
| async def async_get_messages( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI] = None, | |
| ) -> AsyncCursorPage[OpenAIMessage]: | |
| openai_client = self.async_get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| response = await openai_client.beta.threads.messages.list(thread_id=thread_id) | |
| return response | |
| # fmt: off | |
| def get_messages( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| aget_messages: Literal[True], | |
| ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: | |
| ... | |
| def get_messages( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[OpenAI], | |
| aget_messages: Optional[Literal[False]], | |
| ) -> SyncCursorPage[OpenAIMessage]: | |
| ... | |
| # fmt: on | |
| def get_messages( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client=None, | |
| aget_messages=None, | |
| ): | |
| if aget_messages is not None and aget_messages is True: | |
| return self.async_get_messages( | |
| thread_id=thread_id, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| openai_client = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| response = openai_client.beta.threads.messages.list(thread_id=thread_id) | |
| return response | |
| ### THREADS ### | |
| async def async_create_thread( | |
| self, | |
| metadata: Optional[dict], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
| ) -> Thread: | |
| openai_client = self.async_get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| data = {} | |
| if messages is not None: | |
| data["messages"] = messages # type: ignore | |
| if metadata is not None: | |
| data["metadata"] = metadata # type: ignore | |
| message_thread = await openai_client.beta.threads.create(**data) # type: ignore | |
| return Thread(**message_thread.dict()) | |
| # fmt: off | |
| def create_thread( | |
| self, | |
| metadata: Optional[dict], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
| client: Optional[AsyncOpenAI], | |
| acreate_thread: Literal[True], | |
| ) -> Coroutine[None, None, Thread]: | |
| ... | |
| def create_thread( | |
| self, | |
| metadata: Optional[dict], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
| client: Optional[OpenAI], | |
| acreate_thread: Optional[Literal[False]], | |
| ) -> Thread: | |
| ... | |
| # fmt: on | |
| def create_thread( | |
| self, | |
| metadata: Optional[dict], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
| client=None, | |
| acreate_thread=None, | |
| ): | |
| """ | |
| Here's an example: | |
| ``` | |
| from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData | |
| # create thread | |
| message: MessageData = {"role": "user", "content": "Hey, how's it going?"} | |
| openai_api.create_thread(messages=[message]) | |
| ``` | |
| """ | |
| if acreate_thread is not None and acreate_thread is True: | |
| return self.async_create_thread( | |
| metadata=metadata, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| messages=messages, | |
| ) | |
| openai_client = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| data = {} | |
| if messages is not None: | |
| data["messages"] = messages # type: ignore | |
| if metadata is not None: | |
| data["metadata"] = metadata # type: ignore | |
| message_thread = openai_client.beta.threads.create(**data) # type: ignore | |
| return Thread(**message_thread.dict()) | |
| async def async_get_thread( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| ) -> Thread: | |
| openai_client = self.async_get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| response = await openai_client.beta.threads.retrieve(thread_id=thread_id) | |
| return Thread(**response.dict()) | |
| # fmt: off | |
| def get_thread( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| aget_thread: Literal[True], | |
| ) -> Coroutine[None, None, Thread]: | |
| ... | |
| def get_thread( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[OpenAI], | |
| aget_thread: Optional[Literal[False]], | |
| ) -> Thread: | |
| ... | |
| # fmt: on | |
| def get_thread( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client=None, | |
| aget_thread=None, | |
| ): | |
| if aget_thread is not None and aget_thread is True: | |
| return self.async_get_thread( | |
| thread_id=thread_id, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| openai_client = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| response = openai_client.beta.threads.retrieve(thread_id=thread_id) | |
| return Thread(**response.dict()) | |
| def delete_thread(self): | |
| pass | |
| ### RUNS ### | |
| async def arun_thread( | |
| self, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| stream: Optional[bool], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[AsyncOpenAI], | |
| ) -> Run: | |
| openai_client = self.async_get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, | |
| model=model, | |
| tools=tools, | |
| ) | |
| return response | |
| def async_run_thread_stream( | |
| self, | |
| client: AsyncOpenAI, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| event_handler: Optional[AssistantEventHandler], | |
| ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: | |
| data: Dict[str, Any] = { | |
| "thread_id": thread_id, | |
| "assistant_id": assistant_id, | |
| "additional_instructions": additional_instructions, | |
| "instructions": instructions, | |
| "metadata": metadata, | |
| "model": model, | |
| "tools": tools, | |
| } | |
| if event_handler is not None: | |
| data["event_handler"] = event_handler | |
| return client.beta.threads.runs.stream(**data) # type: ignore | |
| def run_thread_stream( | |
| self, | |
| client: OpenAI, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| event_handler: Optional[AssistantEventHandler], | |
| ) -> AssistantStreamManager[AssistantEventHandler]: | |
| data: Dict[str, Any] = { | |
| "thread_id": thread_id, | |
| "assistant_id": assistant_id, | |
| "additional_instructions": additional_instructions, | |
| "instructions": instructions, | |
| "metadata": metadata, | |
| "model": model, | |
| "tools": tools, | |
| } | |
| if event_handler is not None: | |
| data["event_handler"] = event_handler | |
| return client.beta.threads.runs.stream(**data) # type: ignore | |
| # fmt: off | |
| def run_thread( | |
| self, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| stream: Optional[bool], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client, | |
| arun_thread: Literal[True], | |
| event_handler: Optional[AssistantEventHandler], | |
| ) -> Coroutine[None, None, Run]: | |
| ... | |
| def run_thread( | |
| self, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| stream: Optional[bool], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client, | |
| arun_thread: Optional[Literal[False]], | |
| event_handler: Optional[AssistantEventHandler], | |
| ) -> Run: | |
| ... | |
| # fmt: on | |
| def run_thread( | |
| self, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| stream: Optional[bool], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client=None, | |
| arun_thread=None, | |
| event_handler: Optional[AssistantEventHandler] = None, | |
| ): | |
| if arun_thread is not None and arun_thread is True: | |
| if stream is not None and stream is True: | |
| _client = self.async_get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| return self.async_run_thread_stream( | |
| client=_client, | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, | |
| model=model, | |
| tools=tools, | |
| event_handler=event_handler, | |
| ) | |
| return self.arun_thread( | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, | |
| model=model, | |
| stream=stream, | |
| tools=tools, | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| openai_client = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| ) | |
| if stream is not None and stream is True: | |
| return self.run_thread_stream( | |
| client=openai_client, | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, | |
| model=model, | |
| tools=tools, | |
| event_handler=event_handler, | |
| ) | |
| response = openai_client.beta.threads.runs.create_and_poll( # type: ignore | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, | |
| model=model, | |
| tools=tools, | |
| ) | |
| return response | |