GOGO_rag / app.py
GOGO198's picture
Update app.py
7de46e8 verified
import gradio as gr
import numpy as np
import os
import pandas as pd
import faiss
from huggingface_hub import hf_hub_download
import time
import json
import fastapi
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import threading
import math
# 创建安全缓存目录
CACHE_DIR = "/home/user/cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# 减少内存占用
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 全局变量
index = None
metadata = None
# 新增全局变量
last_updated = 0
index_refresh_interval = 300 # 5分钟刷新一次
# 新增索引刷新函数
def refresh_index():
global index, metadata, last_updated
while True:
try:
# 检查是否有更新
current_time = time.time()
if current_time - last_updated > index_refresh_interval:
print("🔄 检查索引更新...")
# 获取最新元数据
METADATA_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="metadata.csv",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN"),
force_download=True # 强制更新
)
# 检查文件修改时间
file_mtime = os.path.getmtime(METADATA_PATH)
if file_mtime > last_updated:
print("📥 检测到新索引,重新加载...")
# 重新加载索引
INDEX_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="faiss_index.bin",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN"),
force_download=True
)
new_index = faiss.read_index(INDEX_PATH)
new_metadata = pd.read_csv(METADATA_PATH)
# 原子替换
index = new_index
metadata = new_metadata
last_updated = file_mtime
print(f"✅ 索引更新完成 | 记录数: {len(metadata)}")
except Exception as e:
print(f"索引更新失败: {str(e)}")
# 每30秒检查一次
time.sleep(30)
def load_resources():
"""加载所有必要资源(768维专用)"""
global index, metadata
# 清理残留锁文件
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
for lock_file in lock_files:
try:
os.remove(os.path.join(CACHE_DIR, lock_file))
print(f"🧹 清理锁文件: {lock_file}")
except:
pass
if index is None or metadata is None:
print("🔄 正在加载所有资源...")
# 加载FAISS索引(768维)
if index is None:
print("📥 正在下载FAISS索引...")
try:
INDEX_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="faiss_index.bin",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
index = faiss.read_index(INDEX_PATH)
if index.d != 768:
raise ValueError("❌ 索引维度错误:预期768维")
print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
except Exception as e:
print(f"❌ FAISS索引加载失败: {str(e)}")
raise
# 加载元数据
if metadata is None:
print("📄 正在下载元数据...")
try:
METADATA_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="metadata.csv",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
metadata = pd.read_csv(METADATA_PATH)
print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
except Exception as e:
print(f"❌ 元数据加载失败: {str(e)}")
raise
# 启动索引刷新线程
threading.Thread(target=refresh_index, daemon=True).start()
# 确保资源在API调用前加载
load_resources()
def sanitize_floats(obj):
if isinstance(obj, float):
if math.isnan(obj) or math.isinf(obj):
return 0.0 # 替换非法值为默认值
return obj
elif isinstance(obj, dict):
return {k: sanitize_floats(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [sanitize_floats(x) for x in obj]
else:
return obj
# 在返回结果前调用清理器
return {
"status": "success",
"results": sanitize_floats(results) # 深度清理
}
def predict(vector):
try:
print(f"接收向量: {vector[:3]}... (长度: {len(vector)})")
# 确保向量格式正确
query_vector = np.array(vector).astype('float32').reshape(1, -1)
# 动态结果数量 (最大不超过总文档数)
k = min(3, index.ntotal)
if k == 0:
return {
"status": "success",
"results": [],
"message": "索引为空"
}
print(f"执行FAISS搜索 | k={k}")
D, I = index.search(query_vector, k=k)
# 打印搜索结果
print(f"搜索结果索引: {I[0]}")
print(f"搜索距离: {D[0]}")
# 构建结果
results = []
for i in range(k):
try:
idx = I[0][i]
distance = D[0][i]
# 修复1:处理非法浮点数
if not np.isfinite(distance) or distance < 0:
distance = 100.0 # 设置为安全阈值
# 修复2:安全计算置信度 (0-1范围)
confidence = 1 / (1 + distance)
confidence = max(0.0, min(1.0, confidence)) # 钳制到[0,1]
# 修复3:强制转换为合法浮点
distance = float(distance)
confidence = float(confidence)
result = {
"source": metadata.iloc[idx]["source"],
"content": metadata.iloc[idx].get("content", ""),
"confidence": confidence,
"distance": distance
}
results.append(result)
except Exception as e:
# 确保异常结果也符合JSON规范
results.append({
"error": str(e),
"confidence": 0.5,
"distance": 0.0
})
return {
"status": "success",
"results": sanitize_floats(results)
}
except Exception as e:
# 返回错误响应
return {
"status": "error",
"message": f"服务器内部错误: {str(e)}",
"details": sanitize_floats({"trace": traceback.format_exc()})
}
# 创建FastAPI应用
app = FastAPI()
# 添加CORS支持
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def api_predict(request: Request):
"""API预测端点"""
try:
data = await request.json()
vector = data.get("vector")
if not vector or not isinstance(vector, list):
return JSONResponse(
status_code=400,
content={"status": "error", "message": "无效输入: 需要向量数据"}
)
result = predict(vector)
return JSONResponse(content=result)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"status": "error",
"message": f"服务器内部错误了: {str(e)}"
}
)
# 启动应用
if __name__ == "__main__":
# 验证资源
print("="*50)
print("Space启动完成 | 准备接收请求")
print(f"索引维度: {index.d}")
print(f"元数据记录数: {len(metadata)}")
print("="*50)
# 只启动FastAPI服务
uvicorn.run(app, host="0.0.0.0", port=7860)