Duibonduil's picture
Upload 10 files
ae64487 verified
import ast
import asyncio
import datetime
import html
import json
import os
import time
from typing import (
Any,
List,
Dict,
Generator,
AsyncGenerator,
)
from binascii import b2a_hex
from aworld.config.conf import ClientType
from aworld.core.llm_provider_base import LLMProviderBase
from aworld.models.llm_http_handler import LLMHTTPHandler
from aworld.models.model_response import ModelResponse, LLMResponseError, ToolCall
from aworld.logs.util import logger
from aworld.utils import import_package
from aworld.models.utils import usage_process
MODEL_NAMES = {
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini"],
}
# Custom JSON encoder to handle ToolCall and other special types
class CustomJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder to handle ToolCall objects and other special types."""
def default(self, obj):
# Handle objects with to_dict method
if hasattr(obj, 'to_dict') and callable(obj.to_dict):
return obj.to_dict()
# Handle objects with __dict__ attribute (most custom classes)
if hasattr(obj, '__dict__'):
return obj.__dict__
# Let the base class handle it (will raise TypeError if not serializable)
return super().default(obj)
class AntProvider(LLMProviderBase):
"""Ant provider implementation.
"""
def _init_provider(self):
"""Initialize Ant provider.
Returns:
Ant provider instance.
"""
import_package("Crypto", install_name="pycryptodome")
# Get API key
api_key = self.api_key
if not api_key:
env_var = "ANT_API_KEY"
api_key = os.getenv(env_var, "")
self.api_key = api_key
if not api_key:
raise ValueError(
f"ANT API key not found, please set {env_var} environment variable or provide it in the parameters")
if api_key and api_key.startswith("ak_info:"):
ak_info_str = api_key[len("ak_info:"):]
try:
ak_info = json.loads(ak_info_str)
for key, value in ak_info.items():
os.environ[key] = value
if key == "ANT_API_KEY":
api_key = value
self.api_key = api_key
except Exception as e:
logger.warn(f"Invalid ANT API key startswith ak_info: {api_key}")
self.stream_api_key = os.getenv("ANT_STREAM_API_KEY", "")
base_url = self.base_url
if not base_url:
base_url = os.getenv("ANT_ENDPOINT", "https://zdfmng.alipay.com")
self.base_url = base_url
self.aes_key = os.getenv("ANT_AES_KEY", "")
self.is_http_provider = True
self.kwargs["client_type"] = ClientType.HTTP
logger.info(f"Using HTTP provider for Ant")
self.http_provider = LLMHTTPHandler(
base_url=base_url,
api_key=api_key,
model_name=self.model_name,
)
self.is_http_provider = True
return self.http_provider
def _init_async_provider(self):
"""Initialize async Ant provider.
Returns:
Async Ant provider instance.
"""
# Get API key
if not self.provider:
provider = self._init_provider()
return provider
@classmethod
def supported_models(cls) -> list[str]:
return [""]
def _aes_encrypt(self, data, key):
"""AES encryption function. If data is not a multiple of 16 [encrypted data must be a multiple of 16!], pad it to a multiple of 16.
Args:
key: Encryption key
data: Data to encrypt
Returns:
Encrypted data
"""
from Crypto.Cipher import AES
iv = "1234567890123456"
cipher = AES.new(key.encode('utf-8'), AES.MODE_CBC, iv.encode('utf-8'))
block_size = AES.block_size
# Check if data is a multiple of 16, if not, pad with b'\0'
if len(data) % block_size != 0:
add = block_size - (len(data) % block_size)
else:
add = 0
data = data.encode('utf-8') + b'\0' * add
encrypted = cipher.encrypt(data)
result = b2a_hex(encrypted)
return result.decode('utf-8')
def _build_openai_params(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> Dict[str, Any]:
openai_params = {
"model": kwargs.get("model_name", self.model_name or ""),
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop
}
supported_params = [
"frequency_penalty", "logit_bias", "logprobs", "top_logprobs",
"presence_penalty", "response_format", "seed", "stream", "top_p",
"user", "function_call", "functions", "tools", "tool_choice"
]
for param in supported_params:
if param in kwargs:
openai_params[param] = kwargs[param]
return openai_params
def _build_claude_params(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> Dict[str, Any]:
claude_params = {
"model": kwargs.get("model_name", self.model_name or ""),
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop
}
supported_params = [
"top_p", "top_k", "reasoning_effort", "tools", "tool_choice"
]
for param in supported_params:
if param in kwargs:
claude_params[param] = kwargs[param]
return claude_params
def _get_visit_info(self):
visit_info = {
"visitDomain": self.kwargs.get("ant_visit_domain") or os.getenv("ANT_VISIT_DOMAIN", "BU_general"),
"visitBiz": self.kwargs.get("ant_visit_biz") or os.getenv("ANT_VISIT_BIZ", ""),
"visitBizLine": self.kwargs.get("ant_visit_biz_line") or os.getenv("ANT_VISIT_BIZ_LINE", "")
}
if not visit_info["visitBiz"] or not visit_info["visitBizLine"]:
return None
return visit_info
def _get_service_param(self,
message_key: str,
output_type: str = "request",
messages: List[Dict[str, str]] = None,
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs
) -> Dict[str, Any]:
"""Get service name from model name.
Returns:
Service name.
"""
if messages:
for message in messages:
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"]:
if message["content"] is None: message["content"] = ""
processed_tool_calls = []
for tool_call in message["tool_calls"]:
if isinstance(tool_call, dict):
processed_tool_calls.append(tool_call)
elif isinstance(tool_call, ToolCall):
processed_tool_calls.append(tool_call.to_dict())
message["tool_calls"] = processed_tool_calls
query_conditions = {
"messageKey": message_key,
}
param = {"cacheInterval": -1, }
visit_info = self._get_visit_info()
if not visit_info:
raise LLMResponseError(
f"AntProvider#Invalid visit_info, please set ANT_VISIT_BIZ and ANT_VISIT_BIZ_LINE environment variable or provide it in the parameters",
self.model_name or "unknown"
)
param.update(visit_info)
if self.model_name.startswith("claude"):
query_conditions.update(self._build_claude_params(messages, temperature, max_tokens, stop, **kwargs))
param.update({
"serviceName": "amazon_claude_chat_completions_dataview",
"queryConditions": query_conditions,
})
elif output_type == "pull":
param.update({
"serviceName": "chatgpt_response_query_dataview",
"queryConditions": query_conditions
})
else:
query_conditions = {
"model": self.model_name,
"n": "1",
"api_key": self.api_key,
"messageKey": message_key,
"outputType": "PULL",
"messages": messages,
}
query_conditions.update(self._build_openai_params(messages, temperature, max_tokens, stop, **kwargs))
param.update({
"serviceName": "asyn_chatgpt_prompts_completions_query_dataview",
"queryConditions": query_conditions,
})
return param
def _gen_message_key(self):
def _timestamp():
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
return timestamp
timestamp = _timestamp()
message_key = "llm_call_%s" % (timestamp)
return message_key
def _build_request_data(self, param: Dict[str, Any]):
param_data = json.dumps(param)
encrypted_param_data = self._aes_encrypt(param_data, self.aes_key)
post_data = {"encryptedParam": encrypted_param_data}
return post_data
def _build_chat_query_request_data(self,
message_key: str,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs):
param = self._get_service_param(message_key, "request", messages, temperature, max_tokens, stop, **kwargs)
query_data = self._build_request_data(param)
return query_data
def _post_chat_query_request(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs):
message_key = self._gen_message_key()
post_data = self._build_chat_query_request_data(message_key,
messages,
model_name=self.model_name,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
**kwargs)
response = self.http_provider.sync_call(post_data, endpoint="commonQuery/queryData")
return message_key, response
def _valid_chat_result(self, body):
if "data" not in body or not body["data"]:
return False
if "values" not in body["data"] or not body["data"]["values"]:
return False
if "response" not in body["data"]["values"] and "data" not in body["data"]["values"]:
return False
return True
def _build_chat_pull_request_data(self, message_key):
param = self._get_service_param(message_key, "pull")
pull_data = self._build_request_data(param)
return pull_data
def _pull_chat_result(self, message_key, response: Dict[str, Any], timeout):
if self.model_name.startswith("claude"):
if self._valid_chat_result(response):
x = response["data"]["values"]["data"]
ast_str = ast.literal_eval("'" + x + "'")
result = html.unescape(ast_str)
data = json.loads(result)
return data
else:
raise LLMResponseError(
f"Invalid response from Ant API, response: {response}",
self.model_name or "unknown"
)
post_data = self._build_chat_pull_request_data(message_key)
url = 'commonQuery/queryData'
headers = {
'Content-Type': 'application/json'
}
# Start polling until valid result or timeout
start_time = time.time()
elapsed_time = 0
while elapsed_time < timeout:
response = self.http_provider.sync_call(post_data, endpoint=url, headers=headers)
logger.debug(f"Poll attempt at {elapsed_time}s, response: {response}")
# Check if valid result is received
if self._valid_chat_result(response):
x = response["data"]["values"]["response"]
ast_str = ast.literal_eval("'" + x + "'")
result = html.unescape(ast_str)
data = json.loads(result)
return data
elif (not response.get("success")) or ("data" in response and response["data"]):
err_code = response.get("data", {}).get("errorCode", "")
err_msg = response.get("data", {}).get("errorMessage", "")
if err_code or err_msg:
raise LLMResponseError(
f"Request failed: {response}",
self.model_name or "unknown"
)
# If no result, wait 1 second and query again
time.sleep(1)
elapsed_time = time.time() - start_time
logger.debug(f"Polling... Elapsed time: {elapsed_time:.1f}s")
# Timeout handling
raise LLMResponseError(
f"Timeout after {timeout} seconds waiting for response from Ant API",
self.model_name or "unknown"
)
async def _async_pull_chat_result(self, message_key, response: Dict[str, Any], timeout):
if self.model_name.startswith("claude"):
if self._valid_chat_result(response):
x = response["data"]["values"]["data"]
ast_str = ast.literal_eval("'" + x + "'")
result = html.unescape(ast_str)
data = json.loads(result)
return data
elif (not response.get("success")) or ("data" in response and response["data"]):
err_code = response.get("data", {}).get("errorCode", "")
err_msg = response.get("data", {}).get("errorMessage", "")
if err_code or err_msg:
raise LLMResponseError(
f"Request failed: {response}",
self.model_name or "unknown"
)
post_data = self._build_chat_pull_request_data(message_key)
url = 'commonQuery/queryData'
headers = {
'Content-Type': 'application/json'
}
# Start polling until valid result or timeout
start_time = time.time()
elapsed_time = 0
while elapsed_time < timeout:
response = await self.http_provider.async_call(post_data, endpoint=url, headers=headers)
logger.debug(f"Poll attempt at {elapsed_time}s, response: {response}")
# Check if valid result is received
if self._valid_chat_result(response):
x = response["data"]["values"]["response"]
ast_str = ast.literal_eval("'" + x + "'")
result = html.unescape(ast_str)
data = json.loads(result)
return data
elif (not response.get("success")) or ("data" in response and response["data"]):
err_code = response.get("data", {}).get("errorCode", "")
err_msg = response.get("data", {}).get("errorMessage", "")
if err_code or err_msg:
raise LLMResponseError(
f"Request failed: {response}",
self.model_name or "unknown"
)
# If no result, wait 1 second and query again
await asyncio.sleep(1)
elapsed_time = time.time() - start_time
logger.debug(f"Polling... Elapsed time: {elapsed_time:.1f}s")
# Timeout handling
raise LLMResponseError(
f"Timeout after {timeout} seconds waiting for response from Ant API",
self.model_name or "unknown"
)
def _convert_completion_message(self, message: Dict[str, Any], is_finished: bool = False) -> ModelResponse:
"""Convert Ant completion message to OpenAI format.
Args:
message: Ant completion message.
Returns:
OpenAI format message.
"""
# Generate unique ID
response_id = f"ant-{hash(str(message)) & 0xffffffff:08x}"
# Get content
content = message.get("completion", "")
# Create message object
message_dict = {
"role": "assistant",
"content": content,
"is_chunk": True
}
# Keep original contextId and sessionId
if "contextId" in message:
message_dict["contextId"] = message["contextId"]
if "sessionId" in message:
message_dict["sessionId"] = message["sessionId"]
usage = {
"completion_tokens": message.get("completionToken", 0),
"prompt_tokens": message.get("promptTokens", 0),
"total_tokens": message.get("completionToken", 0) + message.get("promptTokens", 0)
}
# process tool calls
tool_calls = message.get("toolCalls", [])
for tool_call in tool_calls:
index = tool_call.get("index", 0)
name = tool_call.get("function", {}).get("name")
arguments = tool_call.get("function", {}).get("arguments")
if index >= len(self.stream_tool_buffer):
self.stream_tool_buffer.append({
"id": tool_call.get("id"),
"type": "function",
"function": {
"name": name,
"arguments": arguments
}
})
else:
self.stream_tool_buffer[index]["function"]["arguments"] += arguments
if is_finished and self.stream_tool_buffer:
message_dict["tool_calls"] = self.stream_tool_buffer.copy()
processed_tool_calls = []
for tool_call in self.stream_tool_buffer:
processed_tool_calls.append(ToolCall.from_dict(tool_call))
tool_resp = ModelResponse(
id=response_id,
model=self.model_name or "ant",
content=content,
tool_calls=processed_tool_calls,
usage=usage,
raw_response=message,
message=message_dict
)
self.stream_tool_buffer = []
return tool_resp
# Build and return ModelResponse object directly
return ModelResponse(
id=response_id,
model=self.model_name or "ant",
content=content,
tool_calls=None, # TODO: add tool calls
usage=usage,
raw_response=message,
message=message_dict
)
def preprocess_stream_call_message(self, messages: List[Dict[str, str]], ext_params: Dict[str, Any]) -> Dict[
str, str]:
"""Preprocess messages, use Ant format directly.
Args:
messages: Ant format message list.
Returns:
Processed message list.
"""
param = {
"messages": messages,
"sessionId": "TkQUldjzOgYSKyTrpor3TA==",
"model": self.model_name,
"needMemory": False,
"stream": True,
"contextId": "contextId_34555fd2d246447fa55a1a259445a427",
"platform": "AWorld"
}
for k in ext_params.keys():
if k not in param:
param[k] = ext_params[k]
return param
def postprocess_response(self, response: Any) -> ModelResponse:
"""Process Ant response.
Args:
response: Ant response object.
Returns:
ModelResponse object.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if ((not isinstance(response, dict) and (not hasattr(response, 'choices') or not response.choices))
or (isinstance(response, dict) and not response.get("choices"))):
error_msg = ""
if hasattr(response, 'error') and response.error and isinstance(response.error, dict):
error_msg = response.error.get('message', '')
elif hasattr(response, 'msg'):
error_msg = response.msg
raise LLMResponseError(
error_msg if error_msg else "Unknown error",
self.model_name or "unknown",
response
)
return ModelResponse.from_openai_response(response)
def postprocess_stream_response(self, chunk: Any) -> ModelResponse:
"""Process Ant stream response chunk.
Args:
chunk: Ant response chunk.
Returns:
ModelResponse object.
Raises:
LLMResponseError: When LLM response error occurs.
"""
# Check if chunk contains error
if hasattr(chunk, 'error') or (isinstance(chunk, dict) and chunk.get('error')):
error_msg = chunk.error if hasattr(chunk, 'error') else chunk.get('error', 'Unknown error')
raise LLMResponseError(
error_msg,
self.model_name or "unknown",
chunk
)
if isinstance(chunk, dict) and ('completion' in chunk):
return self._convert_completion_message(chunk)
# If chunk is already in OpenAI format, use standard processing method
return ModelResponse.from_openai_stream_chunk(chunk)
def completion(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> ModelResponse:
"""Synchronously call Ant to generate response.
Args:
messages: Message list.
temperature: Temperature parameter.
max_tokens: Maximum number of tokens to generate.
stop: List of stop sequences.
**kwargs: Other parameters.
Returns:
ModelResponse object.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.provider:
raise RuntimeError(
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
try:
start_time = time.time()
message_key, response = self._post_chat_query_request(messages, temperature, max_tokens, stop, **kwargs)
timeout = kwargs.get("response_timeout", self.kwargs.get("timeout", 180))
result = self._pull_chat_result(message_key, response, timeout)
logger.info(f"completion cost time: {time.time() - start_time}s.")
resp = self.postprocess_response(result)
usage_process(resp.usage)
return resp
except Exception as e:
if isinstance(e, LLMResponseError):
raise e
logger.warn(f"Error in Ant completion: {e}")
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
async def acompletion(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> ModelResponse:
"""Asynchronously call Ant to generate response.
Args:
messages: Message list.
temperature: Temperature parameter.
max_tokens: Maximum number of tokens to generate.
stop: List of stop sequences.
**kwargs: Other parameters.
Returns:
ModelResponse object.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.async_provider:
self._init_async_provider()
start_time = time.time()
try:
message_key, response = self._post_chat_query_request(messages, temperature, max_tokens, stop, **kwargs)
timeout = kwargs.get("response_timeout", self.kwargs.get("timeout", 180))
result = await self._async_pull_chat_result(message_key, response, timeout)
logger.info(f"completion cost time: {time.time() - start_time}s.")
resp = self.postprocess_response(result)
usage_process(resp.usage)
return resp
except Exception as e:
if isinstance(e, LLMResponseError):
raise e
logger.warn(f"Error in async Ant completion: {e}")
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
def stream_completion(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> Generator[ModelResponse, None, None]:
"""Synchronously call Ant to generate streaming response.
Args:
messages: Message list.
temperature: Temperature parameter.
max_tokens: Maximum number of tokens to generate.
stop: List of stop sequences.
**kwargs: Other parameters.
Returns:
Generator yielding ModelResponse chunks.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.provider:
raise RuntimeError(
"Sync provider not initialized. Make sure 'sync_enabled' parameter is set to True in initialization.")
start_time = time.time()
# Generate message_key
timestamp = int(time.time())
self.message_key = f"llm_call_{timestamp}"
message_key_literal = self.message_key # Ensure it's a direct string literal
self.aes_key = kwargs.get("aes_key", self.aes_key)
# Add streaming parameter
kwargs["stream"] = True
processed_messages = self.preprocess_stream_call_message(messages,
self._build_openai_params(temperature, max_tokens,
stop, **kwargs))
if not processed_messages:
raise LLMResponseError("Failed to get post data", self.model_name or "unknown")
usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
try:
# Send request
# response = self.http_provider.sync_call(processed_messages[0], endpoint="commonQuery/queryData")
headers = {
"Content-Type": "application/json",
"X_ACCESS_KEY": self.stream_api_key
}
response_stream = self.http_provider.sync_stream_call(processed_messages, endpoint="chat/completions",
headers=headers)
if response_stream:
for chunk in response_stream:
if not chunk:
continue
# Process special markers
if isinstance(chunk, dict) and "status" in chunk:
if chunk["status"] == "done":
# Stream completion marker, can choose to end
logger.info("Received [DONE] marker, stream completed")
yield self._convert_completion_message(chunk, is_finished=True)
yield ModelResponse.from_special_marker("done", self.model_name, chunk)
break
elif chunk["status"] == "revoke":
# Revoke marker, need to notify the frontend to revoke the displayed content
logger.info("Received [REVOKE] marker, content should be revoked")
yield ModelResponse.from_special_marker("revoke", self.model_name, chunk)
continue
elif chunk["status"] == "fail":
# Fail marker
logger.error("Received [FAIL] marker, request failed")
raise LLMResponseError("Request failed", self.model_name or "unknown")
elif chunk["status"] == "cancel":
# Request was cancelled
logger.warning("Received [CANCEL] marker, stream was cancelled")
raise LLMResponseError("Stream was cancelled", self.model_name or "unknown")
continue
# Process normal response chunks
resp = self.postprocess_stream_response(chunk)
self._accumulate_chunk_usage(usage, resp.usage)
yield resp
usage_process(usage)
logger.info(f"stream_completion cost time: {time.time() - start_time}s.")
except Exception as e:
if isinstance(e, LLMResponseError):
raise e
logger.error(f"Error in Ant stream completion: {e}")
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))
async def astream_completion(self,
messages: List[Dict[str, str]],
temperature: float = 0.0,
max_tokens: int = None,
stop: List[str] = None,
**kwargs) -> AsyncGenerator[ModelResponse, None]:
"""Asynchronously call Ant to generate streaming response.
Args:
messages: Message list.
temperature: Temperature parameter.
max_tokens: Maximum number of tokens to generate.
stop: List of stop sequences.
**kwargs: Other parameters.
Returns:
AsyncGenerator yielding ModelResponse chunks.
Raises:
LLMResponseError: When LLM response error occurs.
"""
if not self.async_provider:
self._init_async_provider()
start_time = time.time()
# Generate message_key
timestamp = int(time.time())
self.message_key = f"llm_call_{timestamp}"
message_key_literal = self.message_key # Ensure it's a direct string literal
self.aes_key = kwargs.get("aes_key", self.aes_key)
# Add streaming parameter
kwargs["stream"] = True
processed_messages = self.preprocess_stream_call_message(messages,
self._build_openai_params(temperature, max_tokens,
stop, **kwargs))
if not processed_messages:
raise LLMResponseError("Failed to get post data", self.model_name or "unknown")
usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
try:
headers = {
"Content-Type": "application/json",
"X_ACCESS_KEY": self.stream_api_key
}
logger.info(f"astream_completion request data: {processed_messages}")
async for chunk in self.http_provider.async_stream_call(processed_messages, endpoint="chat/completions",
headers=headers):
if not chunk:
continue
# Process special markers
if isinstance(chunk, dict) and "status" in chunk:
if chunk["status"] == "done":
# Stream completion marker, can choose to end
logger.info("Received [DONE] marker, stream completed")
yield ModelResponse.from_special_marker("done", self.model_name, chunk)
break
elif chunk["status"] == "revoke":
# Revoke marker, need to notify the frontend to revoke the displayed content
logger.info("Received [REVOKE] marker, content should be revoked")
yield ModelResponse.from_special_marker("revoke", self.model_name, chunk)
continue
elif chunk["status"] == "fail":
# Fail marker
logger.error("Received [FAIL] marker, request failed")
raise LLMResponseError("Request failed", self.model_name or "unknown")
elif chunk["status"] == "cancel":
# Request was cancelled
logger.warning("Received [CANCEL] marker, stream was cancelled")
raise LLMResponseError("Stream was cancelled", self.model_name or "unknown")
continue
# Process normal response chunks
resp = self.postprocess_stream_response(chunk)
self._accumulate_chunk_usage(usage, resp.usage)
yield resp
usage_process(usage)
logger.info(f"astream_completion cost time: {time.time() - start_time}s.")
except Exception as e:
if isinstance(e, LLMResponseError):
raise e
logger.warn(f"Error in async Ant stream completion: {e}")
raise LLMResponseError(str(e), kwargs.get("model_name", self.model_name or "unknown"))