Spaces:
Paused
Paused
| from typing import TYPE_CHECKING, Any, List, Optional, Union | |
| from httpx._models import Headers, Response | |
| import litellm | |
| from litellm.litellm_core_utils.prompt_templates.factory import ( | |
| convert_to_azure_openai_messages, | |
| ) | |
| from litellm.llms.base_llm.chat.transformation import BaseLLMException | |
| from litellm.types.llms.azure import ( | |
| API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT, | |
| API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT, | |
| ) | |
| from litellm.types.utils import ModelResponse | |
| from litellm.utils import supports_response_schema | |
| from ....exceptions import UnsupportedParamsError | |
| from ....types.llms.openai import AllMessageValues | |
| from ...base_llm.chat.transformation import BaseConfig | |
| from ..common_utils import AzureOpenAIError | |
| if TYPE_CHECKING: | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
| LoggingClass = LiteLLMLoggingObj | |
| else: | |
| LoggingClass = Any | |
| class AzureOpenAIConfig(BaseConfig): | |
| """ | |
| Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions | |
| The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. Below are the parameters:: | |
| - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. | |
| - `function_call` (string or object): This optional parameter controls how the model calls functions. | |
| - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. | |
| - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. | |
| - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. | |
| - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. | |
| - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. | |
| - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. | |
| - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. | |
| - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. | |
| """ | |
| def __init__( | |
| self, | |
| frequency_penalty: Optional[int] = None, | |
| function_call: Optional[Union[str, dict]] = None, | |
| functions: Optional[list] = None, | |
| logit_bias: Optional[dict] = None, | |
| max_tokens: Optional[int] = None, | |
| n: Optional[int] = None, | |
| presence_penalty: Optional[int] = None, | |
| stop: Optional[Union[str, list]] = None, | |
| temperature: Optional[int] = None, | |
| top_p: Optional[int] = None, | |
| ) -> None: | |
| locals_ = locals().copy() | |
| for key, value in locals_.items(): | |
| if key != "self" and value is not None: | |
| setattr(self.__class__, key, value) | |
| def get_config(cls): | |
| return super().get_config() | |
| def get_supported_openai_params(self, model: str) -> List[str]: | |
| return [ | |
| "temperature", | |
| "n", | |
| "stream", | |
| "stream_options", | |
| "stop", | |
| "max_tokens", | |
| "max_completion_tokens", | |
| "tools", | |
| "tool_choice", | |
| "presence_penalty", | |
| "frequency_penalty", | |
| "logit_bias", | |
| "user", | |
| "function_call", | |
| "functions", | |
| "tools", | |
| "tool_choice", | |
| "top_p", | |
| "logprobs", | |
| "top_logprobs", | |
| "response_format", | |
| "seed", | |
| "extra_headers", | |
| "parallel_tool_calls", | |
| "prediction", | |
| "modalities", | |
| "audio", | |
| ] | |
| def _is_response_format_supported_model(self, model: str) -> bool: | |
| """ | |
| - all 4o models are supported | |
| - check if 'supports_response_format' is True from get_model_info | |
| - [TODO] support smart retries for 3.5 models (some supported, some not) | |
| """ | |
| if "4o" in model: | |
| return True | |
| elif supports_response_schema(model): | |
| return True | |
| return False | |
| def _is_response_format_supported_api_version( | |
| self, api_version_year: str, api_version_month: str | |
| ) -> bool: | |
| """ | |
| - check if api_version is supported for response_format | |
| - returns True if the API version is equal to or newer than the supported version | |
| """ | |
| api_year = int(api_version_year) | |
| api_month = int(api_version_month) | |
| supported_year = int(API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT) | |
| supported_month = int(API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT) | |
| # If the year is greater than supported year, it's definitely supported | |
| if api_year > supported_year: | |
| return True | |
| # If the year is less than supported year, it's not supported | |
| elif api_year < supported_year: | |
| return False | |
| # If same year, check if month is >= supported month | |
| else: | |
| return api_month >= supported_month | |
| def map_openai_params( | |
| self, | |
| non_default_params: dict, | |
| optional_params: dict, | |
| model: str, | |
| drop_params: bool, | |
| api_version: str = "", | |
| ) -> dict: | |
| supported_openai_params = self.get_supported_openai_params(model) | |
| api_version_times = api_version.split("-") | |
| api_version_year = api_version_times[0] | |
| api_version_month = api_version_times[1] | |
| api_version_day = api_version_times[2] | |
| for param, value in non_default_params.items(): | |
| if param == "tool_choice": | |
| """ | |
| This parameter requires API version 2023-12-01-preview or later | |
| tool_choice='required' is not supported as of 2024-05-01-preview | |
| """ | |
| ## check if api version supports this param ## | |
| if ( | |
| api_version_year < "2023" | |
| or (api_version_year == "2023" and api_version_month < "12") | |
| or ( | |
| api_version_year == "2023" | |
| and api_version_month == "12" | |
| and api_version_day < "01" | |
| ) | |
| ): | |
| if litellm.drop_params is True or ( | |
| drop_params is not None and drop_params is True | |
| ): | |
| pass | |
| else: | |
| raise UnsupportedParamsError( | |
| status_code=400, | |
| message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""", | |
| ) | |
| elif value == "required" and ( | |
| api_version_year == "2024" and api_version_month <= "05" | |
| ): ## check if tool_choice value is supported ## | |
| if litellm.drop_params is True or ( | |
| drop_params is not None and drop_params is True | |
| ): | |
| pass | |
| else: | |
| raise UnsupportedParamsError( | |
| status_code=400, | |
| message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions", | |
| ) | |
| else: | |
| optional_params["tool_choice"] = value | |
| elif param == "response_format" and isinstance(value, dict): | |
| _is_response_format_supported_model = ( | |
| self._is_response_format_supported_model(model) | |
| ) | |
| is_response_format_supported_api_version = ( | |
| self._is_response_format_supported_api_version( | |
| api_version_year, api_version_month | |
| ) | |
| ) | |
| is_response_format_supported = ( | |
| is_response_format_supported_api_version | |
| and _is_response_format_supported_model | |
| ) | |
| optional_params = self._add_response_format_to_tools( | |
| optional_params=optional_params, | |
| value=value, | |
| is_response_format_supported=is_response_format_supported, | |
| ) | |
| elif param == "tools" and isinstance(value, list): | |
| optional_params.setdefault("tools", []) | |
| optional_params["tools"].extend(value) | |
| elif param in supported_openai_params: | |
| optional_params[param] = value | |
| return optional_params | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| messages = convert_to_azure_openai_messages(messages) | |
| return { | |
| "model": model, | |
| "messages": messages, | |
| **optional_params, | |
| } | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: Response, | |
| model_response: ModelResponse, | |
| logging_obj: LoggingClass, | |
| request_data: dict, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| encoding: Any, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| raise NotImplementedError( | |
| "Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK." | |
| ) | |
| def get_mapped_special_auth_params(self) -> dict: | |
| return {"token": "azure_ad_token"} | |
| def map_special_auth_params(self, non_default_params: dict, optional_params: dict): | |
| for param, value in non_default_params.items(): | |
| if param == "token": | |
| optional_params["azure_ad_token"] = value | |
| return optional_params | |
| def get_eu_regions(self) -> List[str]: | |
| """ | |
| Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability | |
| """ | |
| return ["europe", "sweden", "switzerland", "france", "uk"] | |
| def get_us_regions(self) -> List[str]: | |
| """ | |
| Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability | |
| """ | |
| return [ | |
| "us", | |
| "eastus", | |
| "eastus2", | |
| "eastus2euap", | |
| "eastus3", | |
| "southcentralus", | |
| "westus", | |
| "westus2", | |
| "westus3", | |
| "westus4", | |
| ] | |
| def get_error_class( | |
| self, error_message: str, status_code: int, headers: Union[dict, Headers] | |
| ) -> BaseLLMException: | |
| return AzureOpenAIError( | |
| message=error_message, status_code=status_code, headers=headers | |
| ) | |
| def validate_environment( | |
| self, | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| raise NotImplementedError( | |
| "Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK." | |
| ) | |