Spaces:
Sleeping
Sleeping
import os | |
from typing import Any, Dict, List, Generator, AsyncGenerator | |
from openai import OpenAI, AsyncOpenAI | |
from aworld.config.conf import ClientType | |
from aworld.core.llm_provider_base import LLMProviderBase | |
from aworld.models.llm_http_handler import LLMHTTPHandler | |
from aworld.models.model_response import ModelResponse, LLMResponseError | |
from aworld.logs.util import logger | |
from aworld.models.utils import usage_process | |
class OpenAIProvider(LLMProviderBase): | |
"""OpenAI provider implementation. | |
""" | |
def _init_provider(self): | |
"""Initialize OpenAI provider. | |
Returns: | |
OpenAI provider instance. | |
""" | |
# Get API key | |
api_key = self.api_key | |
if not api_key: | |
env_var = "OPENAI_API_KEY" | |
api_key = os.getenv(env_var, "") | |
if not api_key: | |
raise ValueError( | |
f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters") | |
base_url = self.base_url | |
if not base_url: | |
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") | |
self.is_http_provider = False | |
if self.kwargs.get("client_type", ClientType.SDK) == ClientType.HTTP: | |
logger.info(f"Using HTTP provider for OpenAI") | |
self.http_provider = LLMHTTPHandler( | |
base_url=base_url, | |
api_key=api_key, | |
model_name=self.model_name, | |
max_retries=self.kwargs.get("max_retries", 3) | |
) | |
self.is_http_provider = True | |
return self.http_provider | |
else: | |
return OpenAI( | |
api_key=api_key, | |
base_url=base_url, | |
timeout=self.kwargs.get("timeout", 180), | |
max_retries=self.kwargs.get("max_retries", 3) | |
) | |
def _init_async_provider(self): | |
"""Initialize async OpenAI provider. | |
Returns: | |
Async OpenAI provider instance. | |
""" | |
# Get API key | |
api_key = self.api_key | |
if not api_key: | |
env_var = "OPENAI_API_KEY" | |
api_key = os.getenv(env_var, "") | |
if not api_key: | |
raise ValueError( | |
f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters") | |
base_url = self.base_url | |
if not base_url: | |
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") | |
return AsyncOpenAI( | |
api_key=api_key, | |
base_url=base_url, | |
timeout=self.kwargs.get("timeout", 180), | |
max_retries=self.kwargs.get("max_retries", 3) | |
) | |
def supported_models(cls) -> list[str]: | |
return ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini", "deepseek-chat", "deepseek-reasoner", | |
r"qwq-.*", r"qwen-.*"] | |
def preprocess_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: | |
"""Preprocess messages, use OpenAI format directly. | |
Args: | |
messages: OpenAI format message list. | |
Returns: | |
Processed message list. | |
""" | |
for message in messages: | |
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"]: | |
if message["content"] is None: message["content"] = "" | |
for tool_call in message["tool_calls"]: | |
if "function" not in tool_call and "name" in tool_call and "arguments" in tool_call: | |
tool_call["function"] = {"name": tool_call["name"], "arguments": tool_call["arguments"]} | |
return messages | |
def postprocess_response(self, response: Any) -> ModelResponse: | |
"""Process OpenAI response. | |
Args: | |
response: OpenAI response object. | |
Returns: | |
ModelResponse object. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if ((not isinstance(response, dict) and (not hasattr(response, 'choices') or not response.choices)) | |
or (isinstance(response, dict) and not response.get("choices"))): | |
error_msg = "" | |
if hasattr(response, 'error') and response.error and isinstance(response.error, dict): | |
error_msg = response.error.get('message', '') | |
elif hasattr(response, 'msg'): | |
error_msg = response.msg | |
raise LLMResponseError( | |
error_msg if error_msg else "Unknown error", | |
self.model_name or "unknown", | |
response | |
) | |
return ModelResponse.from_openai_response(response) | |
def postprocess_stream_response(self, chunk: Any) -> ModelResponse: | |
"""Process OpenAI streaming response chunk. | |
Args: | |
chunk: OpenAI response chunk. | |
Returns: | |
ModelResponse object. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
# Check if chunk contains error | |
if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')): | |
error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error') | |
raise LLMResponseError( | |
error_msg, | |
self.model_name or "unknown", | |
chunk | |
) | |
# process tool calls | |
if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.tool_calls) or ( | |
isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get("delta", {}).get("tool_calls")): | |
tool_calls = chunk.choices[0].delta.tool_calls if hasattr(chunk, 'choices') else chunk["choices"][0].get("delta", {}).get("tool_calls") | |
for tool_call in tool_calls: | |
index = tool_call.index if hasattr(tool_call, 'index') else tool_call["index"] | |
func_name = tool_call.function.name if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("name") | |
func_args = tool_call.function.arguments if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("arguments") | |
if index >= len(self.stream_tool_buffer): | |
self.stream_tool_buffer.append({ | |
"id": tool_call.id if hasattr(tool_call, 'id') else tool_call.get("id"), | |
"type": "function", | |
"function": { | |
"name": func_name, | |
"arguments": func_args | |
} | |
}) | |
else: | |
self.stream_tool_buffer[index]["function"]["arguments"] += func_args | |
processed_chunk = chunk | |
if hasattr(processed_chunk, 'choices'): | |
processed_chunk.choices[0].delta.tool_calls = None | |
else: | |
processed_chunk["choices"][0]["delta"]["tool_calls"] = None | |
resp = ModelResponse.from_openai_stream_chunk(processed_chunk) | |
if (not resp.content and not resp.usage.get("total_tokens", 0)): | |
return None | |
if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].finish_reason) or ( | |
isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get( | |
"finish_reason")): | |
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk, 'choices') else chunk["choices"][0].get( | |
"finish_reason") | |
if self.stream_tool_buffer: | |
tool_call_chunk = { | |
"id": chunk.id if hasattr(chunk, 'id') else chunk.get("id"), | |
"model": chunk.model if hasattr(chunk, 'model') else chunk.get("model"), | |
"object": chunk.object if hasattr(chunk, 'object') else chunk.get("object"), | |
"choices": [ | |
{ | |
"delta": { | |
"role": "assistant", | |
"content": "", | |
"tool_calls": self.stream_tool_buffer | |
} | |
} | |
] | |
} | |
self.stream_tool_buffer = [] | |
return ModelResponse.from_openai_stream_chunk(tool_call_chunk) | |
return ModelResponse.from_openai_stream_chunk(chunk) | |
def completion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> ModelResponse: | |
"""Synchronously call OpenAI to generate response. | |
Args: | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse object. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.provider: | |
raise RuntimeError( | |
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") | |
processed_messages = self.preprocess_messages(messages) | |
try: | |
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) | |
if self.is_http_provider: | |
response = self.http_provider.sync_call(openai_params) | |
else: | |
response = self.provider.chat.completions.create(**openai_params) | |
if (hasattr(response, 'code') and response.code != 0) or ( | |
isinstance(response, dict) and response.get("code", 0) != 0): | |
error_msg = getattr(response, 'msg', 'Unknown error') | |
logger.warn(f"API Error: {error_msg}") | |
raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response) | |
if not response: | |
raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown")) | |
resp = self.postprocess_response(response) | |
usage_process(resp.usage) | |
return resp | |
except Exception as e: | |
if isinstance(e, LLMResponseError): | |
raise e | |
logger.warn(f"Error in OpenAI completion: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
def stream_completion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> Generator[ModelResponse, None, None]: | |
"""Synchronously call OpenAI to generate streaming response. | |
Args: | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
Generator yielding ModelResponse chunks. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.provider: | |
raise RuntimeError( | |
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") | |
processed_messages = self.preprocess_messages(messages) | |
usage={ | |
"completion_tokens": 0, | |
"prompt_tokens": 0, | |
"total_tokens": 0 | |
} | |
try: | |
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) | |
openai_params["stream"] = True | |
if self.is_http_provider: | |
response_stream = self.http_provider.sync_stream_call(openai_params) | |
else: | |
response_stream = self.provider.chat.completions.create(**openai_params) | |
for chunk in response_stream: | |
if not chunk: | |
continue | |
resp = self.postprocess_stream_response(chunk) | |
if resp: | |
self._accumulate_chunk_usage(usage, resp.usage) | |
yield resp | |
usage_process(usage) | |
except Exception as e: | |
logger.warn(f"Error in stream_completion: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
async def astream_completion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> AsyncGenerator[ModelResponse, None]: | |
"""Asynchronously call OpenAI to generate streaming response. | |
Args: | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
AsyncGenerator yielding ModelResponse chunks. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.async_provider: | |
raise RuntimeError( | |
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.") | |
processed_messages = self.preprocess_messages(messages) | |
usage = { | |
"completion_tokens": 0, | |
"prompt_tokens": 0, | |
"total_tokens": 0 | |
} | |
try: | |
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) | |
openai_params["stream"] = True | |
if self.is_http_provider: | |
async for chunk in self.http_provider.async_stream_call(openai_params): | |
if not chunk: | |
continue | |
resp = self.postprocess_stream_response(chunk) | |
self._accumulate_chunk_usage(usage, resp.usage) | |
yield resp | |
else: | |
response_stream = await self.async_provider.chat.completions.create(**openai_params) | |
async for chunk in response_stream: | |
if not chunk: | |
continue | |
resp = self.postprocess_stream_response(chunk) | |
if resp: | |
self._accumulate_chunk_usage(usage, resp.usage) | |
yield resp | |
usage_process(usage) | |
except Exception as e: | |
logger.warn(f"Error in astream_completion: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
async def acompletion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> ModelResponse: | |
"""Asynchronously call OpenAI to generate response. | |
Args: | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse object. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.async_provider: | |
raise RuntimeError( | |
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.") | |
processed_messages = self.preprocess_messages(messages) | |
try: | |
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) | |
if self.is_http_provider: | |
response = await self.http_provider.async_call(openai_params) | |
else: | |
response = await self.async_provider.chat.completions.create(**openai_params) | |
if (hasattr(response, 'code') and response.code != 0) or ( | |
isinstance(response, dict) and response.get("code", 0) != 0): | |
error_msg = getattr(response, 'msg', 'Unknown error') | |
logger.warn(f"API Error: {error_msg}") | |
raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response) | |
if not response: | |
raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown")) | |
resp = self.postprocess_response(response) | |
usage_process(resp.usage) | |
return resp | |
except Exception as e: | |
if isinstance(e, LLMResponseError): | |
raise e | |
logger.warn(f"Error in acompletion: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
def get_openai_params(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> Dict[str, Any]: | |
openai_params = { | |
"model": kwargs.get("model_name", self.model_name or ""), | |
"messages": messages, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"stop": stop | |
} | |
supported_params = [ | |
"max_completion_tokens", "meta_data", "modalities", "n", "parallel_tool_calls", | |
"prediction", "reasoning_effort", "service_tier", "stream_options", "web_search_options" | |
"frequency_penalty", "logit_bias", "logprobs", "top_logprobs", | |
"presence_penalty", "response_format", "seed", "stream", "top_p", | |
"user", "function_call", "functions", "tools", "tool_choice" | |
] | |
for param in supported_params: | |
if param in kwargs: | |
openai_params[param] = kwargs[param] | |
return openai_params | |
def speech_to_text(self, | |
audio_file: str, | |
language: str = None, | |
prompt: str = None, | |
**kwargs) -> ModelResponse: | |
"""Convert speech to text. | |
Uses OpenAI's speech-to-text API to convert audio files to text. | |
Args: | |
audio_file: Path to audio file or file object. | |
language: Audio language, optional. | |
prompt: Transcription prompt, optional. | |
**kwargs: Other parameters, may include: | |
- model: Transcription model name, defaults to "whisper-1". | |
- response_format: Response format, defaults to "text". | |
- temperature: Sampling temperature, defaults to 0. | |
Returns: | |
ModelResponse: Unified model response object, with content field containing the transcription result. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.provider: | |
raise RuntimeError( | |
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") | |
try: | |
# Prepare parameters | |
transcription_params = { | |
"model": kwargs.get("model", "whisper-1"), | |
"response_format": kwargs.get("response_format", "text"), | |
"temperature": kwargs.get("temperature", 0) | |
} | |
# Add optional parameters | |
if language: | |
transcription_params["language"] = language | |
if prompt: | |
transcription_params["prompt"] = prompt | |
# Open file (if path is provided) | |
if isinstance(audio_file, str): | |
with open(audio_file, "rb") as file: | |
transcription_response = self.provider.audio.transcriptions.create( | |
file=file, | |
**transcription_params | |
) | |
else: | |
# If already a file object | |
transcription_response = self.provider.audio.transcriptions.create( | |
file=audio_file, | |
**transcription_params | |
) | |
# Create ModelResponse | |
return ModelResponse( | |
id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}", | |
model=transcription_params["model"], | |
content=transcription_response.text if hasattr(transcription_response, 'text') else str( | |
transcription_response), | |
raw_response=transcription_response, | |
message={ | |
"role": "assistant", | |
"content": transcription_response.text if hasattr(transcription_response, 'text') else str( | |
transcription_response) | |
} | |
) | |
except Exception as e: | |
logger.warn(f"Speech-to-text error: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model", "whisper-1")) | |
async def aspeech_to_text(self, | |
audio_file: str, | |
language: str = None, | |
prompt: str = None, | |
**kwargs) -> ModelResponse: | |
"""Asynchronously convert speech to text. | |
Uses OpenAI's speech-to-text API to convert audio files to text. | |
Args: | |
audio_file: Path to audio file or file object. | |
language: Audio language, optional. | |
prompt: Transcription prompt, optional. | |
**kwargs: Other parameters, may include: | |
- model: Transcription model name, defaults to "whisper-1". | |
- response_format: Response format, defaults to "text". | |
- temperature: Sampling temperature, defaults to 0. | |
Returns: | |
ModelResponse: Unified model response object, with content field containing the transcription result. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.async_provider: | |
raise RuntimeError( | |
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.") | |
try: | |
# Prepare parameters | |
transcription_params = { | |
"model": kwargs.get("model", "whisper-1"), | |
"response_format": kwargs.get("response_format", "text"), | |
"temperature": kwargs.get("temperature", 0) | |
} | |
# Add optional parameters | |
if language: | |
transcription_params["language"] = language | |
if prompt: | |
transcription_params["prompt"] = prompt | |
# Open file (if path is provided) | |
if isinstance(audio_file, str): | |
with open(audio_file, "rb") as file: | |
transcription_response = await self.async_provider.audio.transcriptions.create( | |
file=file, | |
**transcription_params | |
) | |
else: | |
# If already a file object | |
transcription_response = await self.async_provider.audio.transcriptions.create( | |
file=audio_file, | |
**transcription_params | |
) | |
# Create ModelResponse | |
return ModelResponse( | |
id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}", | |
model=transcription_params["model"], | |
content=transcription_response.text if hasattr(transcription_response, 'text') else str( | |
transcription_response), | |
raw_response=transcription_response, | |
message={ | |
"role": "assistant", | |
"content": transcription_response.text if hasattr(transcription_response, 'text') else str( | |
transcription_response) | |
} | |
) | |
except Exception as e: | |
logger.warn(f"Async speech-to-text error: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model", "whisper-1")) | |
class AzureOpenAIProvider(OpenAIProvider): | |
"""Azure OpenAI provider implementation. | |
""" | |
def _init_provider(self): | |
"""Initialize Azure OpenAI provider. | |
Returns: | |
Azure OpenAI provider instance. | |
""" | |
from langchain_openai import AzureChatOpenAI | |
# Get API key | |
api_key = self.api_key | |
if not api_key: | |
env_var = "AZURE_OPENAI_API_KEY" | |
api_key = os.getenv(env_var, "") | |
if not api_key: | |
raise ValueError( | |
f"Azure OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters") | |
# Get API version | |
api_version = self.kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview") | |
# Get endpoint | |
azure_endpoint = self.base_url | |
if not azure_endpoint: | |
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "") | |
if not azure_endpoint: | |
raise ValueError( | |
"Azure OpenAI endpoint not found, please set AZURE_OPENAI_ENDPOINT environment variable or provide it in the parameters") | |
return AzureChatOpenAI( | |
model=self.model_name or "gpt-4o", | |
temperature=self.kwargs.get("temperature", 0.0), | |
api_version=api_version, | |
azure_endpoint=azure_endpoint, | |
api_key=api_key | |
) | |