Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://docs.aws.amazon.com/bedrock/latest/userguide/api-inference-examples-claude-messages-code-examples.html | |
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html | |
https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html | |
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html | |
https://docs.aws.amazon.com/bedrock/latest/userguide/inference-invoke.html | |
https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-examples.html | |
""" | |
import argparse | |
from datetime import datetime | |
import json | |
import os | |
from pathlib import Path | |
import sys | |
import time | |
from zoneinfo import ZoneInfo # Python 3.9+ 自带,无需安装 | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../")) | |
import boto3 | |
from project_settings import environment, project_path | |
def get_args(): | |
""" | |
python3 aws_claude_chat.py --model_name anthropic.claude-instant-v1 \ | |
--eval_dataset_name agent-lingoace-zh-80-chat.jsonl \ | |
--client "us_west(47.88.76.239)" \ | |
--create_time_str 20250724-interval-1 \ | |
--interval 1 | |
python3 aws_claude_chat.py --model_name anthropic.claude-v2 \ | |
--eval_dataset_name agent-lingoace-zh-80-chat.jsonl \ | |
--client "us_west(47.88.76.239)" \ | |
--create_time_str 20250724-interval-1 \ | |
--interval 1 | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model_name", | |
# default="anthropic.claude-v2", | |
default="anthropic.claude-instant-v1", | |
# default="anthropic.claude-v2:1", | |
# default="anthropic.claude-instant-v1:2", | |
# default="anthropic.claude-v2:0", | |
type=str | |
) | |
parser.add_argument( | |
"--eval_dataset_name", | |
default="agent-lingoace-zh-80-chat.jsonl", | |
type=str | |
) | |
parser.add_argument( | |
"--eval_dataset_dir", | |
default=(project_path / "data/dataset").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--eval_data_dir", | |
default=(project_path / "data/eval_data").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--client", | |
default="shenzhen_sase", | |
type=str | |
) | |
parser.add_argument( | |
"--service", | |
default="aws_us_east", | |
type=str | |
) | |
parser.add_argument( | |
"--create_time_str", | |
default="null", | |
type=str | |
) | |
parser.add_argument( | |
"--interval", | |
default=1, | |
type=int | |
) | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = get_args() | |
service = environment.get(key=args.service, dtype=json.loads) | |
aws_access_key_id = service["AWS_ACCESS_KEY_ID"] | |
aws_secret_access_key = service["AWS_SECRET_ACCESS_KEY"] | |
aws_default_region = service["AWS_DEFAULT_REGION"] | |
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id | |
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key | |
os.environ["AWS_DEFAULT_REGION"] = aws_default_region | |
client = boto3.client( | |
service_name="bedrock-runtime", | |
region_name=aws_default_region | |
) | |
eval_dataset_dir = Path(args.eval_dataset_dir) | |
eval_dataset_dir.mkdir(parents=True, exist_ok=True) | |
eval_data_dir = Path(args.eval_data_dir) | |
eval_data_dir.mkdir(parents=True, exist_ok=True) | |
if args.create_time_str == "null": | |
tz = ZoneInfo("Asia/Shanghai") | |
now = datetime.now(tz) | |
create_time_str = now.strftime("%Y%m%d_%H%M%S") | |
# create_time_str = "20250722_173400" | |
else: | |
create_time_str = args.create_time_str | |
eval_dataset = eval_dataset_dir / args.eval_dataset_name | |
output_file = eval_data_dir / f"aws_claude/anthropic/{args.model_name}/{args.client}/{args.service}/{create_time_str}/{args.eval_dataset_name}.raw" | |
output_file.parent.mkdir(parents=True, exist_ok=True) | |
total = 0 | |
# finished | |
finished_idx_set = set() | |
if os.path.exists(output_file.as_posix()): | |
with open(output_file.as_posix(), "r", encoding="utf-8") as f: | |
for row in f: | |
row = json.loads(row) | |
idx = row["idx"] | |
total = row["total"] | |
finished_idx_set.add(idx) | |
print(f"finished count: {len(finished_idx_set)}") | |
with open(eval_dataset.as_posix(), "r", encoding="utf-8") as fin, open(output_file.as_posix(), "a+", encoding="utf-8") as fout: | |
for row in fin: | |
row = json.loads(row) | |
idx = row["idx"] | |
prompt = row["prompt"] | |
response = row["response"] | |
if idx in finished_idx_set: | |
continue | |
finished_idx_set.add(idx) | |
body = { | |
"anthropic_version": "bedrock-2023-05-31", | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": prompt}] | |
} | |
], | |
"max_tokens": 1000, | |
"temperature": 0.5, | |
"top_p": 0.95, | |
"stop_sequences": ["client"], | |
# "thinking": { | |
# "type": "enabled", | |
# "budget_tokens": 1 | |
# }, | |
} | |
try: | |
# client.converse() | |
time.sleep(args.interval) | |
print(f"sleep: {args.interval}") | |
time_begin = time.time() | |
llm_response = client.invoke_model( | |
modelId=args.model_name, | |
body=json.dumps(body), | |
contentType="application/json" | |
) | |
llm_response = json.loads(llm_response["body"].read()) | |
# print(result['content'][0]['text']) | |
time_cost = time.time() - time_begin | |
print(f"time_cost: {time_cost}") | |
except Exception as e: | |
print(f"request failed, error type: {type(e)}, error text: {str(e)}") | |
continue | |
prediction = llm_response["content"][0]["text"] | |
total += 1 | |
row_ = { | |
"idx": idx, | |
"prompt": prompt, | |
"response": response, | |
"prediction": prediction, | |
"total": total, | |
"time_cost": time_cost, | |
} | |
row_ = json.dumps(row_, ensure_ascii=False) | |
fout.write(f"{row_}\n") | |
return | |
if __name__ == "__main__": | |
main() | |