HoneyTian's picture
first commit
4464055
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://docs.aws.amazon.com/bedrock/latest/userguide/api-inference-examples-claude-messages-code-examples.html
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html
https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
https://docs.aws.amazon.com/bedrock/latest/userguide/inference-invoke.html
https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-examples.html
"""
import argparse
from datetime import datetime
import json
import os
from pathlib import Path
import sys
import time
from zoneinfo import ZoneInfo # Python 3.9+ 自带,无需安装
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../"))
import boto3
from project_settings import environment, project_path
def get_args():
"""
python3 aws_claude.py --model_name anthropic.claude-instant-v1 \
--eval_dataset_name agent-lingoace-zh-400-choice.jsonl \
--client "us_west(47.88.76.239)" \
--create_time_str 20250723-interval-10 \
--interval 10
python3 aws_claude.py --model_name anthropic.claude-v2 \
--eval_dataset_name agent-lingoace-zh-400-choice.jsonl \
--client "us_west(47.88.76.239)" \
--create_time_str 20250723-interval-10 \
--interval 10
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
# default="anthropic.claude-v2",
default="anthropic.claude-instant-v1",
# default="anthropic.claude-v2:1",
# default="anthropic.claude-instant-v1:2",
# default="anthropic.claude-v2:0",
type=str
)
parser.add_argument(
"--eval_dataset_name",
default="agent-lingoace-zh-400-choice.jsonl",
# default="arc-easy-1000-choice.jsonl",
type=str
)
parser.add_argument(
"--eval_dataset_dir",
default=(project_path / "data/dataset").as_posix(),
type=str
)
parser.add_argument(
"--eval_data_dir",
default=(project_path / "data/eval_data").as_posix(),
type=str
)
parser.add_argument(
"--client",
default="shenzhen_sase",
type=str
)
parser.add_argument(
"--service",
default="aws_us_east",
type=str
)
parser.add_argument(
"--create_time_str",
default="null",
type=str
)
parser.add_argument(
"--interval",
default=10,
type=int
)
args = parser.parse_args()
return args
def main():
args = get_args()
service = environment.get(key=args.service, dtype=json.loads)
aws_access_key_id = service["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = service["AWS_SECRET_ACCESS_KEY"]
aws_default_region = service["AWS_DEFAULT_REGION"]
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
os.environ["AWS_DEFAULT_REGION"] = aws_default_region
client = boto3.client(
service_name="bedrock-runtime",
region_name=aws_default_region
)
eval_dataset_dir = Path(args.eval_dataset_dir)
eval_dataset_dir.mkdir(parents=True, exist_ok=True)
eval_data_dir = Path(args.eval_data_dir)
eval_data_dir.mkdir(parents=True, exist_ok=True)
if args.create_time_str == "null":
tz = ZoneInfo("Asia/Shanghai")
now = datetime.now(tz)
create_time_str = now.strftime("%Y%m%d_%H%M%S")
# create_time_str = "20250722_173400"
else:
create_time_str = args.create_time_str
eval_dataset = eval_dataset_dir / args.eval_dataset_name
output_file = eval_data_dir / f"aws_claude/anthropic/{args.model_name}/{args.client}/{args.service}/{create_time_str}/{args.eval_dataset_name}"
output_file.parent.mkdir(parents=True, exist_ok=True)
total = 0
total_correct = 0
# finished
finished_idx_set = set()
if os.path.exists(output_file.as_posix()):
with open(output_file.as_posix(), "r", encoding="utf-8") as f:
for row in f:
row = json.loads(row)
idx = row["idx"]
total = row["total"]
total_correct = row["total_correct"]
finished_idx_set.add(idx)
print(f"finished count: {len(finished_idx_set)}")
with open(eval_dataset.as_posix(), "r", encoding="utf-8") as fin, open(output_file.as_posix(), "a+", encoding="utf-8") as fout:
for row in fin:
row = json.loads(row)
idx = row["idx"]
prompt = row["prompt"]
response = row["response"]
if idx in finished_idx_set:
continue
finished_idx_set.add(idx)
body = {
"anthropic_version": "bedrock-2023-05-31",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
}
],
"max_tokens": 1000,
"temperature": 0.5,
"top_p": 0.95,
# "thinking": {
# "type": "enabled",
# "budget_tokens": 1
# },
}
try:
# client.converse()
time.sleep(args.interval)
print(f"sleep: {args.interval}")
time_begin = time.time()
llm_response = client.invoke_model(
modelId=args.model_name,
body=json.dumps(body),
contentType="application/json"
)
llm_response = json.loads(llm_response["body"].read())
# print(result['content'][0]['text'])
time_cost = time.time() - time_begin
print(f"time_cost: {time_cost}")
except Exception as e:
print(f"request failed, error type: {type(e)}, error text: {str(e)}")
continue
prediction = llm_response["content"][0]["text"]
correct = 1 if prediction == response else 0
total += 1
total_correct += correct
score = total_correct / total
row_ = {
"idx": idx,
"prompt": prompt,
"response": response,
"prediction": prediction,
"correct": correct,
"total": total,
"total_correct": total_correct,
"score": score,
"time_cost": time_cost,
}
row_ = json.dumps(row_, ensure_ascii=False)
fout.write(f"{row_}\n")
return
if __name__ == "__main__":
main()