ImageEditPro / util.py
selfit-camera's picture
init
2c7d451
import os
import sys
import cv2
import json
import random
import time
import datetime
import requests
import func_timeout
import numpy as np
import gradio as gr
import boto3
import tempfile
from botocore.client import Config
from PIL import Image
# TOKEN = os.environ['TOKEN']
# APIKEY = os.environ['APIKEY']
# UKAPIURL = os.environ['UKAPIURL']
OneKey = os.environ['OneKey'].strip()
OneKey = OneKey.split("#")
TOKEN = OneKey[0]
APIKEY = OneKey[1]
UKAPIURL = OneKey[2]
LLMKEY = OneKey[3]
R2_ACCESS_KEY = OneKey[4]
R2_SECRET_KEY = OneKey[5]
R2_ENDPOINT = OneKey[6]
tmpFolder = "tmp"
os.makedirs(tmpFolder, exist_ok=True)
def upload_user_img(clientIp, timeId, img):
fileName = clientIp.replace(".", "")+str(timeId)+".jpg"
local_path = os.path.join(tmpFolder, fileName)
img = cv2.imread(img)
cv2.imwrite(os.path.join(tmpFolder, fileName), img)
json_data = {
"token": TOKEN,
"input1": fileName,
"input2": "",
"protocol": "",
"cloud": "ali"
}
session = requests.session()
ret = requests.post(
f"{UKAPIURL}/upload",
headers={'Content-Type': 'application/json'},
json=json_data
)
res = ""
if ret.status_code==200:
if 'upload1' in ret.json():
upload_url = ret.json()['upload1']
headers = {'Content-Type': 'image/jpeg'}
response = session.put(upload_url, data=open(local_path, 'rb').read(), headers=headers)
# print(response.status_code)
if response.status_code == 200:
res = upload_url
if os.path.exists(local_path):
os.remove(local_path)
return res
class R2Api:
def __init__(self, session=None):
super().__init__()
self.R2_BUCKET = "trump-ai-voice"
self.domain = "https://www.trumpaivoice.net/"
self.R2_ACCESS_KEY = R2_ACCESS_KEY
self.R2_SECRET_KEY = R2_SECRET_KEY
self.R2_ENDPOINT = R2_ENDPOINT
self.client = boto3.client(
"s3",
endpoint_url=self.R2_ENDPOINT,
aws_access_key_id=self.R2_ACCESS_KEY,
aws_secret_access_key=self.R2_SECRET_KEY,
config=Config(signature_version="s3v4")
)
self.session = requests.Session() if session is None else session
def upload_file(self, local_path, cloud_path):
t1 = time.time()
head_dict = {
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'png': 'image/png',
'gif': 'image/gif',
'bmp': 'image/bmp',
'webp': 'image/webp',
'ico': 'image/x-icon'
}
ftype = os.path.basename(local_path).split(".")[-1].lower()
ctype = head_dict.get(ftype, 'application/octet-stream')
headers = {"Content-Type": ctype}
cloud_path = f"QwenImageEdit/Uploads/{str(datetime.date.today())}/{os.path.basename(local_path)}"
url = self.client.generate_presigned_url(
"put_object",
Params={"Bucket": self.R2_BUCKET, "Key": cloud_path, "ContentType": ctype},
ExpiresIn=604800
)
retry_count = 0
while retry_count < 3:
try:
with open(local_path, 'rb') as f:
self.session.put(url, data=f.read(), headers=headers, timeout=8)
break
except (requests.exceptions.Timeout, requests.exceptions.RequestException):
retry_count += 1
if retry_count == 3:
raise Exception('Failed to upload file to R2 after 3 retries!')
continue
print("upload_file time is ====>", time.time() - t1)
return f"{self.domain}{cloud_path}"
def upload_user_img_r2(clientIp, timeId, img):
fileName = clientIp.replace(".", "")+str(timeId)+".jpg"
local_path = os.path.join(tmpFolder, fileName)
img = cv2.imread(img)
cv2.imwrite(os.path.join(tmpFolder, fileName), img)
res = R2Api().upload_file(local_path, fileName)
if os.path.exists(local_path):
os.remove(local_path)
return res
@func_timeout.func_set_timeout(10)
def get_country_info(ip):
"""获取IP对应的国家信息"""
try:
# 使用您指定的新接口 URL
url = f"https://qifu-api.baidubce.com/ip/geo/v1/district?ip={ip}"
ret = requests.get(url)
ret.raise_for_status() # 如果请求失败 (例如 404, 500), 会抛出异常
json_data = ret.json()
# 根据新的JSON结构,国家信息在 'data' -> 'country' 路径下
if json_data.get("code") == "Success":
country = json_data.get("data", {}).get("country")
return country if country else "Unknown"
else:
# 处理API返回错误码的情况
print(f"API请求失败: {json_data.get('msg', '未知错误')}")
return "Unknown"
except requests.exceptions.RequestException as e:
print(f"网络请求失败: {e}")
return "Unknown"
except Exception as e:
print(f"获取IP属地失败: {e}")
return "Unknown"
@func_timeout.func_set_timeout(10)
def get_location_info(ip):
"""获取IP对应的详细位置信息"""
try:
# 使用您指定的新接口 URL
url = f"https://qifu-api.baidubce.com/ip/geo/v1/district?ip={ip}"
ret = requests.get(url)
ret.raise_for_status()
json_data = ret.json()
if json_data.get("code") == "Success":
data = json_data.get("data", {})
return {
"country": data.get("country", "Unknown"),
"prov": data.get("prov", "Unknown"),
"city": data.get("city", "Unknown"),
"isp": data.get("isp", "Unknown"),
"owner": data.get("owner", "Unknown"),
"full_data": data
}
else:
print(f"API请求失败: {json_data.get('msg', '未知错误')}")
return {
"country": "Unknown",
"prov": "Unknown",
"city": "Unknown",
"isp": "Unknown",
"owner": "Unknown",
"full_data": {}
}
except requests.exceptions.RequestException as e:
print(f"网络请求失败: {e}")
return {"country": "Unknown", "prov": "Unknown", "city": "Unknown", "isp": "Unknown", "owner": "Unknown", "full_data": {}}
except Exception as e:
print(f"获取IP属地失败: {e}")
return {"country": "Unknown", "prov": "Unknown", "city": "Unknown", "isp": "Unknown", "owner": "Unknown", "full_data": {}}
def get_country_info_safe(ip):
"""安全获取IP属地信息,出错时返回Unknown"""
try:
return get_country_info(ip)
except func_timeout.FunctionTimedOut:
print(f"获取IP属地超时: {ip}")
return "Unknown"
except Exception as e:
print(f"获取IP属地失败: {e}")
return "Unknown"
def get_location_info_safe(ip):
"""安全获取IP详细位置信息,出错时返回默认值"""
try:
return get_location_info(ip)
except func_timeout.FunctionTimedOut:
print(f"获取IP位置超时: {ip}")
return {"country": "Unknown", "prov": "Unknown", "city": "Unknown", "isp": "Unknown", "owner": "Unknown", "full_data": {}}
except Exception as e:
print(f"获取IP位置失败: {e}")
return {"country": "Unknown", "prov": "Unknown", "city": "Unknown", "isp": "Unknown", "owner": "Unknown", "full_data": {}}
def contains_chinese(text):
"""检测文本是否包含中文字符"""
import re
return bool(re.search(r'[\u4e00-\u9fff]', text))
def submit_image_edit_task(user_image_url, prompt):
"""
提交图片编辑任务
"""
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {APIKEY}'
}
data = {
"user_image": user_image_url,
"task_type": "80",
"prompt": prompt,
"secret_key": "219ngu",
"is_private": "0"
}
try:
response = requests.post(
f'{UKAPIURL}/public_image_edit',
headers=headers,
json=data
)
if response.status_code == 200:
result = response.json()
if result.get('code') == 0:
return result['data']['task_id'], None
else:
return None, f"API Error: {result.get('message', 'Unknown error')}"
else:
return None, f"HTTP Error: {response.status_code}"
except Exception as e:
return None, f"Request Exception: {str(e)}"
def check_task_status(task_id):
"""
查询任务状态
"""
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {APIKEY}'
}
data = {
"task_id": task_id
}
try:
response = requests.post(
f'{UKAPIURL}/status_image_edit',
headers=headers,
json=data
)
if response.status_code == 200:
result = response.json()
if result.get('code') == 0:
task_data = result['data']
return task_data['status'], task_data.get('output1'), task_data
else:
return 'error', None, result.get('message', 'Unknown error')
else:
return 'error', None, f"HTTP Error: {response.status_code}"
except Exception as e:
return 'error', None, f"Request Exception: {str(e)}"
def process_image_edit(img_input, prompt, progress_callback=None):
"""
处理图片编辑的完整流程
Args:
img_input: 可以是文件路径(str)或PIL Image对象
prompt: 编辑指令
progress_callback: 进度回调函数
"""
temp_img_path = None
try:
# 生成客户端 IP 和时间戳
client_ip = "127.0.0.1" # 默认IP
time_id = int(time.time())
# 处理输入图像 - 支持PIL Image和文件路径
if hasattr(img_input, 'save'): # PIL Image对象
# 创建临时文件
temp_dir = tempfile.mkdtemp()
temp_img_path = os.path.join(temp_dir, f"temp_img_{time_id}.jpg")
# 保存PIL Image为临时文件
if img_input.mode != 'RGB':
img_input = img_input.convert('RGB')
img_input.save(temp_img_path, 'JPEG', quality=95)
img_path = temp_img_path
print(f"💾 PIL Image已保存为临时文件: {temp_img_path}")
else:
# 假设是文件路径
img_path = img_input
if progress_callback:
progress_callback("uploading image...")
# 上传用户图片
uploaded_url = upload_user_img_r2(client_ip, time_id, img_path)
if not uploaded_url:
return None, "image upload failed"
# 从上传 URL 中提取实际的图片 URL
if "?" in uploaded_url:
uploaded_url = uploaded_url.split("?")[0]
if progress_callback:
progress_callback("submitting edit task...")
# 提交图片编辑任务
task_id, error = submit_image_edit_task(uploaded_url, prompt)
if error:
return None, error
if progress_callback:
progress_callback(f"task submitted, ID: {task_id}, processing...")
# 等待任务完成
max_attempts = 60 # 最多等待10分钟
for attempt in range(max_attempts):
status, output_url, task_data = check_task_status(task_id)
if status == 'completed':
if output_url:
return output_url, "image edit completed"
else:
return None, "Task completed but no result image returned"
elif status == 'error' or status == 'failed':
return None, f"task processing failed: {task_data}"
elif status in ['queued', 'processing', 'running', 'created', 'working']:
if progress_callback:
progress_callback(f"task processing... (status: {status})")
time.sleep(1) # 等待10秒后重试
else:
if progress_callback:
progress_callback(f"unknown status: {status}")
time.sleep(1)
return None, "task processing timeout"
except Exception as e:
return None, f"error occurred during processing: {str(e)}"
finally:
# 清理临时文件
if temp_img_path and os.path.exists(temp_img_path):
try:
os.remove(temp_img_path)
# 尝试删除临时目录(如果为空)
temp_dir = os.path.dirname(temp_img_path)
if os.path.exists(temp_dir):
os.rmdir(temp_dir)
print(f"🗑️ 已清理临时文件: {temp_img_path}")
except Exception as cleanup_error:
print(f"⚠️ 清理临时文件失败: {cleanup_error}")
if __name__ == "__main__":
pass