arena / completion.py
terryyz
update
9f4f2cd
raw
history blame
37 kB
import os
import json
import time
import yaml
import random
import shortuuid
import requests
from typing import Optional
import boto3
from glob import glob
from tqdm import tqdm
# API setting constants
API_MAX_RETRY = 50
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = None
registered_api_completion = {}
registered_engine_completion = {}
def register_api(api_type):
def decorator(func):
registered_api_completion[api_type] = func
return func
return decorator
def register_engine(engine_type):
def decorator(func):
registered_engine_completion[engine_type] = func
return func
return decorator
def load_questions(question_file: str):
"""Load questions from a file."""
questions = []
with open(question_file, "r") as ques_file:
for line in ques_file:
if line:
questions.append(json.loads(line))
return questions
def load_model_answers(answer_dir: str):
"""Load model answers.
The return value is a python dict of type:
Dict[model_name: str -> Dict[uid: int -> answer: dict]]
"""
if not os.path.exists(answer_dir):
return {}
filenames = []
for folder in os.listdir(answer_dir):
if not os.path.isdir(os.path.join(answer_dir, folder)):
continue
if not os.path.exists(os.path.join(answer_dir, folder, "generation.jsonl")):
continue
filenames.append(os.path.join(answer_dir, folder, "generation.jsonl"))
filenames.sort()
model_answers = {}
for filename in filenames:
# Use parent directory name as model name
model_name = os.path.basename(os.path.dirname(filename))
answer = {}
with open(filename) as fin:
for line in fin:
line = json.loads(line)
answer[line["uid"]] = line
model_answers[model_name] = answer
return model_answers
def load_model_judgements(answer_dir: str):
"""Load model judgements.
The return value is a python dict of type:
Dict[model_name: str -> Dict[uid: int -> answer: dict]]
"""
filenames = glob(os.path.join(answer_dir, "*.jsonl"))
filenames.sort()
model_answers = {}
for filename in filenames:
model_name = os.path.basename(filename)[:-6]
answer = {}
with open(filename) as fin:
for line in fin:
line = json.loads(line)
answer[line["uid"]] = line
model_answers[model_name] = answer
return model_answers
def load_model_answers_and_execution_results(data_dir: str):
"""Load model answers and execution results.
The return value is a python dict of type:
Dict[model_name: str -> Dict[uid: int -> answer: dict]]
"""
filenames = []
for folder in os.listdir(data_dir):
if not os.path.isdir(os.path.join(data_dir, folder)):
continue
if not os.path.exists(os.path.join(data_dir, folder, "execution_results.jsonl")):
continue
filenames.append(os.path.join(data_dir, folder, "execution_results.jsonl"))
filenames.sort()
model_answers = {}
for filename in filenames:
# Use parent directory name as model name
model_name = os.path.basename(os.path.dirname(filename))
answer = {}
with open(filename) as fin:
for line in fin:
line = json.loads(line)
answer[line["uid"]] = line
model_answers[model_name] = answer
return model_answers
def load_id_to_model_answers(answer_dir: str):
"""Load model answers.
The return value is a python dict of type:
Dict[model_name: str -> Dict[uid: int -> answer: dict]]
"""
filenames = glob(os.path.join(answer_dir, "*.jsonl"))
filenames.sort()
model_answers = {}
for filename in filenames:
model_name = os.path.basename(filename)[:-6]
with open(filename) as fin:
for line in fin:
line = json.loads(line)
if line["uid"] in model_answers:
model_answers[line["uid"]][model_name] = line
else:
model_answers[line["uid"]] = {model_name: line}
return model_answers
def get_endpoint(endpoint_list):
if endpoint_list is None:
return None
assert endpoint_list is not None
# randomly pick one
api_dict = random.choices(
endpoint_list
)[0]
return api_dict
# load config args from config yaml files
def make_config(config_file: str) -> dict:
with open(config_file, "r") as f:
config_kwargs = yaml.safe_load(os.path.expandvars(f.read()))
return config_kwargs
@register_api("openai")
def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
import openai
if api_dict:
client = openai.OpenAI(
base_url=api_dict["api_base"],
api_key=api_dict["api_key"],
)
else:
client = openai.OpenAI()
if api_dict and "model_name" in api_dict:
model = api_dict["model_name"]
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
completion = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
output = {
"answer": completion.choices[0].message.content
}
break
except openai.RateLimitError as e:
time.sleep(API_RETRY_SLEEP)
except openai.BadRequestError as e:
break
except KeyError:
break
return output
@register_api("openai_streaming")
def chat_completion_openai_streaming(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
"""Streaming version of OpenAI completion that yields tokens as they arrive"""
import openai
if api_dict:
client = openai.OpenAI(
base_url=api_dict["api_base"],
api_key=api_dict["api_key"],
)
else:
client = openai.OpenAI()
if api_dict and "model_name" in api_dict:
model = api_dict["model_name"]
try:
stream = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except Exception as e:
print(f"Error in streaming completion: {e}")
yield f"Error: {str(e)}"
@register_api("openai_thinking")
def chat_completion_openai_thinking(model, messages, api_dict=None, **kwargs):
import openai
if api_dict:
client = openai.OpenAI(
api_key=api_dict["api_key"],
base_url=api_dict["api_base"],
)
else:
client = openai.OpenAI()
output = API_ERROR_OUTPUT
for i in range(API_MAX_RETRY):
try:
completion = client.chat.completions.create(
model=model,
messages=messages,
reasoning_effort=kwargs['reasoning_effort'] if 'reasoning_effort' in kwargs else 'medium',
)
output = {
"answer": completion.choices[0].message.content
}
break
except openai.RateLimitError as e:
time.sleep(API_RETRY_SLEEP)
except openai.BadRequestError as e:
break
except KeyError:
break
return output
@register_api("deepseek_reasoner")
def chat_completion_deepseek_reasoner(messages, api_dict, **kwargs):
import urllib.request
chat_endpoint_headers = {
"User-Agent": "curl/8.7.1",
"Authorization": "Bearer {}".format(api_dict['api_key']),
"Content-Type": "application/json",
"Accept": "application/json",
}
chat_endpoint_url = "https://api.deepseek.com/chat/completions"
req_body = {
"messages": messages,
"model": "deepseek-reasoner",
"stream": False,
}
req_data = json.dumps(req_body).encode("utf-8")
output = API_ERROR_OUTPUT
for i in range(API_MAX_RETRY):
try:
req = urllib.request.Request(
chat_endpoint_url,
headers = chat_endpoint_headers.copy(),
data = req_data,
)
with urllib.request.urlopen(req) as res:
res_data = res.read()
res_body = json.loads(res_data.decode("utf-8"))
output = {
"thought": res_body["choices"][0]["message"]["reasoning_content"],
"answer": res_body["choices"][0]["message"]["content"],
}
break
except Exception as e:
time.sleep(API_RETRY_SLEEP)
return output
@register_api("deepseek")
def chat_completion_deepseek(messages, max_tokens, api_dict, **kwargs):
import urllib.request
chat_endpoint_headers = {
"User-Agent": "curl/8.7.1",
"Authorization": "Bearer {}".format(api_dict['api_key']),
"Content-Type": "application/json",
"Accept": "application/json",
}
chat_endpoint_url = "https://api.deepseek.com/chat/completions"
req_body = {
"messages": messages,
"model": "deepseek-chat",
"stream": False,
"max_tokens": max_tokens,
}
req_data = json.dumps(req_body).encode("utf-8")
output = API_ERROR_OUTPUT
for i in range(API_MAX_RETRY):
try:
req = urllib.request.Request(
chat_endpoint_url,
headers = chat_endpoint_headers.copy(),
data = req_data,
)
with urllib.request.urlopen(req) as res:
res_data = res.read()
res_body = json.loads(res_data.decode("utf-8"))
output = {
"answer": res_body["choices"][0]["message"]["content"],
}
break
except Exception as e:
time.sleep(API_RETRY_SLEEP)
return output
@register_api("anthropic")
def chat_completion_anthropic(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
import anthropic
if api_dict:
api_key = api_dict["api_key"]
else:
api_key = os.environ["ANTHROPIC_API_KEY"]
sys_msg = ""
if messages[0]["role"] == "system":
sys_msg = messages[0]["content"]
messages = messages[1:]
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
c = anthropic.Anthropic(api_key=api_key)
response = c.messages.create(
model=model,
messages=messages,
stop_sequences=[anthropic.HUMAN_PROMPT],
max_tokens=max_tokens,
temperature=temperature,
system=sys_msg
)
output = {
"answer": response.content[0].text
}
break
except anthropic.APIError as e:
time.sleep(API_RETRY_SLEEP)
return output
@register_api("anthropic_thinking")
def chat_completion_anthropic_thinking(model, messages, max_tokens, budget_tokens, **kwargs):
import anthropic
client = anthropic.Anthropic(
timeout=1200,
)
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
response = client.messages.create(
model=model,
max_tokens=max_tokens,
thinking={
"type": "enabled",
"budget_tokens": budget_tokens
},
messages=messages,
)
output = {
"thought": response.content[0].thinking,
"answer": response.content[1].text,
}
break
except anthropic.APIError as e:
time.sleep(API_RETRY_SLEEP)
return output
@register_api("mistral")
def chat_completion_mistral(model, messages, temperature, max_tokens, **kwargs):
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from mistralai.exceptions import MistralException
api_key = os.environ["MISTRAL_API_KEY"]
client = MistralClient(api_key=api_key)
prompts = [ChatMessage(role=message["role"], content=message["content"]) for message in messages]
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
chat_response = client.chat(
model=model,
messages=prompts,
temperature=temperature,
max_tokens=max_tokens,
)
output = {
"answer": chat_response.choices[0].message.content
}
break
except MistralException as e:
break
return output
@register_api("xai")
def chat_completion_xai(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
import xai_sdk
client = xai_sdk.Client(api_key=api_dict['api_key'], api_host=api_dict['api_base']).compat
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
stream = client.chat.completions.create(
model=model,
messages=messages,
stream=True,
max_tokens=max_tokens,
temperature=temperature,
top_p=0.95,
)
output_text = ""
for chunk in stream:
if chunk.choices[0].delta.content:
output_text += chunk.choices[0].delta.content
output = {
"answer": output_text
}
break
except Exception as e:
time.sleep(API_RETRY_SLEEP)
return output
@register_api("litellm")
def chat_completion_litellm(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
import litellm
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
response = litellm.completion(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
output = {
"answer": response.choices[0].message.content
}
break
except Exception as e:
time.sleep(API_RETRY_SLEEP)
return output
@register_api("litellm_streaming")
def chat_completion_litellm_streaming(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
"""Streaming version of litellm completion"""
import litellm
try:
response = litellm.completion(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except Exception as e:
print(f"Error in litellm streaming completion: {e}")
yield f"Error: {str(e)}"
@register_api("anthropic_streaming")
def chat_completion_anthropic_streaming(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
"""Streaming version of Anthropic completion"""
import anthropic
if api_dict:
client = anthropic.Anthropic(api_key=api_dict["api_key"])
else:
client = anthropic.Anthropic()
try:
# Convert messages to Anthropic format
system_message = ""
conversation_messages = []
for msg in messages:
if msg["role"] == "system":
system_message = msg["content"]
else:
conversation_messages.append(msg)
stream = client.messages.create(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=system_message if system_message else None,
messages=conversation_messages,
stream=True
)
for chunk in stream:
if chunk.type == "content_block_delta" and chunk.delta.text:
yield chunk.delta.text
except Exception as e:
print(f"Error in Anthropic streaming completion: {e}")
yield f"Error: {str(e)}"
@register_api("gemini")
def http_completion_gemini(model, messages, **kwargs):
import requests
api_key = os.environ["GEMINI_API_KEY"]
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
},
]
sys_prompt = None
if messages[0]["role"] == "system":
sys_prompt = {
"parts":[
{"text": messages[0]["content"]}
]
}
messages = messages[1:]
role_map = {"user": "user",
"assistant": "model"}
conv = [{"parts":[{"text":turn["content"]}], "role":role_map[turn["role"]]} for turn in messages]
json_request = {
"contents": conv,
"safetySettings": safety_settings,
"systemInstruction": sys_prompt,
}
if "temperature" in kwargs and "max_tokens" in kwargs:
gen_config = {
"temperature": kwargs["temperature"],
"maxOutputTokens": kwargs["max_tokens"],
}
json_request["generationConfig"] = gen_config
elif "temperature" in kwargs:
gen_config = {
"temperature": kwargs["temperature"],
}
json_request["generationConfig"] = gen_config
elif "max_tokens" in kwargs:
gen_config = {
"maxOutputTokens": kwargs["max_tokens"],
}
json_request["generationConfig"] = gen_config
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
response = requests.post(
f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}",
json=json_request,
)
except Exception as e:
print(f"**API REQUEST ERROR** Reason: {e}.")
time.sleep(API_RETRY_SLEEP)
if response.status_code != 200:
print(f"**API REQUEST ERROR** Reason: status code {response.status_code}.")
time.sleep(API_RETRY_SLEEP)
try:
output = {
"answer": response.json()["candidates"][0]["content"]["parts"][0]["text"],
}
except KeyError as e:
print(response.json())
return output
@register_api("vertex")
def vertex_completion_gemini(model, messages, project_id, regions, **kwargs):
import requests
import subprocess
output = API_ERROR_OUTPUT
# Obtain the access token using gcloud CLI
access_token = subprocess.check_output(
["gcloud", "auth", "application-default", "print-access-token"],
text=True
).strip()
if messages[0]["role"] == "system":
data = {
"systemInstruction": {
"role": "system", # ignored by vertexi api (04/18/2025)
"parts": [{
"text": messages[0]["content"]
}]
},
}
messages = messages[1:]
else:
data = {}
role_map = {
"user": "user",
"assistant": "model"
}
messages = [{"parts":[{"text":turn["content"]}], "role":role_map[turn["role"]]} for turn in messages]
url = (
f"https://us-central1-aiplatform.googleapis.com/v1/projects/"
f"{project_id}/locations/{regions}/publishers/google/models/"
f"{model}:generateContent"
)
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
data = data | {
"contents": messages,
}
if "temperature" in kwargs or "max_tokens" in kwargs:
gen_config = {}
if "temperature" in kwargs:
gen_config["temperature"] = kwargs["temperature"]
if "max_tokens" in kwargs:
gen_config["maxOutputTokens"] = kwargs["max_tokens"]
data["generationConfig"] = gen_config
response = requests.post(url, json=data, headers=headers)
try:
output = {
"answer": response.json()["candidates"][0]["content"]["parts"][0]["text"],
}
except KeyError as e:
print(type(e), e)
print(response.json())
return output
@register_api("cohere")
def chat_completion_cohere(model, messages, temperature, max_tokens, **kwargs):
import cohere
co = cohere.Client(os.environ["COHERE_API_KEY"])
assert len(messages) > 0
template_map = {"system":"SYSTEM",
"assistant":"CHATBOT",
"user":"USER"}
assert messages[-1]["role"] == "user"
prompt = messages[-1]["content"]
if len(messages) > 1:
history = []
for message in messages[:-1]:
history.append({"role":template_map[message["role"]], "message":message["content"]})
else:
history = None
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
response = co.chat(
message=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
chat_history=history,
)
output = {
"answer": response.text
}
break
except cohere.core.api_error.ApiError as e:
raise
except Exception as e:
break
return output
@register_api("meta")
def chat_completion_meta(model, messages, temperature, max_tokens, api_dict, **kwargs):
assert api_dict
texts = [{"role": m["role"],
"text": m["content"]} for m in messages]
output = ""
for _ in range(API_MAX_RETRY):
try:
res = requests.post(
f"{api_dict['api_base']}/chat_stream_completions?access_token={api_dict['api_key']}",
stream=True,
headers={"Content-Type": "application/json"},
json={
"model": model,
"chunks_delimited": True,
"messages": texts,
"options": {
"max_tokens": max_tokens,
"generation_algorithm": "top_p",
"top_p": 1,
"temperature": temperature,
},
},
timeout=30,
)
if res.status_code == 200:
for line in res.iter_lines():
if line:
part = json.loads(line.decode("utf-8"))
if "text" in part:
output += part["text"]
break
else:
print(f"**API REQUEST ERROR** Code: {res.status_code}")
time.sleep(API_RETRY_SLEEP)
except Exception as e:
print("**API REQUEST ERROR** Reason: Unknown.")
time.sleep(API_RETRY_SLEEP)
continue
return {
"answer": output
}
def batch_submit_sglang(
executor,
tokenizer,
temperature,
max_tokens,
all_context,
max_context_length=None,
end_think_token=None,
):
print(f"DEBUG: sglang_completion_qwq: max_context_length: {max_context_length}")
sampling_params = {
"temperature": temperature,
"skip_special_tokens": False,
"max_new_tokens": max_tokens - 1,
"no_stop_trim": True,
}
batch_prompt_token_ids = []
batch_uids =[]
uid_to_prompt = {}
uid_to_response = {}
for context in all_context:
prompt_token_ids = tokenizer.apply_chat_template(
context['turns'],
add_generation_prompt=True,
tokenize=True,
)
if max_context_length and (len(prompt_token_ids) + max_tokens) > max_context_length:
print(f"DEBUG: sglang_completion_qwq: context length ({len(prompt_token_ids) + max_tokens}) > max_context_length ({max_context_length}), skip this context")
continue
batch_prompt_token_ids.append(prompt_token_ids)
batch_uids.append(context['uid'])
uid_to_prompt[context['uid']] = context['turns']
err_msg = f"ERROR: len(batch_prompt_token_ids): {len(batch_prompt_token_ids)} != len(batch_uids): {len(batch_uids)}"
assert len(batch_prompt_token_ids) == len(batch_uids), err_msg
_ = executor.submit(
prompt_token_ids=batch_prompt_token_ids,
sampling_params=[sampling_params] * len(batch_uids),
keys=batch_uids,
)
for request in tqdm(executor.as_completed(), total=len(batch_uids)):
uid = request.key()
result = request.result()
raw_response = tokenizer.decode(
result['output_ids'],
skip_special_tokens=True,
)
if end_think_token:
thought, _, ans = raw_response.partition(end_think_token)
if ans == "":
uid_to_response[uid] = {"thought": thought, "answer": raw_response}
else:
uid_to_response[uid] = {"thought": thought, "answer": ans}
else:
uid_to_response[uid] = {"answer": raw_response}
# assert len(uid_to_response) == len(all_context), f"ERROR: len output ({len(uid_to_response)}) != len input ({len(all_context)})"
return uid_to_response
def _infer_cuda_tp_world_size():
cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_devices is None:
tp_world_size = 8
else:
tp_world_size = len(cuda_devices.split(","))
return tp_world_size
def download_model(model: str, max_workers: int = 64):
import subprocess
env = os.environ.copy()
env["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
cmd = [
"huggingface-cli",
"download",
f"--max-workers={max_workers}",
model
]
try:
subprocess.run(cmd, env=env, check=True)
print(f"Successfully downloaded model '{model}' with {max_workers} max workers.")
except subprocess.CalledProcessError as e:
print(f"Error occurred while downloading the model: {e}")
@register_engine("sglang")
def sglang_completion(
model,
batch_context,
answer_file,
temperature,
max_tokens=32768,
end_think_token=None,
**kwargs,
):
from transformers import AutoTokenizer
from utils.sglang_server import SGLangServerExecutor
import re
tokenizer = AutoTokenizer.from_pretrained(model)
uids = [context['uid'] for context in batch_context]
prompts = [context['instruction'] for context in batch_context]
code_envs = [context['environment'] for context in batch_context]
processed_context = [
{
"uid": uids[i],
"turns": [{
"content": prompts[i],
"role": "user",
}]
}
for i in tqdm(range(len(uids)))
]
download_model(model=model)
server_args = {
"model_path": model,
"dtype": "auto",
"tp_size": _infer_cuda_tp_world_size(),
"mem_fraction_static": 0.7,
"max_prefill_tokens": max_tokens,
"max_workers": 256,
"server_port": 30000,
}
executor = SGLangServerExecutor(
**server_args,
)
print(f"DEBUG: sglang_completion: model: {model}")
uid_to_response = batch_submit_sglang(
executor=executor,
tokenizer=tokenizer,
temperature=temperature,
max_tokens=max_tokens,
all_context=processed_context,
end_think_token=end_think_token,
)
executor.join()
print("DEBUG: sglang_completion: done, sleep 10 seconds...")
time.sleep(10)
num_null = sum(
[uid_to_response[uid]['answer'] is None for uid in uids if uid in uid_to_response]
)
print(f"Number of null responses: {num_null}")
records = []
for i, context in enumerate(processed_context):
uid = context['uid']
if uid not in uid_to_response:
continue
answer_data = uid_to_response[uid]
record = {
"uid": uid,
"ans_id": shortuuid.uuid(),
"model": kwargs.get("model_display_name", model),
"messages": context['turns'] + [
{"content": answer_data, "role": "assistant"}
],
"environment": code_envs[i],
"tstamp": time.time(),
"metadata": {},
}
records.append(record)
with open(answer_file, 'w', encoding='utf-8') as f:
for rec in records:
f.write(json.dumps(rec, ensure_ascii=True) + '\n')
@register_api("aws_claude")
def chat_completion_aws_bedrock_claude(messages, api_dict=None, aws_region="us-west-2", **kwargs):
"""
Call AWS Bedrock API for chat completion
Args:
model (str): Model ID
conv (object): Conversation object containing messages
temperature (float): Temperature parameter for response generation
max_tokens (int): Maximum tokens in response
api_dict (dict, optional): API configuration dictionary
aws_region (str, optional): AWS region, defaults to "us-west-2"
Returns:
str: Generated response text or error message
"""
# Configure AWS client if api_dict provided
if api_dict is not None:
bedrock_rt_client = boto3.client(
service_name='bedrock-runtime',
region_name=aws_region,
aws_access_key_id=api_dict.get('aws_access_key_id'),
aws_secret_access_key=api_dict.get('aws_secret_access_key')
)
else:
bedrock_rt_client = boto3.client(
service_name='bedrock-runtime',
region_name=aws_region,)
output = API_ERROR_OUTPUT
#get kwargs from settings
temperature= kwargs["temperature"]
max_tokens= kwargs["max_tokens"]
model = kwargs["model_id"]
sys_msg = ""
if messages[0]["role"] == "system":
sys_msg = messages[0]["content"]
messages = messages[1:]
else:
prompt = messages[0]['content']
# Retry logic for API calls
for _ in range(API_MAX_RETRY):
try:
# Prepare request body
prompt_json = {
"system": sys_msg,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"anthropic_version": "bedrock-2023-05-31",
"stop_sequences": ["Human"]
}
# Call Bedrock API
response = bedrock_rt_client.invoke_model(
body=json.dumps(prompt_json),
modelId=model,
accept='application/json',
contentType='application/json'
)
# Parse response
response_body = json.loads(response.get('body').read())
output = {"answer":response_body.get("content")[0].get("text")}
break
except Exception as e:
time.sleep(API_RETRY_SLEEP)
return output
@register_api("aws_mistral")
def chat_completion_aws_bedrock_mistral(messages, api_dict=None, aws_region="us-west-2", **kwargs):
"""
Call AWS Bedrock API for chat completion
Args:
model (str): Model ID
conv (object): Conversation object containing messages
temperature (float): Temperature parameter for response generation
max_tokens (int): Maximum tokens in response
api_dict (dict, optional): API configuration dictionary
aws_region (str, optional): AWS region, defaults to "us-west-2"
Returns:
str: Generated response text or error message
"""
# Configure AWS client if api_dict provided
if api_dict is not None:
bedrock_rt_client = boto3.client(
service_name='bedrock-runtime',
region_name=aws_region,
aws_access_key_id=api_dict.get('aws_access_key_id'),
aws_secret_access_key=api_dict.get('aws_secret_access_key')
)
else:
bedrock_rt_client = boto3.client(
service_name='bedrock-runtime',
region_name=aws_region,)
output = API_ERROR_OUTPUT
#get kwargs from settings
temperature= kwargs["temperature"]
max_tokens= kwargs["max_tokens"]
model = kwargs["model_id"]
# Retry logic for API calls
for _ in range(API_MAX_RETRY):
try:
## =============== Format prompt ================
prompt = "\n".join([content for message in messages for content in message["content"]])
formatted_prompt = f"<s>[INST] {prompt.strip()} [/INST]"
body = {
"prompt": formatted_prompt,
"max_tokens": max_tokens,
"stop": ["Human:"],
"temperature": temperature,
}
# Call Bedrock API
response = bedrock_rt_client.invoke_model(
body=json.dumps(body),
modelId=model,
accept='application/json',
contentType='application/json'
)
# Parse response
response_body = json.loads(response.get('body').read())
if "pixtral-large" in model: #us.mistral.pixtral-large-2502-v1:0
output = {"answer": response_body.get("choices")[0].get("message").get("content")}
else:
output = {"answer": response_body.get("outputs")[0].get("text")}
break
except Exception as e:
time.sleep(API_RETRY_SLEEP)
return output
@register_api("mistral_streaming")
def chat_completion_mistral_streaming(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
"""Streaming version of Mistral completion"""
import openai
if api_dict:
client = openai.OpenAI(
base_url=api_dict["api_base"],
api_key=api_dict["api_key"],
)
else:
client = openai.OpenAI()
try:
stream = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except Exception as e:
print(f"Error in Mistral streaming completion: {e}")
yield f"Error: {str(e)}"
@register_api("gemini_streaming")
def chat_completion_gemini_streaming(model, messages, **kwargs):
"""Streaming version of Gemini completion"""
import google.generativeai as genai
try:
# Configure the API
genai.configure(api_key=os.environ.get("GEMINI_API_KEY"))
# Create model
model_genai = genai.GenerativeModel(model)
# Convert messages to Gemini format
conversation = model_genai.start_chat(history=[])
# Get the last user message
last_user_message = None
for msg in messages:
if msg["role"] == "user":
last_user_message = msg["content"]
if not last_user_message:
yield "Error: No user message found"
return
# Stream the response
response = conversation.send_message(last_user_message, stream=True)
for chunk in response:
if chunk.text:
yield chunk.text
except Exception as e:
print(f"Error in Gemini streaming completion: {e}")
yield f"Error: {str(e)}"