from typing import ( List, Dict, Union, Generator, AsyncGenerator, ) from aworld.config import ConfigDict from aworld.config.conf import AgentConfig, ClientType from aworld.logs.util import logger from aworld.core.llm_provider_base import LLMProviderBase from aworld.models.openai_provider import OpenAIProvider, AzureOpenAIProvider from aworld.models.anthropic_provider import AnthropicProvider from aworld.models.ant_provider import AntProvider from aworld.models.model_response import ModelResponse # Predefined model names for common providers MODEL_NAMES = { "anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"], "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini"], "azure_openai": ["gpt-4", "gpt-4-turbo", "gpt-4o", "gpt-35-turbo"], } # Endpoint patterns for identifying providers ENDPOINT_PATTERNS = { "openai": ["api.openai.com"], "anthropic": ["api.anthropic.com", "claude-api"], "azure_openai": ["openai.azure.com"], "ant": ["zdfmng.alipay.com"], } # Provider class mapping PROVIDER_CLASSES = { "openai": OpenAIProvider, "anthropic": AnthropicProvider, "azure_openai": AzureOpenAIProvider, "ant": AntProvider, } class LLMModel: """Unified large model interface, encapsulates different model implementations, provides a unified completion method. """ def __init__(self, conf: Union[ConfigDict, AgentConfig] = None, custom_provider: LLMProviderBase = None, **kwargs): """Initialize unified model interface. Args: conf: Agent configuration, if provided, create model based on configuration. custom_provider: Custom LLMProviderBase instance, if provided, use it directly. **kwargs: Other parameters, may include: - base_url: Specify model endpoint. - api_key: API key. - model_name: Model name. - temperature: Temperature parameter. """ # If custom_provider instance is provided, use it directly if custom_provider is not None: if not isinstance(custom_provider, LLMProviderBase): raise TypeError( "custom_provider must be an instance of LLMProviderBase") self.provider_name = "custom" self.provider = custom_provider return # Get basic parameters base_url = kwargs.get("base_url") or ( conf.llm_base_url if conf else None) model_name = kwargs.get("model_name") or ( conf.llm_model_name if conf else None) llm_provider = conf.llm_provider if conf_contains_key( conf, "llm_provider") else None # Get API key from configuration (if any) if conf and conf.llm_api_key: kwargs["api_key"] = conf.llm_api_key # Identify provider self.provider_name = self._identify_provider( llm_provider, base_url, model_name) # Fill basic parameters kwargs['base_url'] = base_url kwargs['model_name'] = model_name # Fill parameters for llm provider kwargs['sync_enabled'] = conf.llm_sync_enabled if conf_contains_key( conf, "llm_sync_enabled") else True kwargs['async_enabled'] = conf.llm_async_enabled if conf_contains_key( conf, "llm_async_enabled") else True kwargs['client_type'] = conf.llm_client_type if conf_contains_key( conf, "llm_client_type") else ClientType.SDK kwargs.update(self._transfer_conf_to_args(conf)) # Create model provider based on provider_name self._create_provider(**kwargs) def _transfer_conf_to_args(self, conf: Union[ConfigDict, AgentConfig] = None) -> dict: """ Transfer parameters from conf to args Args: conf: config object """ if not conf: return {} # Get all parameters from conf if type(conf).__name__ == 'AgentConfig': conf_dict = conf.model_dump() else: # ConfigDict conf_dict = conf ignored_keys = ["llm_provider", "llm_base_url", "llm_model_name", "llm_api_key", "llm_sync_enabled", "llm_async_enabled", "llm_client_type"] args = {} # Filter out used parameters and add remaining parameters to args for key, value in conf_dict.items(): if key not in ignored_keys and value is not None: args[key] = value return args def _identify_provider(self, provider: str = None, base_url: str = None, model_name: str = None) -> str: """Identify LLM provider. Identification logic: 1. If provider is specified and doesn't need to be overridden, use the specified provider. 2. If base_url is provided, try to identify provider based on base_url. 3. If model_name is provided, try to identify provider based on model_name. 4. If none can be identified, default to "openai". Args: provider: Specified provider. base_url: Service URL. model_name: Model name. Returns: str: Identified provider. """ # Default provider identified_provider = "openai" # Identify provider based on base_url if base_url: for p, patterns in ENDPOINT_PATTERNS.items(): if any(pattern in base_url for pattern in patterns): identified_provider = p logger.info( f"Identified provider: {identified_provider} based on base_url: {base_url}") return identified_provider # Identify provider based on model_name if model_name and not base_url: for p, models in MODEL_NAMES.items(): if model_name in models or any(model_name.startswith(model) for model in models): identified_provider = p logger.info( f"Identified provider: {identified_provider} based on model_name: {model_name}") break if provider and provider in PROVIDER_CLASSES and identified_provider and identified_provider != provider: logger.warning( f"Provider mismatch: {provider} != {identified_provider}, using {provider} as provider") identified_provider = provider return identified_provider def _create_provider(self, **kwargs): """Return the corresponding provider instance based on provider. Args: **kwargs: Parameters, may include: - base_url: Model endpoint. - api_key: API key. - model_name: Model name. - temperature: Temperature parameter. - timeout: Timeout. - max_retries: Maximum number of retries. """ self.provider = PROVIDER_CLASSES[self.provider_name](**kwargs) @classmethod def supported_providers(cls) -> list[str]: return list(PROVIDER_CLASSES.keys()) def supported_models(self) -> list[str]: """Get supported models for the current provider. Returns: list: Supported models. """ return self.provider.supported_models() if self.provider else [] async def acompletion(self, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs) -> ModelResponse: """Asynchronously call model to generate response. Args: messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. **kwargs: Other parameters. Returns: ModelResponse: Unified model response object. """ # Call provider's acompletion method directly return await self.provider.acompletion( messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs ) def completion(self, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs) -> ModelResponse: """Synchronously call model to generate response. Args: messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. **kwargs: Other parameters. Returns: ModelResponse: Unified model response object. """ # Call provider's completion method directly return self.provider.completion( messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs ) def stream_completion(self, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs) -> Generator[ModelResponse, None, None]: """Synchronously call model to generate streaming response. Args: messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. **kwargs: Other parameters. Returns: Generator yielding ModelResponse chunks. """ # Call provider's stream_completion method directly return self.provider.stream_completion( messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs ) async def astream_completion(self, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs) -> AsyncGenerator[ModelResponse, None]: """Asynchronously call model to generate streaming response. Args: messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. **kwargs: Other parameters, may include: - base_url: Specify model endpoint. - api_key: API key. - model_name: Model name. Returns: AsyncGenerator yielding ModelResponse chunks. """ # Call provider's astream_completion method directly async for chunk in self.provider.astream_completion( messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs ): yield chunk def speech_to_text(self, audio_file: str, language: str = None, prompt: str = None, **kwargs) -> ModelResponse: """Convert speech to text. Args: audio_file: Path to audio file or file object. language: Audio language, optional. prompt: Transcription prompt, optional. **kwargs: Other parameters. Returns: ModelResponse: Unified model response object, with content field containing the transcription result. Raises: LLMResponseError: When LLM response error occurs. NotImplementedError: When provider does not support speech to text conversion. """ return self.provider.speech_to_text( audio_file=audio_file, language=language, prompt=prompt, **kwargs ) async def aspeech_to_text(self, audio_file: str, language: str = None, prompt: str = None, **kwargs) -> ModelResponse: """Asynchronously convert speech to text. Args: audio_file: Path to audio file or file object. language: Audio language, optional. prompt: Transcription prompt, optional. **kwargs: Other parameters. Returns: ModelResponse: Unified model response object, with content field containing the transcription result. Raises: LLMResponseError: When LLM response error occurs. NotImplementedError: When provider does not support speech to text conversion. """ return await self.provider.aspeech_to_text( audio_file=audio_file, language=language, prompt=prompt, **kwargs ) def register_llm_provider(provider: str, provider_class: type): """Register a custom LLM provider. Args: provider: Provider name. provider_class: Provider class, must inherit from LLMProviderBase. """ if not issubclass(provider_class, LLMProviderBase): raise TypeError("provider_class must be a subclass of LLMProviderBase") PROVIDER_CLASSES[provider] = provider_class def conf_contains_key(conf: Union[ConfigDict, AgentConfig], key: str) -> bool: """Check if conf contains key. Args: conf: Config object. key: Key to check. Returns: bool: Whether conf contains key. """ if not conf: return False if type(conf).__name__ == 'AgentConfig': return hasattr(conf, key) else: return key in conf def get_llm_model(conf: Union[ConfigDict, AgentConfig] = None, custom_provider: LLMProviderBase = None, **kwargs) -> Union[LLMModel, 'ChatOpenAI']: """Get a unified LLM model instance. Args: conf: Agent configuration, if provided, create model based on configuration. custom_provider: Custom LLMProviderBase instance, if provided, use it directly. **kwargs: Other parameters, may include: - base_url: Specify model endpoint. - api_key: API key. - model_name: Model name. - temperature: Temperature parameter. Returns: Unified model interface. """ # Create and return LLMModel instance directly llm_provider = conf.llm_provider if conf_contains_key( conf, "llm_provider") else None if (llm_provider == "chatopenai"): from langchain_openai import ChatOpenAI base_url = kwargs.get("base_url") or ( conf.llm_base_url if conf_contains_key(conf, "llm_base_url") else None) model_name = kwargs.get("model_name") or ( conf.llm_model_name if conf_contains_key(conf, "llm_model_name") else None) api_key = kwargs.get("api_key") or ( conf.llm_api_key if conf_contains_key(conf, "llm_api_key") else None) return ChatOpenAI( model=model_name, temperature=kwargs.get("temperature", conf.llm_temperature if conf_contains_key( conf, "llm_temperature") else 0.0), base_url=base_url, api_key=api_key, ) return LLMModel(conf=conf, custom_provider=custom_provider, **kwargs) def call_llm_model( llm_model: LLMModel, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, stream: bool = False, **kwargs ) -> Union[ModelResponse, Generator[ModelResponse, None, None]]: """Convenience function to call LLM model. Args: llm_model: LLM model instance. messages: Message list. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. stream: Whether to return a streaming response. **kwargs: Other parameters. Returns: Model response or response generator. """ if stream: return llm_model.stream_completion( messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs ) else: return llm_model.completion( messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs ) async def acall_llm_model( llm_model: LLMModel, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, stream: bool = False, **kwargs ) -> ModelResponse: """Convenience function to asynchronously call LLM model. Args: llm_model: LLM model instance. messages: Message list. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. stream: Whether to return a streaming response. **kwargs: Other parameters. Returns: Model response or response generator. """ return await llm_model.acompletion( messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs ) async def acall_llm_model_stream( llm_model: LLMModel, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs ) -> AsyncGenerator[ModelResponse, None]: async for chunk in llm_model.astream_completion( messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs ): yield chunk def speech_to_text( llm_model: LLMModel, audio_file: str, language: str = None, prompt: str = None, **kwargs ) -> ModelResponse: """Convenience function to convert speech to text. Args: llm_model: LLM model instance. audio_file: Path to audio file or file object. language: Audio language, optional. prompt: Transcription prompt, optional. **kwargs: Other parameters. Returns: ModelResponse: Unified model response object, with content field containing the transcription result. """ if llm_model.provider_name != "openai": raise NotImplementedError( f"Speech-to-text functionality is currently only supported for OpenAI compatible provider, current provider: {llm_model.provider_name}") return llm_model.speech_to_text( audio_file=audio_file, language=language, prompt=prompt, **kwargs ) async def aspeech_to_text( llm_model: LLMModel, audio_file: str, language: str = None, prompt: str = None, **kwargs ) -> ModelResponse: """Convenience function to asynchronously convert speech to text. Args: llm_model: LLM model instance. audio_file: Path to audio file or file object. language: Audio language, optional. prompt: Transcription prompt, optional. **kwargs: Other parameters. Returns: ModelResponse: Unified model response object, with content field containing the transcription result. """ if llm_model.provider_name != "openai": raise NotImplementedError( f"Speech-to-text functionality is currently only supported for OpenAI compatible provider, current provider: {llm_model.provider_name}") return await llm_model.aspeech_to_text( audio_file=audio_file, language=language, prompt=prompt, **kwargs )