Spaces:
Running
Running
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import requests | |
import json | |
import time | |
import sys | |
import base64 | |
import os | |
# 设置UTF-8编码 | |
if sys.platform.startswith('win'): | |
os.system("chcp 65001") | |
if hasattr(sys.stdout, 'reconfigure'): | |
sys.stdout.reconfigure(encoding='utf-8') | |
API_URL = "http://127.0.0.1:8890/v1/chat/completions" | |
API_KEY = "sk-123456" # 替换为实际的API key | |
def test_text_to_image(prompt="生成一只可爱的猫咪", stream=False): | |
"""测试文本到图像生成""" | |
print(f"\n===== 测试文本到图像生成 =====") | |
try: | |
print(f"提示词: '{prompt}'") | |
except UnicodeEncodeError: | |
print(f"提示词: [包含非ASCII字符]") | |
print(f"流式响应: {stream}") | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {API_KEY}" | |
} | |
payload = { | |
"model": "sora-1.0", | |
"messages": [ | |
{"role": "user", "content": prompt} | |
], | |
"n": 1, | |
"stream": stream | |
} | |
start_time = time.time() | |
response = requests.post( | |
API_URL, | |
headers=headers, | |
json=payload, | |
stream=stream | |
) | |
if response.status_code != 200: | |
print(f"错误: 状态码 {response.status_code}") | |
print(response.text) | |
return | |
if stream: | |
# 处理流式响应 | |
print("流式响应内容:") | |
for line in response.iter_lines(): | |
if line: | |
line = line.decode('utf-8') | |
if line.startswith("data: "): | |
data = line[6:] | |
if data == "[DONE]": | |
print("[完成]") | |
else: | |
try: | |
json_data = json.loads(data) | |
if 'choices' in json_data and json_data['choices'] and 'delta' in json_data['choices'][0]: | |
delta = json_data['choices'][0]['delta'] | |
if 'content' in delta: | |
print(f"接收内容: {delta['content']}") | |
except Exception as e: | |
print(f"解析响应时出错: {e}") | |
else: | |
# 处理普通响应 | |
try: | |
data = response.json() | |
print(f"响应内容:") | |
print(json.dumps(data, indent=2, ensure_ascii=False)) | |
if 'choices' in data and data['choices']: | |
image_url = None | |
content = data['choices'][0]['message']['content'] | |
if "[1].split(")")[0] | |
print(f"\n生成的图片URL: {image_url}") | |
except Exception as e: | |
print(f"解析响应时出错: {e}") | |
elapsed = time.time() - start_time | |
print(f"请求耗时: {elapsed:.2f}秒") | |
def test_image_to_image(image_path, prompt="将这张图片变成动漫风格"): | |
"""测试图像到图像生成(Remix)""" | |
print(f"\n===== 测试图像到图像生成 =====") | |
print(f"图片路径: '{image_path}'") | |
print(f"提示词: '{prompt}'") | |
# 读取并转换图片为base64 | |
try: | |
with open(image_path, "rb") as image_file: | |
base64_image = base64.b64encode(image_file.read()).decode('utf-8') | |
except Exception as e: | |
print(f"读取图片失败: {e}") | |
return | |
# 构建请求 | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {API_KEY}" | |
} | |
payload = { | |
"model": "sora-1.0", | |
"messages": [ | |
{"role": "user", "content": f"data:image/jpeg;base64,{base64_image}\n{prompt}"} | |
], | |
"n": 1, | |
"stream": False | |
} | |
start_time = time.time() | |
response = requests.post( | |
API_URL, | |
headers=headers, | |
json=payload | |
) | |
if response.status_code != 200: | |
print(f"错误: 状态码 {response.status_code}") | |
print(response.text) | |
return | |
# 处理响应 | |
try: | |
data = response.json() | |
print(f"响应内容:") | |
print(json.dumps(data, indent=2, ensure_ascii=False)) | |
if 'choices' in data and data['choices']: | |
image_url = None | |
content = data['choices'][0]['message']['content'] | |
if "[1].split(")")[0] | |
print(f"\n生成的图片URL: {image_url}") | |
except Exception as e: | |
print(f"解析响应时出错: {e}") | |
elapsed = time.time() - start_time | |
print(f"请求耗时: {elapsed:.2f}秒") | |
def main(): | |
"""主函数""" | |
if len(sys.argv) < 2: | |
print("用法: python test_client.py <测试类型> [参数...]") | |
print("测试类型:") | |
print(" text2img <提示词> [stream=true/false]") | |
print(" img2img <图片路径> <提示词>") | |
return | |
test_type = sys.argv[1].lower() | |
if test_type == "text2img": | |
prompt = sys.argv[2] if len(sys.argv) > 2 else "生成一只可爱的猫咪" | |
stream = False | |
if len(sys.argv) > 3 and sys.argv[3].lower() == "stream=true": | |
stream = True | |
test_text_to_image(prompt, stream) | |
elif test_type == "img2img": | |
if len(sys.argv) < 3: | |
print("错误: 需要图片路径") | |
return | |
image_path = sys.argv[2] | |
prompt = sys.argv[3] if len(sys.argv) > 3 else "将这张图片变成动漫风格" | |
test_image_to_image(image_path, prompt) | |
else: | |
print(f"错误: 未知的测试类型 '{test_type}'") | |
if __name__ == "__main__": | |
main() |