Duibonduil's picture
Upload 10 files
ae64487 verified
import os
from typing import Any, Dict, List, Generator, AsyncGenerator
from openai import OpenAI, AsyncOpenAI
from aworld.config.conf import ClientType
from aworld.core.llm_provider_base import LLMProviderBase
from aworld.models.llm_http_handler import LLMHTTPHandler
from aworld.models.model_response import ModelResponse, LLMResponseError
from aworld.logs.util import logger
from aworld.models.utils import usage_process
class OpenAIProvider(LLMProviderBase):
"""OpenAI provider implementation.
"""
def _init_provider(self):
"""Initialize OpenAI provider.
Returns:
OpenAI provider instance.
"""
# Get API key
api_key = self.api_key
if not api_key:
env_var = "OPENAI_API_KEY"
api_key = os.getenv(env_var, "")
if not api_key:
raise ValueError(
f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters")
base_url = self.base_url
if not base_url:
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
self.is_http_provider = False
if self.kwargs.get("client_type", ClientType.SDK) == ClientType.HTTP:
logger.info(f"Using HTTP provider for OpenAI")
self.http_provider = LLMHTTPHandler(
base_url=base_url,
api_key=api_key,
model_name=self.model_name,
max_retries=self.kwargs.get("max_retries", 3)
)
self.is_http_provider = True
return self.http_provider
else:
return OpenAI(
api_key=api_key,
base_url=base_url,
timeout=self.kwargs.get("timeout", 180),
max_retries=self.kwargs.get("max_retries", 3)
)
def _init_async_provider(self):
"""Initialize async OpenAI provider.
Returns:
Async OpenAI provider instance.
"""
# Get API key
api_key = self.api_key
if not api_key:
env_var = "OPENAI_API_KEY"
api_key = os.getenv(env_var, "")
if not api_key:
raise ValueError(
f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters")
base_url = self.base_url
if not base_url:
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
return AsyncOpenAI(
api_key=api_key,
base_url=base_url,
timeout=self.kwargs.get("timeout", 180),
max_retries=self.kwargs.get("max_retries", 3)
)
@classmethod
def supported_models(cls) -> list[str]:
return ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini", "deepseek-chat", "deepseek-reasoner",
r"qwq-.*", r"qwen-.*"]
def preprocess_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Preprocess messages, use OpenAI format directly.
Args:
messages: OpenAI format message list.
Returns:
Processed message list.
"""
for message in messages:
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"]:
if message["content"] is None: message["content"] = ""
for tool_call in message["tool_calls"]:
if "function" not in tool_call and "name" in tool_call and "arguments" in tool_call:
tool_call["function"] = {"name": tool_call["name"], "arguments": tool_call["arguments"]}
return messages
def postprocess_response(self, response: Any) -> ModelResponse:
"""Process OpenAI response.
Args:
response: OpenAI response object.
Returns:
ModelResponse object.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if ((not isinstance(response, dict) and (not hasattr(response, 'choices') or not response.choices))
or (isinstance(response, dict) and not response.get("choices"))):
error_msg = ""
if hasattr(response, 'error') and response.error and isinstance(response.error, dict):
error_msg = response.error.get('message', '')
elif hasattr(response, 'msg'):
error_msg = response.msg
raise LLMResponseError(
error_msg if error_msg else "Unknown error",
self.model_name or "unknown",
response
)
return ModelResponse.from_openai_response(response)
def postprocess_stream_response(self, chunk: Any) -> ModelResponse:
"""Process OpenAI streaming response chunk.
Args:
chunk: OpenAI response chunk.
Returns:
ModelResponse object.
Raises:
LLMResponseError: When LLM response error occurs.
"""
# Check if chunk contains error
if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')):
error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error')
raise LLMResponseError(
error_msg,
self.model_name or "unknown",
chunk
)
# process tool calls
if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.tool_calls) or (
isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get("delta", {}).get("tool_calls")):
tool_calls = chunk.choices[0].delta.tool_calls if hasattr(chunk, 'choices') else chunk["choices"][0].get("delta", {}).get("tool_calls")
for tool_call in tool_calls:
index = tool_call.index if hasattr(tool_call, 'index') else tool_call["index"]
func_name = tool_call.function.name if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("name")
func_args = tool_call.function.arguments if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("arguments")
if index >= len(self.stream_tool_buffer):
self.stream_tool_buffer.append({
"id": tool_call.id if hasattr(tool_call, 'id') else tool_call.get("id"),
"type": "function",
"function": {
"name": func_name,
"arguments": func_args
}
})
else:
self.stream_tool_buffer[index]["function"]["arguments"] += func_args
processed_chunk = chunk
if hasattr(processed_chunk, 'choices'):
processed_chunk.choices[0].delta.tool_calls = None
else:
processed_chunk["choices"][0]["delta"]["tool_calls"] = None
resp = ModelResponse.from_openai_stream_chunk(processed_chunk)
if (not resp.content and not resp.usage.get("total_tokens", 0)):
return None
if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].finish_reason) or (
isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get(
"finish_reason")):
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk, 'choices') else chunk["choices"][0].get(
"finish_reason")
if self.stream_tool_buffer:
tool_call_chunk = {
"id": chunk.id if hasattr(chunk, 'id') else chunk.get("id"),
"model": chunk.model if hasattr(chunk, 'model') else chunk.get("model"),
"object": chunk.object if hasattr(chunk, 'object') else chunk.get("object"),
"choices": [
{
"delta": {
"role": "assistant",
"content": "",
"tool_calls": self.stream_tool_buffer
}
}
]
}
self.stream_tool_buffer = []
return ModelResponse.from_openai_stream_chunk(tool_call_chunk)
return ModelResponse.from_openai_stream_chunk(chunk)
def completion(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> ModelResponse:
"""Synchronously call OpenAI to generate response.
Args:
messages: Message list.
temperature: Temperature parameter.
max_tokens: Maximum number of tokens to generate.
stop: List of stop sequences.
**kwargs: Other parameters.
Returns:
ModelResponse object.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.provider:
raise RuntimeError(
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
processed_messages = self.preprocess_messages(messages)
try:
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs)
if self.is_http_provider:
response = self.http_provider.sync_call(openai_params)
else:
response = self.provider.chat.completions.create(**openai_params)
if (hasattr(response, 'code') and response.code != 0) or (
isinstance(response, dict) and response.get("code", 0) != 0):
error_msg = getattr(response, 'msg', 'Unknown error')
logger.warn(f"API Error: {error_msg}")
raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response)
if not response:
raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown"))
resp = self.postprocess_response(response)
usage_process(resp.usage)
return resp
except Exception as e:
if isinstance(e, LLMResponseError):
raise e
logger.warn(f"Error in OpenAI completion: {e}")
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
def stream_completion(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> Generator[ModelResponse, None, None]:
"""Synchronously call OpenAI to generate streaming response.
Args:
messages: Message list.
temperature: Temperature parameter.
max_tokens: Maximum number of tokens to generate.
stop: List of stop sequences.
**kwargs: Other parameters.
Returns:
Generator yielding ModelResponse chunks.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.provider:
raise RuntimeError(
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
processed_messages = self.preprocess_messages(messages)
usage={
"completion_tokens": 0,
"prompt_tokens": 0,
"total_tokens": 0
}
try:
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs)
openai_params["stream"] = True
if self.is_http_provider:
response_stream = self.http_provider.sync_stream_call(openai_params)
else:
response_stream = self.provider.chat.completions.create(**openai_params)
for chunk in response_stream:
if not chunk:
continue
resp = self.postprocess_stream_response(chunk)
if resp:
self._accumulate_chunk_usage(usage, resp.usage)
yield resp
usage_process(usage)
except Exception as e:
logger.warn(f"Error in stream_completion: {e}")
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
async def astream_completion(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> AsyncGenerator[ModelResponse, None]:
"""Asynchronously call OpenAI to generate streaming response.
Args:
messages: Message list.
temperature: Temperature parameter.
max_tokens: Maximum number of tokens to generate.
stop: List of stop sequences.
**kwargs: Other parameters.
Returns:
AsyncGenerator yielding ModelResponse chunks.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.async_provider:
raise RuntimeError(
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.")
processed_messages = self.preprocess_messages(messages)
usage = {
"completion_tokens": 0,
"prompt_tokens": 0,
"total_tokens": 0
}
try:
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs)
openai_params["stream"] = True
if self.is_http_provider:
async for chunk in self.http_provider.async_stream_call(openai_params):
if not chunk:
continue
resp = self.postprocess_stream_response(chunk)
self._accumulate_chunk_usage(usage, resp.usage)
yield resp
else:
response_stream = await self.async_provider.chat.completions.create(**openai_params)
async for chunk in response_stream:
if not chunk:
continue
resp = self.postprocess_stream_response(chunk)
if resp:
self._accumulate_chunk_usage(usage, resp.usage)
yield resp
usage_process(usage)
except Exception as e:
logger.warn(f"Error in astream_completion: {e}")
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
async def acompletion(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> ModelResponse:
"""Asynchronously call OpenAI to generate response.
Args:
messages: Message list.
temperature: Temperature parameter.
max_tokens: Maximum number of tokens to generate.
stop: List of stop sequences.
**kwargs: Other parameters.
Returns:
ModelResponse object.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.async_provider:
raise RuntimeError(
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.")
processed_messages = self.preprocess_messages(messages)
try:
openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs)
if self.is_http_provider:
response = await self.http_provider.async_call(openai_params)
else:
response = await self.async_provider.chat.completions.create(**openai_params)
if (hasattr(response, 'code') and response.code != 0) or (
isinstance(response, dict) and response.get("code", 0) != 0):
error_msg = getattr(response, 'msg', 'Unknown error')
logger.warn(f"API Error: {error_msg}")
raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response)
if not response:
raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown"))
resp = self.postprocess_response(response)
usage_process(resp.usage)
return resp
except Exception as e:
if isinstance(e, LLMResponseError):
raise e
logger.warn(f"Error in acompletion: {e}")
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
def get_openai_params(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> Dict[str, Any]:
openai_params = {
"model": kwargs.get("model_name", self.model_name or ""),
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop
}
supported_params = [
"max_completion_tokens", "meta_data", "modalities", "n", "parallel_tool_calls",
"prediction", "reasoning_effort", "service_tier", "stream_options", "web_search_options"
"frequency_penalty", "logit_bias", "logprobs", "top_logprobs",
"presence_penalty", "response_format", "seed", "stream", "top_p",
"user", "function_call", "functions", "tools", "tool_choice"
]
for param in supported_params:
if param in kwargs:
openai_params[param] = kwargs[param]
return openai_params
def speech_to_text(self,
audio_file: str,
language: str = None,
prompt: str = None,
**kwargs) -> ModelResponse:
"""Convert speech to text.
Uses OpenAI's speech-to-text API to convert audio files to text.
Args:
audio_file: Path to audio file or file object.
language: Audio language, optional.
prompt: Transcription prompt, optional.
**kwargs: Other parameters, may include:
- model: Transcription model name, defaults to "whisper-1".
- response_format: Response format, defaults to "text".
- temperature: Sampling temperature, defaults to 0.
Returns:
ModelResponse: Unified model response object, with content field containing the transcription result.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.provider:
raise RuntimeError(
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
try:
# Prepare parameters
transcription_params = {
"model": kwargs.get("model", "whisper-1"),
"response_format": kwargs.get("response_format", "text"),
"temperature": kwargs.get("temperature", 0)
}
# Add optional parameters
if language:
transcription_params["language"] = language
if prompt:
transcription_params["prompt"] = prompt
# Open file (if path is provided)
if isinstance(audio_file, str):
with open(audio_file, "rb") as file:
transcription_response = self.provider.audio.transcriptions.create(
file=file,
**transcription_params
)
else:
# If already a file object
transcription_response = self.provider.audio.transcriptions.create(
file=audio_file,
**transcription_params
)
# Create ModelResponse
return ModelResponse(
id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}",
model=transcription_params["model"],
content=transcription_response.text if hasattr(transcription_response, 'text') else str(
transcription_response),
raw_response=transcription_response,
message={
"role": "assistant",
"content": transcription_response.text if hasattr(transcription_response, 'text') else str(
transcription_response)
}
)
except Exception as e:
logger.warn(f"Speech-to-text error: {e}")
raise LLMResponseError(str(e), kwargs.get("model", "whisper-1"))
async def aspeech_to_text(self,
audio_file: str,
language: str = None,
prompt: str = None,
**kwargs) -> ModelResponse:
"""Asynchronously convert speech to text.
Uses OpenAI's speech-to-text API to convert audio files to text.
Args:
audio_file: Path to audio file or file object.
language: Audio language, optional.
prompt: Transcription prompt, optional.
**kwargs: Other parameters, may include:
- model: Transcription model name, defaults to "whisper-1".
- response_format: Response format, defaults to "text".
- temperature: Sampling temperature, defaults to 0.
Returns:
ModelResponse: Unified model response object, with content field containing the transcription result.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.async_provider:
raise RuntimeError(
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.")
try:
# Prepare parameters
transcription_params = {
"model": kwargs.get("model", "whisper-1"),
"response_format": kwargs.get("response_format", "text"),
"temperature": kwargs.get("temperature", 0)
}
# Add optional parameters
if language:
transcription_params["language"] = language
if prompt:
transcription_params["prompt"] = prompt
# Open file (if path is provided)
if isinstance(audio_file, str):
with open(audio_file, "rb") as file:
transcription_response = await self.async_provider.audio.transcriptions.create(
file=file,
**transcription_params
)
else:
# If already a file object
transcription_response = await self.async_provider.audio.transcriptions.create(
file=audio_file,
**transcription_params
)
# Create ModelResponse
return ModelResponse(
id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}",
model=transcription_params["model"],
content=transcription_response.text if hasattr(transcription_response, 'text') else str(
transcription_response),
raw_response=transcription_response,
message={
"role": "assistant",
"content": transcription_response.text if hasattr(transcription_response, 'text') else str(
transcription_response)
}
)
except Exception as e:
logger.warn(f"Async speech-to-text error: {e}")
raise LLMResponseError(str(e), kwargs.get("model", "whisper-1"))
class AzureOpenAIProvider(OpenAIProvider):
"""Azure OpenAI provider implementation.
"""
def _init_provider(self):
"""Initialize Azure OpenAI provider.
Returns:
Azure OpenAI provider instance.
"""
from langchain_openai import AzureChatOpenAI
# Get API key
api_key = self.api_key
if not api_key:
env_var = "AZURE_OPENAI_API_KEY"
api_key = os.getenv(env_var, "")
if not api_key:
raise ValueError(
f"Azure OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters")
# Get API version
api_version = self.kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
# Get endpoint
azure_endpoint = self.base_url
if not azure_endpoint:
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "")
if not azure_endpoint:
raise ValueError(
"Azure OpenAI endpoint not found, please set AZURE_OPENAI_ENDPOINT environment variable or provide it in the parameters")
return AzureChatOpenAI(
model=self.model_name or "gpt-4o",
temperature=self.kwargs.get("temperature", 0.0),
api_version=api_version,
azure_endpoint=azure_endpoint,
api_key=api_key
)