import os from typing import Any, Dict, List, Generator, AsyncGenerator from openai import OpenAI, AsyncOpenAI from aworld.config.conf import ClientType from aworld.core.llm_provider_base import LLMProviderBase from aworld.models.llm_http_handler import LLMHTTPHandler from aworld.models.model_response import ModelResponse, LLMResponseError from aworld.logs.util import logger from aworld.models.utils import usage_process class OpenAIProvider(LLMProviderBase): """OpenAI provider implementation. """ def _init_provider(self): """Initialize OpenAI provider. Returns: OpenAI provider instance. """ # Get API key api_key = self.api_key if not api_key: env_var = "OPENAI_API_KEY" api_key = os.getenv(env_var, "") if not api_key: raise ValueError( f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters") base_url = self.base_url if not base_url: base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") self.is_http_provider = False if self.kwargs.get("client_type", ClientType.SDK) == ClientType.HTTP: logger.info(f"Using HTTP provider for OpenAI") self.http_provider = LLMHTTPHandler( base_url=base_url, api_key=api_key, model_name=self.model_name, max_retries=self.kwargs.get("max_retries", 3) ) self.is_http_provider = True return self.http_provider else: return OpenAI( api_key=api_key, base_url=base_url, timeout=self.kwargs.get("timeout", 180), max_retries=self.kwargs.get("max_retries", 3) ) def _init_async_provider(self): """Initialize async OpenAI provider. Returns: Async OpenAI provider instance. """ # Get API key api_key = self.api_key if not api_key: env_var = "OPENAI_API_KEY" api_key = os.getenv(env_var, "") if not api_key: raise ValueError( f"OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters") base_url = self.base_url if not base_url: base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") return AsyncOpenAI( api_key=api_key, base_url=base_url, timeout=self.kwargs.get("timeout", 180), max_retries=self.kwargs.get("max_retries", 3) ) @classmethod def supported_models(cls) -> list[str]: return ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini", "deepseek-chat", "deepseek-reasoner", r"qwq-.*", r"qwen-.*"] def preprocess_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: """Preprocess messages, use OpenAI format directly. Args: messages: OpenAI format message list. Returns: Processed message list. """ for message in messages: if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"]: if message["content"] is None: message["content"] = "" for tool_call in message["tool_calls"]: if "function" not in tool_call and "name" in tool_call and "arguments" in tool_call: tool_call["function"] = {"name": tool_call["name"], "arguments": tool_call["arguments"]} return messages def postprocess_response(self, response: Any) -> ModelResponse: """Process OpenAI response. Args: response: OpenAI response object. Returns: ModelResponse object. Raises: LLMResponseError: When LLM response error occurs. """ if ((not isinstance(response, dict) and (not hasattr(response, 'choices') or not response.choices)) or (isinstance(response, dict) and not response.get("choices"))): error_msg = "" if hasattr(response, 'error') and response.error and isinstance(response.error, dict): error_msg = response.error.get('message', '') elif hasattr(response, 'msg'): error_msg = response.msg raise LLMResponseError( error_msg if error_msg else "Unknown error", self.model_name or "unknown", response ) return ModelResponse.from_openai_response(response) def postprocess_stream_response(self, chunk: Any) -> ModelResponse: """Process OpenAI streaming response chunk. Args: chunk: OpenAI response chunk. Returns: ModelResponse object. Raises: LLMResponseError: When LLM response error occurs. """ # Check if chunk contains error if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')): error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error') raise LLMResponseError( error_msg, self.model_name or "unknown", chunk ) # process tool calls if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.tool_calls) or ( isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get("delta", {}).get("tool_calls")): tool_calls = chunk.choices[0].delta.tool_calls if hasattr(chunk, 'choices') else chunk["choices"][0].get("delta", {}).get("tool_calls") for tool_call in tool_calls: index = tool_call.index if hasattr(tool_call, 'index') else tool_call["index"] func_name = tool_call.function.name if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("name") func_args = tool_call.function.arguments if hasattr(tool_call, 'function') else tool_call.get("function", {}).get("arguments") if index >= len(self.stream_tool_buffer): self.stream_tool_buffer.append({ "id": tool_call.id if hasattr(tool_call, 'id') else tool_call.get("id"), "type": "function", "function": { "name": func_name, "arguments": func_args } }) else: self.stream_tool_buffer[index]["function"]["arguments"] += func_args processed_chunk = chunk if hasattr(processed_chunk, 'choices'): processed_chunk.choices[0].delta.tool_calls = None else: processed_chunk["choices"][0]["delta"]["tool_calls"] = None resp = ModelResponse.from_openai_stream_chunk(processed_chunk) if (not resp.content and not resp.usage.get("total_tokens", 0)): return None if (hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].finish_reason) or ( isinstance(chunk, dict) and chunk.get("choices") and chunk["choices"] and chunk["choices"][0].get( "finish_reason")): finish_reason = chunk.choices[0].finish_reason if hasattr(chunk, 'choices') else chunk["choices"][0].get( "finish_reason") if self.stream_tool_buffer: tool_call_chunk = { "id": chunk.id if hasattr(chunk, 'id') else chunk.get("id"), "model": chunk.model if hasattr(chunk, 'model') else chunk.get("model"), "object": chunk.object if hasattr(chunk, 'object') else chunk.get("object"), "choices": [ { "delta": { "role": "assistant", "content": "", "tool_calls": self.stream_tool_buffer } } ] } self.stream_tool_buffer = [] return ModelResponse.from_openai_stream_chunk(tool_call_chunk) return ModelResponse.from_openai_stream_chunk(chunk) def completion(self, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs) -> ModelResponse: """Synchronously call OpenAI to generate response. Args: messages: Message list. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. **kwargs: Other parameters. Returns: ModelResponse object. Raises: LLMResponseError: When LLM response error occurs. """ if not self.provider: raise RuntimeError( "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") processed_messages = self.preprocess_messages(messages) try: openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) if self.is_http_provider: response = self.http_provider.sync_call(openai_params) else: response = self.provider.chat.completions.create(**openai_params) if (hasattr(response, 'code') and response.code != 0) or ( isinstance(response, dict) and response.get("code", 0) != 0): error_msg = getattr(response, 'msg', 'Unknown error') logger.warn(f"API Error: {error_msg}") raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response) if not response: raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown")) resp = self.postprocess_response(response) usage_process(resp.usage) return resp except Exception as e: if isinstance(e, LLMResponseError): raise e logger.warn(f"Error in OpenAI completion: {e}") raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) def stream_completion(self, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs) -> Generator[ModelResponse, None, None]: """Synchronously call OpenAI to generate streaming response. Args: messages: Message list. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. **kwargs: Other parameters. Returns: Generator yielding ModelResponse chunks. Raises: LLMResponseError: When LLM response error occurs. """ if not self.provider: raise RuntimeError( "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") processed_messages = self.preprocess_messages(messages) usage={ "completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0 } try: openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) openai_params["stream"] = True if self.is_http_provider: response_stream = self.http_provider.sync_stream_call(openai_params) else: response_stream = self.provider.chat.completions.create(**openai_params) for chunk in response_stream: if not chunk: continue resp = self.postprocess_stream_response(chunk) if resp: self._accumulate_chunk_usage(usage, resp.usage) yield resp usage_process(usage) except Exception as e: logger.warn(f"Error in stream_completion: {e}") raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) async def astream_completion(self, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs) -> AsyncGenerator[ModelResponse, None]: """Asynchronously call OpenAI to generate streaming response. Args: messages: Message list. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. **kwargs: Other parameters. Returns: AsyncGenerator yielding ModelResponse chunks. Raises: LLMResponseError: When LLM response error occurs. """ if not self.async_provider: raise RuntimeError( "Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.") processed_messages = self.preprocess_messages(messages) usage = { "completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0 } try: openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) openai_params["stream"] = True if self.is_http_provider: async for chunk in self.http_provider.async_stream_call(openai_params): if not chunk: continue resp = self.postprocess_stream_response(chunk) self._accumulate_chunk_usage(usage, resp.usage) yield resp else: response_stream = await self.async_provider.chat.completions.create(**openai_params) async for chunk in response_stream: if not chunk: continue resp = self.postprocess_stream_response(chunk) if resp: self._accumulate_chunk_usage(usage, resp.usage) yield resp usage_process(usage) except Exception as e: logger.warn(f"Error in astream_completion: {e}") raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) async def acompletion(self, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs) -> ModelResponse: """Asynchronously call OpenAI to generate response. Args: messages: Message list. temperature: Temperature parameter. max_tokens: Maximum number of tokens to generate. stop: List of stop sequences. **kwargs: Other parameters. Returns: ModelResponse object. Raises: LLMResponseError: When LLM response error occurs. """ if not self.async_provider: raise RuntimeError( "Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.") processed_messages = self.preprocess_messages(messages) try: openai_params = self.get_openai_params(processed_messages, temperature, max_tokens, stop, **kwargs) if self.is_http_provider: response = await self.http_provider.async_call(openai_params) else: response = await self.async_provider.chat.completions.create(**openai_params) if (hasattr(response, 'code') and response.code != 0) or ( isinstance(response, dict) and response.get("code", 0) != 0): error_msg = getattr(response, 'msg', 'Unknown error') logger.warn(f"API Error: {error_msg}") raise LLMResponseError(error_msg, kwargs.get("model_name", self.model_name or "unknown"), response) if not response: raise LLMResponseError("Empty response", kwargs.get("model_name", self.model_name or "unknown")) resp = self.postprocess_response(response) usage_process(resp.usage) return resp except Exception as e: if isinstance(e, LLMResponseError): raise e logger.warn(f"Error in acompletion: {e}") raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) def get_openai_params(self, messages: List[Dict[str, str]], temperature: float = 0.0, max_tokens: int = None, stop: List[str] = None, **kwargs) -> Dict[str, Any]: openai_params = { "model": kwargs.get("model_name", self.model_name or ""), "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stop": stop } supported_params = [ "max_completion_tokens", "meta_data", "modalities", "n", "parallel_tool_calls", "prediction", "reasoning_effort", "service_tier", "stream_options", "web_search_options" "frequency_penalty", "logit_bias", "logprobs", "top_logprobs", "presence_penalty", "response_format", "seed", "stream", "top_p", "user", "function_call", "functions", "tools", "tool_choice" ] for param in supported_params: if param in kwargs: openai_params[param] = kwargs[param] return openai_params def speech_to_text(self, audio_file: str, language: str = None, prompt: str = None, **kwargs) -> ModelResponse: """Convert speech to text. Uses OpenAI's speech-to-text API to convert audio files to text. Args: audio_file: Path to audio file or file object. language: Audio language, optional. prompt: Transcription prompt, optional. **kwargs: Other parameters, may include: - model: Transcription model name, defaults to "whisper-1". - response_format: Response format, defaults to "text". - temperature: Sampling temperature, defaults to 0. Returns: ModelResponse: Unified model response object, with content field containing the transcription result. Raises: LLMResponseError: When LLM response error occurs. """ if not self.provider: raise RuntimeError( "Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") try: # Prepare parameters transcription_params = { "model": kwargs.get("model", "whisper-1"), "response_format": kwargs.get("response_format", "text"), "temperature": kwargs.get("temperature", 0) } # Add optional parameters if language: transcription_params["language"] = language if prompt: transcription_params["prompt"] = prompt # Open file (if path is provided) if isinstance(audio_file, str): with open(audio_file, "rb") as file: transcription_response = self.provider.audio.transcriptions.create( file=file, **transcription_params ) else: # If already a file object transcription_response = self.provider.audio.transcriptions.create( file=audio_file, **transcription_params ) # Create ModelResponse return ModelResponse( id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}", model=transcription_params["model"], content=transcription_response.text if hasattr(transcription_response, 'text') else str( transcription_response), raw_response=transcription_response, message={ "role": "assistant", "content": transcription_response.text if hasattr(transcription_response, 'text') else str( transcription_response) } ) except Exception as e: logger.warn(f"Speech-to-text error: {e}") raise LLMResponseError(str(e), kwargs.get("model", "whisper-1")) async def aspeech_to_text(self, audio_file: str, language: str = None, prompt: str = None, **kwargs) -> ModelResponse: """Asynchronously convert speech to text. Uses OpenAI's speech-to-text API to convert audio files to text. Args: audio_file: Path to audio file or file object. language: Audio language, optional. prompt: Transcription prompt, optional. **kwargs: Other parameters, may include: - model: Transcription model name, defaults to "whisper-1". - response_format: Response format, defaults to "text". - temperature: Sampling temperature, defaults to 0. Returns: ModelResponse: Unified model response object, with content field containing the transcription result. Raises: LLMResponseError: When LLM response error occurs. """ if not self.async_provider: raise RuntimeError( "Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.") try: # Prepare parameters transcription_params = { "model": kwargs.get("model", "whisper-1"), "response_format": kwargs.get("response_format", "text"), "temperature": kwargs.get("temperature", 0) } # Add optional parameters if language: transcription_params["language"] = language if prompt: transcription_params["prompt"] = prompt # Open file (if path is provided) if isinstance(audio_file, str): with open(audio_file, "rb") as file: transcription_response = await self.async_provider.audio.transcriptions.create( file=file, **transcription_params ) else: # If already a file object transcription_response = await self.async_provider.audio.transcriptions.create( file=audio_file, **transcription_params ) # Create ModelResponse return ModelResponse( id=f"stt-{hash(str(transcription_response)) & 0xffffffff:08x}", model=transcription_params["model"], content=transcription_response.text if hasattr(transcription_response, 'text') else str( transcription_response), raw_response=transcription_response, message={ "role": "assistant", "content": transcription_response.text if hasattr(transcription_response, 'text') else str( transcription_response) } ) except Exception as e: logger.warn(f"Async speech-to-text error: {e}") raise LLMResponseError(str(e), kwargs.get("model", "whisper-1")) class AzureOpenAIProvider(OpenAIProvider): """Azure OpenAI provider implementation. """ def _init_provider(self): """Initialize Azure OpenAI provider. Returns: Azure OpenAI provider instance. """ from langchain_openai import AzureChatOpenAI # Get API key api_key = self.api_key if not api_key: env_var = "AZURE_OPENAI_API_KEY" api_key = os.getenv(env_var, "") if not api_key: raise ValueError( f"Azure OpenAI API key not found, please set {env_var} environment variable or provide it in the parameters") # Get API version api_version = self.kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview") # Get endpoint azure_endpoint = self.base_url if not azure_endpoint: azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "") if not azure_endpoint: raise ValueError( "Azure OpenAI endpoint not found, please set AZURE_OPENAI_ENDPOINT environment variable or provide it in the parameters") return AzureChatOpenAI( model=self.model_name or "gpt-4o", temperature=self.kwargs.get("temperature", 0.0), api_version=api_version, azure_endpoint=azure_endpoint, api_key=api_key )