import os import logging import random import zipfile import tempfile import shutil import json from typing import List, Dict, Any, Optional, Union from PIL import Image from app.api import get_chat_completion from app.config import ( STICKER_RERANKING_SYSTEM_PROMPT, PUBLIC_URL, TEMP_DIR ) from app.database import db from app.image_utils import ( save_image_temp, generate_temp_image, upload_folder_to_huggingface, upload_to_huggingface, get_image_cdn_url, get_image_description, calculate_image_hash ) from app.gradio_formatter import gradio_formatter from multiprocessing import Queue # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) class StickerService: """贴纸服务类,处理贴纸的上传、搜索等业务逻辑""" @staticmethod def upload_sticker(image_file_path: str, title: str, description: str, tags: str) -> str: """上传贴纸""" try: # 打开图片 image = Image.open(image_file_path) # 检查文件名是否已存在 image_hash = calculate_image_hash(image) if db.check_image_exists(image_hash): print(f"文件已存在", image_hash) raise Exception('File_Exists') # 上传到 HuggingFace file_path, image_filename = upload_to_huggingface(image_file_path) # print('>>>> image_file_path', image_file_path) # print('>>>> image_filename', image_filename) # print('>>>> file_path', file_path) # 如果没有描述,获取图片描述 if not description: image_cdn_url = '' if (PUBLIC_URL): image_cdn_url = f'{PUBLIC_URL}/gradio_api/file={image_file_path}' else: image_cdn_url = get_image_cdn_url(file_path) print('image_cdn_url',image_cdn_url) description = get_image_description(image_cdn_url) # 清理临时文件 # os.unlink(temp_file_path) # 存储到 Milvus db.store_sticker(title, description, tags, file_path, image_hash) return f"Upload successful! {image_filename}" except Exception as e: logger.error(f"Upload failed: {str(e)}") return f"Upload failed: {str(e)}" @staticmethod def import_stickers( sticker_dataset: str, upload: bool = False, save_to_milvus: bool = False, progress_callback: callable = None, ) -> List[str]: """导入表情包数据集 Args: sticker_dataset (str): 表情包数据集路径 upload (bool, optional): 是否上传到HuggingFace. Defaults to False. progress_callback (callable, optional): 进度回调函数. Defaults to None. """ results = [] descriptions = {} try: # 创建临时目录 cache_folder = os.path.join(TEMP_DIR, 'cache/') img_folder = os.path.join(TEMP_DIR, 'images/') data_json_path = os.path.join(cache_folder, 'data.json') stickers = [] logger.info(f"start import dataset") # 解压数据集 with zipfile.ZipFile(sticker_dataset, 'r') as zip_ref: zip_ref.extractall(cache_folder) logger.info(f"Extracted dataset to: {cache_folder}") # 尝试读取data.json文件 if os.path.exists(data_json_path): with open(data_json_path, 'r', encoding='utf-8') as f: data = json.load(f) descriptions = { x["filename"]: x["content"] for x in data } logger.info(f"Loaded descriptions from data.json") # 遍历解压后的目录 for root, dirs, files in os.walk(cache_folder): for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp')): image_path = os.path.join(root, file) try: # 打开图片 image = Image.open(image_path) image_hash = calculate_image_hash(image) if db.check_image_exists(image_hash): results.append(f"跳过已存在的图片: {file}") if progress_callback: progress_callback(file, "Skipped (exists)") continue # 获取图片描述 description = None if file in descriptions: description = descriptions[file] if not description: results.append(f"跳过无描述的图片: {file}") if progress_callback: progress_callback(file, "Skipped (no description)") continue image_filename = f"image_{random.randint(100000, 999999)}.png" file_path = f"images/{image_filename}" generate_temp_image(img_folder, image, image_filename) if save_to_milvus: db.store_sticker("", description, "", file_path, image_hash) stickers.append({ "title": "", "description": description, "tags": "", "file_path": file_path, "image_hash": image_hash }) if progress_callback: results.append(f"成功导入: {image_filename}") progress_callback(file, "Imported") except Exception as e: logger.error(f"Failed to process image {file}: {str(e)}") results.append(f"处理失败 {file}: {str(e)}") if progress_callback: progress_callback(file, f"Failed: {str(e)}") # 上传到 HuggingFace if upload and len(stickers) > 0: logger.info(f"upload to huggingface, {len(stickers)} stickers") upload_folder_to_huggingface(img_folder) results.append(f"上传到 HuggingFace 成功") return results except Exception as e: logger.error(f"Import failed: {str(e)}") results.append(f"导入失败: {str(e)}") return results finally: # 清理临时目录 if cache_folder and os.path.exists(cache_folder): shutil.rmtree(cache_folder) logger.info(f"Cleaned up temporary directory: {cache_folder}") if img_folder and os.path.exists(img_folder): shutil.rmtree(img_folder) logger.info(f"Cleaned up temporary directory: {img_folder}") @staticmethod def search_stickers(description: str, limit: int = 2, reranking : bool = False) -> List[Dict[str, Any]]: """搜索贴纸""" if not description: return [] try: results = db.search_stickers(description, limit) if (reranking): # 对搜索结果进行重排 results = StickerService.rerank_search_results(description, results, limit) return results except Exception as e: logger.error(f"Search failed: {str(e)}") return [] @staticmethod def get_all_stickers(limit: int = 1000) -> List[List]: """获取所有贴纸""" try: results = db.get_all_stickers(limit) return gradio_formatter.format_all_stickers(results) except Exception as e: logger.error(f"Failed to get all stickers: {str(e)}") return [] @staticmethod def delete_sticker(sticker_id: str) -> str: """删除贴纸""" try: # 首先查询贴纸是否存在 result = db.delete_sticker(sticker_id) return f"Sticker with ID {sticker_id} deleted successfully" except Exception as e: logger.error(f"Delete failed: {str(e)}") return f"Delete failed: {str(e)}" @staticmethod def rerank_search_results(query: str, sticker_list: List[Dict[str, Any]], limit: int = 5) -> List[Dict[str, Any]]: ## 使用 LLM 模型重新排序搜索结果 try: # 构建提示词 system_prompt = STICKER_RERANKING_SYSTEM_PROMPT # 构建用户提示词,包含查询和表情包信息 _sticker_list = [] for hit in sticker_list: _sticker_list.append({ "id": hit["id"], "description": hit["entity"]["description"] }) user_prompt = f"请分析关键词 '{query}' 与以下表情包的相关性:\n{_sticker_list}" print(f">>> 使用 LLM 模型重新排序....", user_prompt, system_prompt) # 调用 LLM 模型获取重排序结果 response = get_chat_completion(user_prompt, system_prompt) # 解析 LLM 返回的 JSON 结果 reranked_stickers = json.loads(response) # 验证返回结果格式 if not isinstance(reranked_stickers, list): raise ValueError("Invalid response format") # 按分数排序 reranked_stickers.sort(key=lambda x: float(x.get("score", 0)), reverse=True) print(f">>> LLM 排序结果", reranked_stickers) # 将重排序结果与原始结果对应 rerank_results = [] for sticker in reranked_stickers: for hit in sticker_list: if str(hit["id"]) == str(sticker["sticker_id"]): hit["entity"]["score"] = sticker["score"] hit["entity"]["reason"] = sticker["reason"] rerank_results.append(hit) break print(f">>> rerank_results", rerank_results) return rerank_results except Exception as e: logger.error(f"Reranking failed: {str(e)}") return [] # 创建服务实例 sticker_service = StickerService()