import gradio as gr import numpy as np import os import pandas as pd import faiss from huggingface_hub import hf_hub_download import time import json import fastapi from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn import threading import math # 创建安全缓存目录 CACHE_DIR = "/home/user/cache" os.makedirs(CACHE_DIR, exist_ok=True) # 减少内存占用 os.environ["OMP_NUM_THREADS"] = "2" os.environ["TOKENIZERS_PARALLELISM"] = "false" # 全局变量 index = None metadata = None # 新增全局变量 last_updated = 0 index_refresh_interval = 300 # 5分钟刷新一次 # 新增索引刷新函数 def refresh_index(): global index, metadata, last_updated while True: try: # 检查是否有更新 current_time = time.time() if current_time - last_updated > index_refresh_interval: print("🔄 检查索引更新...") # 获取最新元数据 METADATA_PATH = hf_hub_download( repo_id="GOGO198/GOGO_rag_index", filename="metadata.csv", cache_dir=CACHE_DIR, token=os.getenv("HF_TOKEN"), force_download=True # 强制更新 ) # 检查文件修改时间 file_mtime = os.path.getmtime(METADATA_PATH) if file_mtime > last_updated: print("📥 检测到新索引,重新加载...") # 重新加载索引 INDEX_PATH = hf_hub_download( repo_id="GOGO198/GOGO_rag_index", filename="faiss_index.bin", cache_dir=CACHE_DIR, token=os.getenv("HF_TOKEN"), force_download=True ) new_index = faiss.read_index(INDEX_PATH) new_metadata = pd.read_csv(METADATA_PATH) # 原子替换 index = new_index metadata = new_metadata last_updated = file_mtime print(f"✅ 索引更新完成 | 记录数: {len(metadata)}") except Exception as e: print(f"索引更新失败: {str(e)}") # 每30秒检查一次 time.sleep(30) def load_resources(): """加载所有必要资源(768维专用)""" global index, metadata # 清理残留锁文件 lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')] for lock_file in lock_files: try: os.remove(os.path.join(CACHE_DIR, lock_file)) print(f"🧹 清理锁文件: {lock_file}") except: pass if index is None or metadata is None: print("🔄 正在加载所有资源...") # 加载FAISS索引(768维) if index is None: print("📥 正在下载FAISS索引...") try: INDEX_PATH = hf_hub_download( repo_id="GOGO198/GOGO_rag_index", filename="faiss_index.bin", cache_dir=CACHE_DIR, token=os.getenv("HF_TOKEN") ) index = faiss.read_index(INDEX_PATH) if index.d != 768: raise ValueError("❌ 索引维度错误:预期768维") print(f"✅ FAISS索引加载完成 | 维度: {index.d}") except Exception as e: print(f"❌ FAISS索引加载失败: {str(e)}") raise # 加载元数据 if metadata is None: print("📄 正在下载元数据...") try: METADATA_PATH = hf_hub_download( repo_id="GOGO198/GOGO_rag_index", filename="metadata.csv", cache_dir=CACHE_DIR, token=os.getenv("HF_TOKEN") ) metadata = pd.read_csv(METADATA_PATH) print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}") except Exception as e: print(f"❌ 元数据加载失败: {str(e)}") raise # 启动索引刷新线程 threading.Thread(target=refresh_index, daemon=True).start() # 确保资源在API调用前加载 load_resources() def sanitize_floats(obj): if isinstance(obj, float): if math.isnan(obj) or math.isinf(obj): return 0.0 # 替换非法值为默认值 return obj elif isinstance(obj, dict): return {k: sanitize_floats(v) for k, v in obj.items()} elif isinstance(obj, list): return [sanitize_floats(x) for x in obj] else: return obj # 在返回结果前调用清理器 return { "status": "success", "results": sanitize_floats(results) # 深度清理 } def predict(vector): try: print(f"接收向量: {vector[:3]}... (长度: {len(vector)})") # 确保向量格式正确 query_vector = np.array(vector).astype('float32').reshape(1, -1) # 动态结果数量 (最大不超过总文档数) k = min(3, index.ntotal) if k == 0: return { "status": "success", "results": [], "message": "索引为空" } print(f"执行FAISS搜索 | k={k}") D, I = index.search(query_vector, k=k) # 打印搜索结果 print(f"搜索结果索引: {I[0]}") print(f"搜索距离: {D[0]}") # 构建结果 results = [] for i in range(k): try: idx = I[0][i] distance = D[0][i] # 修复1:处理非法浮点数 if not np.isfinite(distance) or distance < 0: distance = 100.0 # 设置为安全阈值 # 修复2:安全计算置信度 (0-1范围) confidence = 1 / (1 + distance) confidence = max(0.0, min(1.0, confidence)) # 钳制到[0,1] # 修复3:强制转换为合法浮点 distance = float(distance) confidence = float(confidence) result = { "source": metadata.iloc[idx]["source"], "content": metadata.iloc[idx].get("content", ""), "confidence": confidence, "distance": distance } results.append(result) except Exception as e: # 确保异常结果也符合JSON规范 results.append({ "error": str(e), "confidence": 0.5, "distance": 0.0 }) return { "status": "success", "results": sanitize_floats(results) } except Exception as e: # 返回错误响应 return { "status": "error", "message": f"服务器内部错误: {str(e)}", "details": sanitize_floats({"trace": traceback.format_exc()}) } # 创建FastAPI应用 app = FastAPI() # 添加CORS支持 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/predict") async def api_predict(request: Request): """API预测端点""" try: data = await request.json() vector = data.get("vector") if not vector or not isinstance(vector, list): return JSONResponse( status_code=400, content={"status": "error", "message": "无效输入: 需要向量数据"} ) result = predict(vector) return JSONResponse(content=result) except Exception as e: return JSONResponse( status_code=500, content={ "status": "error", "message": f"服务器内部错误了: {str(e)}" } ) # 启动应用 if __name__ == "__main__": # 验证资源 print("="*50) print("Space启动完成 | 准备接收请求") print(f"索引维度: {index.d}") print(f"元数据记录数: {len(metadata)}") print("="*50) # 只启动FastAPI服务 uvicorn.run(app, host="0.0.0.0", port=7860)