Spaces:
Sleeping
Sleeping
import ast | |
import asyncio | |
import datetime | |
import html | |
import json | |
import os | |
import time | |
from typing import ( | |
Any, | |
List, | |
Dict, | |
Generator, | |
AsyncGenerator, | |
) | |
from binascii import b2a_hex | |
from aworld.config.conf import ClientType | |
from aworld.core.llm_provider_base import LLMProviderBase | |
from aworld.models.llm_http_handler import LLMHTTPHandler | |
from aworld.models.model_response import ModelResponse, LLMResponseError, ToolCall | |
from aworld.logs.util import logger | |
from aworld.utils import import_package | |
from aworld.models.utils import usage_process | |
MODEL_NAMES = { | |
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"], | |
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini"], | |
} | |
# Custom JSON encoder to handle ToolCall and other special types | |
class CustomJSONEncoder(json.JSONEncoder): | |
"""Custom JSON encoder to handle ToolCall objects and other special types.""" | |
def default(self, obj): | |
# Handle objects with to_dict method | |
if hasattr(obj, 'to_dict') and callable(obj.to_dict): | |
return obj.to_dict() | |
# Handle objects with __dict__ attribute (most custom classes) | |
if hasattr(obj, '__dict__'): | |
return obj.__dict__ | |
# Let the base class handle it (will raise TypeError if not serializable) | |
return super().default(obj) | |
class AntProvider(LLMProviderBase): | |
"""Ant provider implementation. | |
""" | |
def _init_provider(self): | |
"""Initialize Ant provider. | |
Returns: | |
Ant provider instance. | |
""" | |
import_package("Crypto", install_name="pycryptodome") | |
# Get API key | |
api_key = self.api_key | |
if not api_key: | |
env_var = "ANT_API_KEY" | |
api_key = os.getenv(env_var, "") | |
self.api_key = api_key | |
if not api_key: | |
raise ValueError( | |
f"ANT API key not found, please set {env_var} environment variable or provide it in the parameters") | |
if api_key and api_key.startswith("ak_info:"): | |
ak_info_str = api_key[len("ak_info:"):] | |
try: | |
ak_info = json.loads(ak_info_str) | |
for key, value in ak_info.items(): | |
os.environ[key] = value | |
if key == "ANT_API_KEY": | |
api_key = value | |
self.api_key = api_key | |
except Exception as e: | |
logger.warn(f"Invalid ANT API key startswith ak_info: {api_key}") | |
self.stream_api_key = os.getenv("ANT_STREAM_API_KEY", "") | |
base_url = self.base_url | |
if not base_url: | |
base_url = os.getenv("ANT_ENDPOINT", "https://zdfmng.alipay.com") | |
self.base_url = base_url | |
self.aes_key = os.getenv("ANT_AES_KEY", "") | |
self.is_http_provider = True | |
self.kwargs["client_type"] = ClientType.HTTP | |
logger.info(f"Using HTTP provider for Ant") | |
self.http_provider = LLMHTTPHandler( | |
base_url=base_url, | |
api_key=api_key, | |
model_name=self.model_name, | |
) | |
self.is_http_provider = True | |
return self.http_provider | |
def _init_async_provider(self): | |
"""Initialize async Ant provider. | |
Returns: | |
Async Ant provider instance. | |
""" | |
# Get API key | |
if not self.provider: | |
provider = self._init_provider() | |
return provider | |
def supported_models(cls) -> list[str]: | |
return [""] | |
def _aes_encrypt(self, data, key): | |
"""AES encryption function. If data is not a multiple of 16 [encrypted data must be a multiple of 16!], pad it to a multiple of 16. | |
Args: | |
key: Encryption key | |
data: Data to encrypt | |
Returns: | |
Encrypted data | |
""" | |
from Crypto.Cipher import AES | |
iv = "1234567890123456" | |
cipher = AES.new(key.encode('utf-8'), AES.MODE_CBC, iv.encode('utf-8')) | |
block_size = AES.block_size | |
# Check if data is a multiple of 16, if not, pad with b'\0' | |
if len(data) % block_size != 0: | |
add = block_size - (len(data) % block_size) | |
else: | |
add = 0 | |
data = data.encode('utf-8') + b'\0' * add | |
encrypted = cipher.encrypt(data) | |
result = b2a_hex(encrypted) | |
return result.decode('utf-8') | |
def _build_openai_params(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> Dict[str, Any]: | |
openai_params = { | |
"model": kwargs.get("model_name", self.model_name or ""), | |
"messages": messages, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"stop": stop | |
} | |
supported_params = [ | |
"frequency_penalty", "logit_bias", "logprobs", "top_logprobs", | |
"presence_penalty", "response_format", "seed", "stream", "top_p", | |
"user", "function_call", "functions", "tools", "tool_choice" | |
] | |
for param in supported_params: | |
if param in kwargs: | |
openai_params[param] = kwargs[param] | |
return openai_params | |
def _build_claude_params(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> Dict[str, Any]: | |
claude_params = { | |
"model": kwargs.get("model_name", self.model_name or ""), | |
"messages": messages, | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"stop": stop | |
} | |
supported_params = [ | |
"top_p", "top_k", "reasoning_effort", "tools", "tool_choice" | |
] | |
for param in supported_params: | |
if param in kwargs: | |
claude_params[param] = kwargs[param] | |
return claude_params | |
def _get_visit_info(self): | |
visit_info = { | |
"visitDomain": self.kwargs.get("ant_visit_domain") or os.getenv("ANT_VISIT_DOMAIN", "BU_general"), | |
"visitBiz": self.kwargs.get("ant_visit_biz") or os.getenv("ANT_VISIT_BIZ", ""), | |
"visitBizLine": self.kwargs.get("ant_visit_biz_line") or os.getenv("ANT_VISIT_BIZ_LINE", "") | |
} | |
if not visit_info["visitBiz"] or not visit_info["visitBizLine"]: | |
return None | |
return visit_info | |
def _get_service_param(self, | |
message_key: str, | |
output_type: str = "request", | |
messages: List[Dict[str, str]] = None, | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs | |
) -> Dict[str, Any]: | |
"""Get service name from model name. | |
Returns: | |
Service name. | |
""" | |
if messages: | |
for message in messages: | |
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"]: | |
if message["content"] is None: message["content"] = "" | |
processed_tool_calls = [] | |
for tool_call in message["tool_calls"]: | |
if isinstance(tool_call, dict): | |
processed_tool_calls.append(tool_call) | |
elif isinstance(tool_call, ToolCall): | |
processed_tool_calls.append(tool_call.to_dict()) | |
message["tool_calls"] = processed_tool_calls | |
query_conditions = { | |
"messageKey": message_key, | |
} | |
param = {"cacheInterval": -1, } | |
visit_info = self._get_visit_info() | |
if not visit_info: | |
raise LLMResponseError( | |
f"AntProvider#Invalid visit_info, please set ANT_VISIT_BIZ and ANT_VISIT_BIZ_LINE environment variable or provide it in the parameters", | |
self.model_name or "unknown" | |
) | |
param.update(visit_info) | |
if self.model_name.startswith("claude"): | |
query_conditions.update(self._build_claude_params(messages, temperature, max_tokens, stop, **kwargs)) | |
param.update({ | |
"serviceName": "amazon_claude_chat_completions_dataview", | |
"queryConditions": query_conditions, | |
}) | |
elif output_type == "pull": | |
param.update({ | |
"serviceName": "chatgpt_response_query_dataview", | |
"queryConditions": query_conditions | |
}) | |
else: | |
query_conditions = { | |
"model": self.model_name, | |
"n": "1", | |
"api_key": self.api_key, | |
"messageKey": message_key, | |
"outputType": "PULL", | |
"messages": messages, | |
} | |
query_conditions.update(self._build_openai_params(messages, temperature, max_tokens, stop, **kwargs)) | |
param.update({ | |
"serviceName": "asyn_chatgpt_prompts_completions_query_dataview", | |
"queryConditions": query_conditions, | |
}) | |
return param | |
def _gen_message_key(self): | |
def _timestamp(): | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
return timestamp | |
timestamp = _timestamp() | |
message_key = "llm_call_%s" % (timestamp) | |
return message_key | |
def _build_request_data(self, param: Dict[str, Any]): | |
param_data = json.dumps(param) | |
encrypted_param_data = self._aes_encrypt(param_data, self.aes_key) | |
post_data = {"encryptedParam": encrypted_param_data} | |
return post_data | |
def _build_chat_query_request_data(self, | |
message_key: str, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs): | |
param = self._get_service_param(message_key, "request", messages, temperature, max_tokens, stop, **kwargs) | |
query_data = self._build_request_data(param) | |
return query_data | |
def _post_chat_query_request(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs): | |
message_key = self._gen_message_key() | |
post_data = self._build_chat_query_request_data(message_key, | |
messages, | |
model_name=self.model_name, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=stop, | |
**kwargs) | |
response = self.http_provider.sync_call(post_data, endpoint="commonQuery/queryData") | |
return message_key, response | |
def _valid_chat_result(self, body): | |
if "data" not in body or not body["data"]: | |
return False | |
if "values" not in body["data"] or not body["data"]["values"]: | |
return False | |
if "response" not in body["data"]["values"] and "data" not in body["data"]["values"]: | |
return False | |
return True | |
def _build_chat_pull_request_data(self, message_key): | |
param = self._get_service_param(message_key, "pull") | |
pull_data = self._build_request_data(param) | |
return pull_data | |
def _pull_chat_result(self, message_key, response: Dict[str, Any], timeout): | |
if self.model_name.startswith("claude"): | |
if self._valid_chat_result(response): | |
x = response["data"]["values"]["data"] | |
ast_str = ast.literal_eval("'" + x + "'") | |
result = html.unescape(ast_str) | |
data = json.loads(result) | |
return data | |
else: | |
raise LLMResponseError( | |
f"Invalid response from Ant API, response: {response}", | |
self.model_name or "unknown" | |
) | |
post_data = self._build_chat_pull_request_data(message_key) | |
url = 'commonQuery/queryData' | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
# Start polling until valid result or timeout | |
start_time = time.time() | |
elapsed_time = 0 | |
while elapsed_time < timeout: | |
response = self.http_provider.sync_call(post_data, endpoint=url, headers=headers) | |
logger.debug(f"Poll attempt at {elapsed_time}s, response: {response}") | |
# Check if valid result is received | |
if self._valid_chat_result(response): | |
x = response["data"]["values"]["response"] | |
ast_str = ast.literal_eval("'" + x + "'") | |
result = html.unescape(ast_str) | |
data = json.loads(result) | |
return data | |
elif (not response.get("success")) or ("data" in response and response["data"]): | |
err_code = response.get("data", {}).get("errorCode", "") | |
err_msg = response.get("data", {}).get("errorMessage", "") | |
if err_code or err_msg: | |
raise LLMResponseError( | |
f"Request failed: {response}", | |
self.model_name or "unknown" | |
) | |
# If no result, wait 1 second and query again | |
time.sleep(1) | |
elapsed_time = time.time() - start_time | |
logger.debug(f"Polling... Elapsed time: {elapsed_time:.1f}s") | |
# Timeout handling | |
raise LLMResponseError( | |
f"Timeout after {timeout} seconds waiting for response from Ant API", | |
self.model_name or "unknown" | |
) | |
async def _async_pull_chat_result(self, message_key, response: Dict[str, Any], timeout): | |
if self.model_name.startswith("claude"): | |
if self._valid_chat_result(response): | |
x = response["data"]["values"]["data"] | |
ast_str = ast.literal_eval("'" + x + "'") | |
result = html.unescape(ast_str) | |
data = json.loads(result) | |
return data | |
elif (not response.get("success")) or ("data" in response and response["data"]): | |
err_code = response.get("data", {}).get("errorCode", "") | |
err_msg = response.get("data", {}).get("errorMessage", "") | |
if err_code or err_msg: | |
raise LLMResponseError( | |
f"Request failed: {response}", | |
self.model_name or "unknown" | |
) | |
post_data = self._build_chat_pull_request_data(message_key) | |
url = 'commonQuery/queryData' | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
# Start polling until valid result or timeout | |
start_time = time.time() | |
elapsed_time = 0 | |
while elapsed_time < timeout: | |
response = await self.http_provider.async_call(post_data, endpoint=url, headers=headers) | |
logger.debug(f"Poll attempt at {elapsed_time}s, response: {response}") | |
# Check if valid result is received | |
if self._valid_chat_result(response): | |
x = response["data"]["values"]["response"] | |
ast_str = ast.literal_eval("'" + x + "'") | |
result = html.unescape(ast_str) | |
data = json.loads(result) | |
return data | |
elif (not response.get("success")) or ("data" in response and response["data"]): | |
err_code = response.get("data", {}).get("errorCode", "") | |
err_msg = response.get("data", {}).get("errorMessage", "") | |
if err_code or err_msg: | |
raise LLMResponseError( | |
f"Request failed: {response}", | |
self.model_name or "unknown" | |
) | |
# If no result, wait 1 second and query again | |
await asyncio.sleep(1) | |
elapsed_time = time.time() - start_time | |
logger.debug(f"Polling... Elapsed time: {elapsed_time:.1f}s") | |
# Timeout handling | |
raise LLMResponseError( | |
f"Timeout after {timeout} seconds waiting for response from Ant API", | |
self.model_name or "unknown" | |
) | |
def _convert_completion_message(self, message: Dict[str, Any], is_finished: bool = False) -> ModelResponse: | |
"""Convert Ant completion message to OpenAI format. | |
Args: | |
message: Ant completion message. | |
Returns: | |
OpenAI format message. | |
""" | |
# Generate unique ID | |
response_id = f"ant-{hash(str(message)) & 0xffffffff:08x}" | |
# Get content | |
content = message.get("completion", "") | |
# Create message object | |
message_dict = { | |
"role": "assistant", | |
"content": content, | |
"is_chunk": True | |
} | |
# Keep original contextId and sessionId | |
if "contextId" in message: | |
message_dict["contextId"] = message["contextId"] | |
if "sessionId" in message: | |
message_dict["sessionId"] = message["sessionId"] | |
usage = { | |
"completion_tokens": message.get("completionToken", 0), | |
"prompt_tokens": message.get("promptTokens", 0), | |
"total_tokens": message.get("completionToken", 0) + message.get("promptTokens", 0) | |
} | |
# process tool calls | |
tool_calls = message.get("toolCalls", []) | |
for tool_call in tool_calls: | |
index = tool_call.get("index", 0) | |
name = tool_call.get("function", {}).get("name") | |
arguments = tool_call.get("function", {}).get("arguments") | |
if index >= len(self.stream_tool_buffer): | |
self.stream_tool_buffer.append({ | |
"id": tool_call.get("id"), | |
"type": "function", | |
"function": { | |
"name": name, | |
"arguments": arguments | |
} | |
}) | |
else: | |
self.stream_tool_buffer[index]["function"]["arguments"] += arguments | |
if is_finished and self.stream_tool_buffer: | |
message_dict["tool_calls"] = self.stream_tool_buffer.copy() | |
processed_tool_calls = [] | |
for tool_call in self.stream_tool_buffer: | |
processed_tool_calls.append(ToolCall.from_dict(tool_call)) | |
tool_resp = ModelResponse( | |
id=response_id, | |
model=self.model_name or "ant", | |
content=content, | |
tool_calls=processed_tool_calls, | |
usage=usage, | |
raw_response=message, | |
message=message_dict | |
) | |
self.stream_tool_buffer = [] | |
return tool_resp | |
# Build and return ModelResponse object directly | |
return ModelResponse( | |
id=response_id, | |
model=self.model_name or "ant", | |
content=content, | |
tool_calls=None, # TODO: add tool calls | |
usage=usage, | |
raw_response=message, | |
message=message_dict | |
) | |
def preprocess_stream_call_message(self, messages: List[Dict[str, str]], ext_params: Dict[str, Any]) -> Dict[ | |
str, str]: | |
"""Preprocess messages, use Ant format directly. | |
Args: | |
messages: Ant format message list. | |
Returns: | |
Processed message list. | |
""" | |
param = { | |
"messages": messages, | |
"sessionId": "TkQUldjzOgYSKyTrpor3TA==", | |
"model": self.model_name, | |
"needMemory": False, | |
"stream": True, | |
"contextId": "contextId_34555fd2d246447fa55a1a259445a427", | |
"platform": "AWorld" | |
} | |
for k in ext_params.keys(): | |
if k not in param: | |
param[k] = ext_params[k] | |
return param | |
def postprocess_response(self, response: Any) -> ModelResponse: | |
"""Process Ant response. | |
Args: | |
response: Ant response object. | |
Returns: | |
ModelResponse object. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if ((not isinstance(response, dict) and (not hasattr(response, 'choices') or not response.choices)) | |
or (isinstance(response, dict) and not response.get("choices"))): | |
error_msg = "" | |
if hasattr(response, 'error') and response.error and isinstance(response.error, dict): | |
error_msg = response.error.get('message', '') | |
elif hasattr(response, 'msg'): | |
error_msg = response.msg | |
raise LLMResponseError( | |
error_msg if error_msg else "Unknown error", | |
self.model_name or "unknown", | |
response | |
) | |
return ModelResponse.from_openai_response(response) | |
def postprocess_stream_response(self, chunk: Any) -> ModelResponse: | |
"""Process Ant stream response chunk. | |
Args: | |
chunk: Ant response chunk. | |
Returns: | |
ModelResponse object. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
# Check if chunk contains error | |
if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')): | |
error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error') | |
raise LLMResponseError( | |
error_msg, | |
self.model_name or "unknown", | |
chunk | |
) | |
if isinstance(chunk, dict) and ('completion' in chunk): | |
return self._convert_completion_message(chunk) | |
# If chunk is already in OpenAI format, use standard processing method | |
return ModelResponse.from_openai_stream_chunk(chunk) | |
def completion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> ModelResponse: | |
"""Synchronously call Ant to generate response. | |
Args: | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse object. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.provider: | |
raise RuntimeError( | |
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") | |
try: | |
start_time = time.time() | |
message_key, response = self._post_chat_query_request(messages, temperature, max_tokens, stop, **kwargs) | |
timeout = kwargs.get("response_timeout", self.kwargs.get("timeout", 180)) | |
result = self._pull_chat_result(message_key, response, timeout) | |
logger.info(f"completion cost time: {time.time() - start_time}s.") | |
resp = self.postprocess_response(result) | |
usage_process(resp.usage) | |
return resp | |
except Exception as e: | |
if isinstance(e, LLMResponseError): | |
raise e | |
logger.warn(f"Error in Ant completion: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
async def acompletion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> ModelResponse: | |
"""Asynchronously call Ant to generate response. | |
Args: | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
ModelResponse object. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.async_provider: | |
self._init_async_provider() | |
start_time = time.time() | |
try: | |
message_key, response = self._post_chat_query_request(messages, temperature, max_tokens, stop, **kwargs) | |
timeout = kwargs.get("response_timeout", self.kwargs.get("timeout", 180)) | |
result = await self._async_pull_chat_result(message_key, response, timeout) | |
logger.info(f"completion cost time: {time.time() - start_time}s.") | |
resp = self.postprocess_response(result) | |
usage_process(resp.usage) | |
return resp | |
except Exception as e: | |
if isinstance(e, LLMResponseError): | |
raise e | |
logger.warn(f"Error in async Ant completion: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
def stream_completion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> Generator[ModelResponse, None, None]: | |
"""Synchronously call Ant to generate streaming response. | |
Args: | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
Generator yielding ModelResponse chunks. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.provider: | |
raise RuntimeError( | |
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.") | |
start_time = time.time() | |
# Generate message_key | |
timestamp = int(time.time()) | |
self.message_key = f"llm_call_{timestamp}" | |
message_key_literal = self.message_key # Ensure it's a direct string literal | |
self.aes_key = kwargs.get("aes_key", self.aes_key) | |
# Add streaming parameter | |
kwargs["stream"] = True | |
processed_messages = self.preprocess_stream_call_message(messages, | |
self._build_openai_params(temperature, max_tokens, | |
stop, **kwargs)) | |
if not processed_messages: | |
raise LLMResponseError("Failed to get post data", self.model_name or "unknown") | |
usage = { | |
"prompt_tokens": 0, | |
"completion_tokens": 0, | |
"total_tokens": 0 | |
} | |
try: | |
# Send request | |
# response = self.http_provider.sync_call(processed_messages[0], endpoint="commonQuery/queryData") | |
headers = { | |
"Content-Type": "application/json", | |
"X_ACCESS_KEY": self.stream_api_key | |
} | |
response_stream = self.http_provider.sync_stream_call(processed_messages, endpoint="chat/completions", | |
headers=headers) | |
if response_stream: | |
for chunk in response_stream: | |
if not chunk: | |
continue | |
# Process special markers | |
if isinstance(chunk, dict) and "status" in chunk: | |
if chunk["status"] == "done": | |
# Stream completion marker, can choose to end | |
logger.info("Received [DONE] marker, stream completed") | |
yield self._convert_completion_message(chunk, is_finished=True) | |
yield ModelResponse.from_special_marker("done", self.model_name, chunk) | |
break | |
elif chunk["status"] == "revoke": | |
# Revoke marker, need to notify the frontend to revoke the displayed content | |
logger.info("Received [REVOKE] marker, content should be revoked") | |
yield ModelResponse.from_special_marker("revoke", self.model_name, chunk) | |
continue | |
elif chunk["status"] == "fail": | |
# Fail marker | |
logger.error("Received [FAIL] marker, request failed") | |
raise LLMResponseError("Request failed", self.model_name or "unknown") | |
elif chunk["status"] == "cancel": | |
# Request was cancelled | |
logger.warning("Received [CANCEL] marker, stream was cancelled") | |
raise LLMResponseError("Stream was cancelled", self.model_name or "unknown") | |
continue | |
# Process normal response chunks | |
resp = self.postprocess_stream_response(chunk) | |
self._accumulate_chunk_usage(usage, resp.usage) | |
yield resp | |
usage_process(usage) | |
logger.info(f"stream_completion cost time: {time.time() - start_time}s.") | |
except Exception as e: | |
if isinstance(e, LLMResponseError): | |
raise e | |
logger.error(f"Error in Ant stream completion: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |
async def astream_completion(self, | |
messages: List[Dict[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = None, | |
stop: List[str] = None, | |
**kwargs) -> AsyncGenerator[ModelResponse, None]: | |
"""Asynchronously call Ant to generate streaming response. | |
Args: | |
messages: Message list. | |
temperature: Temperature parameter. | |
max_tokens: Maximum number of tokens to generate. | |
stop: List of stop sequences. | |
**kwargs: Other parameters. | |
Returns: | |
AsyncGenerator yielding ModelResponse chunks. | |
Raises: | |
LLMResponseError: When LLM response error occurs. | |
""" | |
if not self.async_provider: | |
self._init_async_provider() | |
start_time = time.time() | |
# Generate message_key | |
timestamp = int(time.time()) | |
self.message_key = f"llm_call_{timestamp}" | |
message_key_literal = self.message_key # Ensure it's a direct string literal | |
self.aes_key = kwargs.get("aes_key", self.aes_key) | |
# Add streaming parameter | |
kwargs["stream"] = True | |
processed_messages = self.preprocess_stream_call_message(messages, | |
self._build_openai_params(temperature, max_tokens, | |
stop, **kwargs)) | |
if not processed_messages: | |
raise LLMResponseError("Failed to get post data", self.model_name or "unknown") | |
usage = { | |
"prompt_tokens": 0, | |
"completion_tokens": 0, | |
"total_tokens": 0 | |
} | |
try: | |
headers = { | |
"Content-Type": "application/json", | |
"X_ACCESS_KEY": self.stream_api_key | |
} | |
logger.info(f"astream_completion request data: {processed_messages}") | |
async for chunk in self.http_provider.async_stream_call(processed_messages, endpoint="chat/completions", | |
headers=headers): | |
if not chunk: | |
continue | |
# Process special markers | |
if isinstance(chunk, dict) and "status" in chunk: | |
if chunk["status"] == "done": | |
# Stream completion marker, can choose to end | |
logger.info("Received [DONE] marker, stream completed") | |
yield ModelResponse.from_special_marker("done", self.model_name, chunk) | |
break | |
elif chunk["status"] == "revoke": | |
# Revoke marker, need to notify the frontend to revoke the displayed content | |
logger.info("Received [REVOKE] marker, content should be revoked") | |
yield ModelResponse.from_special_marker("revoke", self.model_name, chunk) | |
continue | |
elif chunk["status"] == "fail": | |
# Fail marker | |
logger.error("Received [FAIL] marker, request failed") | |
raise LLMResponseError("Request failed", self.model_name or "unknown") | |
elif chunk["status"] == "cancel": | |
# Request was cancelled | |
logger.warning("Received [CANCEL] marker, stream was cancelled") | |
raise LLMResponseError("Stream was cancelled", self.model_name or "unknown") | |
continue | |
# Process normal response chunks | |
resp = self.postprocess_stream_response(chunk) | |
self._accumulate_chunk_usage(usage, resp.usage) | |
yield resp | |
usage_process(usage) | |
logger.info(f"astream_completion cost time: {time.time() - start_time}s.") | |
except Exception as e: | |
if isinstance(e, LLMResponseError): | |
raise e | |
logger.warn(f"Error in async Ant stream completion: {e}") | |
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown")) | |