Spaces:
Sleeping
Sleeping
from typing import ( | |
List, | |
Dict, | |
Union, | |
Generator, | |
AsyncGenerator, | |
) | |
from aworld.config import ConfigDict | |
from aworld.config.conf import AgentConfig, ClientType | |
from aworld.logs.util import logger | |
from aworld.core.llm_provider_base import LLMProviderBase | |
from aworld.models.openai_provider import OpenAIProvider, AzureOpenAIProvider | |
from aworld.models.anthropic_provider import AnthropicProvider | |
from aworld.models.ant_provider import AntProvider | |
from aworld.models.model_response import ModelResponse | |
# Predefined model names for common providers | |
MODEL_NAMES = { | |
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"], | |
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini"], | |
"azure_openai": ["gpt-4", "gpt-4-turbo", "gpt-4o", "gpt-35-turbo"], | |
} | |
# Endpoint patterns for identifying providers | |
ENDPOINT_PATTERNS = { | |
"openai": ["api.openai.com"], | |
"anthropic": ["api.anthropic.com", "claude-api"], | |
"azure_openai": ["openai.azure.com"], | |
"ant": ["zdfmng.alipay.com"], | |
} | |
# Provider class mapping | |
PROVIDER_CLASSES = { | |
"openai": OpenAIProvider, | |
"anthropic": AnthropicProvider, | |
"azure_openai": AzureOpenAIProvider, | |
"ant": AntProvider, | |
} | |
class LLMModel: | |
"""Unified large model interface, encapsulates different model implementations, provides a unified completion method. | |
""" | |
def __init__(self, conf: Union[ConfigDict, AgentConfig] = None, custom_provider: LLMProviderBase = None, **kwargs): | |
"""Initialize unified model interface. | |
Args: | |
conf: Agent configuration, if provided, create model based on configuration. | |
custom_provider: Custom LLMProviderBase instance, if provided, use it directly. | |
**kwargs: Other parameters, may include: | |
- base_url: Specify model endpoint. | |
- api_key: API key. | |
- model_name: Model name. | |
- temperature: Temperature parameter. | |
""" | |
# If custom_provider instance is provided, use it directly | |
if custom_provider is not None: | |
if not isinstance(custom_provider, LLMProviderBase): | |
raise TypeError( | |
"custom_provider must be an instance of LLMProviderBase") | |
self.provider_name = "custom" | |
self.provider = custom_provider | |
return | |
# Get basic parameters | |
base_url = kwargs.get("base_url") or ( | |
conf.llm_base_url if conf else None) | |
model_name = kwargs.get("model_name") or ( | |
conf.llm_model_name if conf else None) | |
llm_provider = conf.llm_provider if conf_contains_key( | |
conf, "llm_provider") else None | |
# Get API key from configuration (if any) | |
if conf and conf.llm_api_key: | |
kwargs["api_key"] = conf.llm_api_key | |
# Identify provider | |
self.provider_name = self._identify_provider( | |
llm_provider, base_url, model_name) | |
# Fill basic parameters | |
kwargs['base_url'] = base_url | |
kwargs['model_name'] = model_name | |
# Fill parameters for llm provider | |
kwargs['sync_enabled'] = conf.llm_sync_enabled if conf_contains_key( | |
conf, "llm_sync_enabled") else True | |
kwargs['async_enabled'] = conf.llm_async_enabled if conf_contains_key( | |
conf, "llm_async_enabled") else True | |
kwargs['client_type'] = conf.llm_client_type if conf_contains_key( | |
conf, "llm_client_type") else ClientType.SDK | |
kwargs.update(self._transfer_conf_to_args(conf)) | |
# Create model provider based on provider_name | |
self._create_provider(**kwargs) | |
def _transfer_conf_to_args(self, conf: Union[ConfigDict, AgentConfig] = None) -> dict: | |
""" | |
Transfer parameters from conf to args | |
Args: | |
conf: config object | |
""" | |
if not conf: | |
return {} | |
# Get all parameters from conf | |
if type(conf).__name__ == 'AgentConfig': | |
conf_dict = conf.model_dump() | |
else: # ConfigDict | |
conf_dict = conf | |
ignored_keys = ["llm_provider", "llm_base_url", "llm_model_name", "llm_api_key", "llm_sync_enabled", | |
"llm_async_enabled", "llm_client_type"] | |
args = {} | |
# Filter out used parameters and add remaining parameters to args | |
for key, value in conf_dict.items(): | |
if key not in ignored_keys and value is not None: | |
args[key] = value | |
return args | |
def _identify_provider(self, provider: str = None, base_url: str = None, model_name: str = None) -> str: | |
"""Identify LLM provider. | |
Identification logic: | |
1. If provider is specified and doesn't need to be overridden, use the specified provider. | |
2. If base_url is provided, try to identify provider based on base_url. | |
3. If model_name is provided, try to identify provider based on model_name. | |
4. If none can be identified, default to "openai". | |
Args: | |
provider: Specified provider. | |
base_url: Service URL. | |
model_name: Model name. | |
Returns: | |
str: Identified provider. | |
""" | |
# Default provider | |
identified_provider = "openai" | |
# Identify provider based on base_url | |
if base_url: | |
for p, patterns in ENDPOINT_PATTERNS.items(): | |
if any(pattern in base_url for pattern in patterns): | |
identified_provider = p | |
logger.info( | |
f"Identified provider: {identified_provider} based on base_url: {base_url}") | |
return identified_provider | |
# Identify provider based on model_name | |
if model_name and not base_url: | |
for p, models in MODEL_NAMES.items(): | |
if model_name in models or any(model_name.startswith(model) for model in models): | |
identified_provider = p | |
logger.info( | |
f"Identified provider: {identified_provider} based on model_name: {model_name}") | |
break | |
if provider and provider in PROVIDER_CLASSES and identified_provider and identified_provider != provider: | |
logger.warning( | |
f"Provider mismatch: {provider} != {identified_provider}, using {provider} as provider") | |
identified_provider = provider | |
return identified_provider | |
def _create_provider(self, **kwargs): | |
"""Return the corresponding provider instance based on provider. | |
Args: | |
**kwargs: Parameters, may include: | |
- base_url: Model endpoint. | |
- api_key: API key. | |
- model_name: Model name. | |
- temperature: Temperature parameter. | |
- timeout: Timeout. | |
- max_retries: Maximum number of retries. | |
""" | |
self.provider = PROVIDER_CLASSES[self.provider_name](**kwargs) | |
def supported_providers(cls) -> list[str]: | |
return list(PROVIDER_CLASSES.keys()) | |
def supported_models(self) -> list[str]: | |
"""Get supported models for the current provider. | |
Returns: | |
list: Supported models. | |
""" | |
return self.provider.supported_models() if self.provider else [] | |
async def acompletion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> ModelResponse: | |
"""Asynchronously call model to generate response. | |
Args: | |
messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse: Unified model response object. | |
""" | |
# Call provider's acompletion method directly | |
return await self.provider.acompletion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=stop, | |
**kwargs | |
) | |
def completion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> ModelResponse: | |
"""Synchronously call model to generate response. | |
Args: | |
messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse: Unified model response object. | |
""" | |
# Call provider's completion method directly | |
return self.provider.completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=stop, | |
**kwargs | |
) | |
def stream_completion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> Generator[ModelResponse, None, None]: | |
"""Synchronously call model to generate streaming response. | |
Args: | |
messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
Generator yielding ModelResponse chunks. | |
""" | |
# Call provider's stream_completion method directly | |
return self.provider.stream_completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=stop, | |
**kwargs | |
) | |
async def astream_completion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> AsyncGenerator[ModelResponse, None]: | |
"""Asynchronously call model to generate streaming response. | |
Args: | |
messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters, may include: | |
- base_url: Specify model endpoint. | |
- api_key: API key. | |
- model_name: Model name. | |
Returns: | |
AsyncGenerator yielding ModelResponse chunks. | |
""" | |
# Call provider's astream_completion method directly | |
async for chunk in self.provider.astream_completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=stop, | |
**kwargs | |
): | |
yield chunk | |
def speech_to_text(self, | |
audio_file: str, | |
language: str = None, | |
prompt: str = None, | |
**kwargs) -> ModelResponse: | |
"""Convert speech to text. | |
Args: | |
audio_file: Path to audio file or file object. | |
language: Audio language, optional. | |
prompt: Transcription prompt, optional. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse: Unified model response object, with content field containing the transcription result. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
NotImplementedError: When provider does not support speech to text conversion. | |
""" | |
return self.provider.speech_to_text( | |
audio_file=audio_file, | |
language=language, | |
prompt=prompt, | |
**kwargs | |
) | |
async def aspeech_to_text(self, | |
audio_file: str, | |
language: str = None, | |
prompt: str = None, | |
**kwargs) -> ModelResponse: | |
"""Asynchronously convert speech to text. | |
Args: | |
audio_file: Path to audio file or file object. | |
language: Audio language, optional. | |
prompt: Transcription prompt, optional. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse: Unified model response object, with content field containing the transcription result. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
NotImplementedError: When provider does not support speech to text conversion. | |
""" | |
return await self.provider.aspeech_to_text( | |
audio_file=audio_file, | |
language=language, | |
prompt=prompt, | |
**kwargs | |
) | |
def register_llm_provider(provider: str, provider_class: type): | |
"""Register a custom LLM provider. | |
Args: | |
provider: Provider name. | |
provider_class: Provider class, must inherit from LLMProviderBase. | |
""" | |
if not issubclass(provider_class, LLMProviderBase): | |
raise TypeError("provider_class must be a subclass of LLMProviderBase") | |
PROVIDER_CLASSES[provider] = provider_class | |
def conf_contains_key(conf: Union[ConfigDict, AgentConfig], key: str) -> bool: | |
"""Check if conf contains key. | |
Args: | |
conf: Config object. | |
key: Key to check. | |
Returns: | |
bool: Whether conf contains key. | |
""" | |
if not conf: | |
return False | |
if type(conf).__name__ == 'AgentConfig': | |
return hasattr(conf, key) | |
else: | |
return key in conf | |
def get_llm_model(conf: Union[ConfigDict, AgentConfig] = None, | |
custom_provider: LLMProviderBase = None, | |
**kwargs) -> Union[LLMModel, 'ChatOpenAI']: | |
"""Get a unified LLM model instance. | |
Args: | |
conf: Agent configuration, if provided, create model based on configuration. | |
custom_provider: Custom LLMProviderBase instance, if provided, use it directly. | |
**kwargs: Other parameters, may include: | |
- base_url: Specify model endpoint. | |
- api_key: API key. | |
- model_name: Model name. | |
- temperature: Temperature parameter. | |
Returns: | |
Unified model interface. | |
""" | |
# Create and return LLMModel instance directly | |
llm_provider = conf.llm_provider if conf_contains_key( | |
conf, "llm_provider") else None | |
if (llm_provider == "chatopenai"): | |
from langchain_openai import ChatOpenAI | |
base_url = kwargs.get("base_url") or ( | |
conf.llm_base_url if conf_contains_key(conf, "llm_base_url") else None) | |
model_name = kwargs.get("model_name") or ( | |
conf.llm_model_name if conf_contains_key(conf, "llm_model_name") else None) | |
api_key = kwargs.get("api_key") or ( | |
conf.llm_api_key if conf_contains_key(conf, "llm_api_key") else None) | |
return ChatOpenAI( | |
model=model_name, | |
temperature=kwargs.get("temperature", conf.llm_temperature if conf_contains_key( | |
conf, "llm_temperature") else 0.0), | |
base_url=base_url, | |
api_key=api_key, | |
) | |
return LLMModel(conf=conf, custom_provider=custom_provider, **kwargs) | |
def call_llm_model( | |
llm_model: LLMModel, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
stream: bool = False, | |
**kwargs | |
) -> Union[ModelResponse, Generator[ModelResponse, None, None]]: | |
"""Convenience function to call LLM model. | |
Args: | |
llm_model: LLM model instance. | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
stream: Whether to return a streaming response. | |
**kwargs: Other parameters. | |
Returns: | |
Model response or response generator. | |
""" | |
if stream: | |
return llm_model.stream_completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=stop, | |
**kwargs | |
) | |
else: | |
return llm_model.completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=stop, | |
**kwargs | |
) | |
async def acall_llm_model( | |
llm_model: LLMModel, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
stream: bool = False, | |
**kwargs | |
) -> ModelResponse: | |
"""Convenience function to asynchronously call LLM model. | |
Args: | |
llm_model: LLM model instance. | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
stream: Whether to return a streaming response. | |
**kwargs: Other parameters. | |
Returns: | |
Model response or response generator. | |
""" | |
return await llm_model.acompletion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=stop, | |
**kwargs | |
) | |
async def acall_llm_model_stream( | |
llm_model: LLMModel, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs | |
) -> AsyncGenerator[ModelResponse, None]: | |
async for chunk in llm_model.astream_completion( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=stop, | |
**kwargs | |
): | |
yield chunk | |
def speech_to_text( | |
llm_model: LLMModel, | |
audio_file: str, | |
language: str = None, | |
prompt: str = None, | |
**kwargs | |
) -> ModelResponse: | |
"""Convenience function to convert speech to text. | |
Args: | |
llm_model: LLM model instance. | |
audio_file: Path to audio file or file object. | |
language: Audio language, optional. | |
prompt: Transcription prompt, optional. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse: Unified model response object, with content field containing the transcription result. | |
""" | |
if llm_model.provider_name != "openai": | |
raise NotImplementedError( | |
f"Speech-to-text functionality is currently only supported for OpenAI compatible provider, current provider: {llm_model.provider_name}") | |
return llm_model.speech_to_text( | |
audio_file=audio_file, | |
language=language, | |
prompt=prompt, | |
**kwargs | |
) | |
async def aspeech_to_text( | |
llm_model: LLMModel, | |
audio_file: str, | |
language: str = None, | |
prompt: str = None, | |
**kwargs | |
) -> ModelResponse: | |
"""Convenience function to asynchronously convert speech to text. | |
Args: | |
llm_model: LLM model instance. | |
audio_file: Path to audio file or file object. | |
language: Audio language, optional. | |
prompt: Transcription prompt, optional. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse: Unified model response object, with content field containing the transcription result. | |
""" | |
if llm_model.provider_name != "openai": | |
raise NotImplementedError( | |
f"Speech-to-text functionality is currently only supported for OpenAI compatible provider, current provider: {llm_model.provider_name}") | |
return await llm_model.aspeech_to_text( | |
audio_file=audio_file, | |
language=language, | |
prompt=prompt, | |
**kwargs | |
) | |