|
import gradio as gr |
|
import numpy as np |
|
import os |
|
import pandas as pd |
|
import faiss |
|
from huggingface_hub import hf_hub_download |
|
import time |
|
import json |
|
import fastapi |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import uvicorn |
|
import threading |
|
import math |
|
|
|
|
|
CACHE_DIR = "/home/user/cache" |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "2" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
index = None |
|
metadata = None |
|
|
|
|
|
last_updated = 0 |
|
index_refresh_interval = 300 |
|
|
|
|
|
def refresh_index(): |
|
global index, metadata, last_updated |
|
|
|
while True: |
|
try: |
|
|
|
current_time = time.time() |
|
if current_time - last_updated > index_refresh_interval: |
|
print("🔄 检查索引更新...") |
|
|
|
|
|
METADATA_PATH = hf_hub_download( |
|
repo_id="GOGO198/GOGO_rag_index", |
|
filename="metadata.csv", |
|
cache_dir=CACHE_DIR, |
|
token=os.getenv("HF_TOKEN"), |
|
force_download=True |
|
) |
|
|
|
|
|
file_mtime = os.path.getmtime(METADATA_PATH) |
|
if file_mtime > last_updated: |
|
print("📥 检测到新索引,重新加载...") |
|
|
|
|
|
INDEX_PATH = hf_hub_download( |
|
repo_id="GOGO198/GOGO_rag_index", |
|
filename="faiss_index.bin", |
|
cache_dir=CACHE_DIR, |
|
token=os.getenv("HF_TOKEN"), |
|
force_download=True |
|
) |
|
new_index = faiss.read_index(INDEX_PATH) |
|
new_metadata = pd.read_csv(METADATA_PATH) |
|
|
|
|
|
index = new_index |
|
metadata = new_metadata |
|
last_updated = file_mtime |
|
|
|
print(f"✅ 索引更新完成 | 记录数: {len(metadata)}") |
|
|
|
except Exception as e: |
|
print(f"索引更新失败: {str(e)}") |
|
|
|
|
|
time.sleep(30) |
|
|
|
def load_resources(): |
|
"""加载所有必要资源(768维专用)""" |
|
global index, metadata |
|
|
|
|
|
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')] |
|
for lock_file in lock_files: |
|
try: |
|
os.remove(os.path.join(CACHE_DIR, lock_file)) |
|
print(f"🧹 清理锁文件: {lock_file}") |
|
except: |
|
pass |
|
|
|
if index is None or metadata is None: |
|
print("🔄 正在加载所有资源...") |
|
|
|
|
|
if index is None: |
|
print("📥 正在下载FAISS索引...") |
|
try: |
|
INDEX_PATH = hf_hub_download( |
|
repo_id="GOGO198/GOGO_rag_index", |
|
filename="faiss_index.bin", |
|
cache_dir=CACHE_DIR, |
|
token=os.getenv("HF_TOKEN") |
|
) |
|
index = faiss.read_index(INDEX_PATH) |
|
|
|
if index.d != 768: |
|
raise ValueError("❌ 索引维度错误:预期768维") |
|
|
|
print(f"✅ FAISS索引加载完成 | 维度: {index.d}") |
|
except Exception as e: |
|
print(f"❌ FAISS索引加载失败: {str(e)}") |
|
raise |
|
|
|
|
|
if metadata is None: |
|
print("📄 正在下载元数据...") |
|
try: |
|
METADATA_PATH = hf_hub_download( |
|
repo_id="GOGO198/GOGO_rag_index", |
|
filename="metadata.csv", |
|
cache_dir=CACHE_DIR, |
|
token=os.getenv("HF_TOKEN") |
|
) |
|
metadata = pd.read_csv(METADATA_PATH) |
|
print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}") |
|
except Exception as e: |
|
print(f"❌ 元数据加载失败: {str(e)}") |
|
raise |
|
|
|
|
|
threading.Thread(target=refresh_index, daemon=True).start() |
|
|
|
|
|
load_resources() |
|
|
|
def sanitize_floats(obj): |
|
if isinstance(obj, float): |
|
if math.isnan(obj) or math.isinf(obj): |
|
return 0.0 |
|
return obj |
|
elif isinstance(obj, dict): |
|
return {k: sanitize_floats(v) for k, v in obj.items()} |
|
elif isinstance(obj, list): |
|
return [sanitize_floats(x) for x in obj] |
|
else: |
|
return obj |
|
|
|
|
|
return { |
|
"status": "success", |
|
"results": sanitize_floats(results) |
|
} |
|
|
|
def predict(vector): |
|
try: |
|
print(f"接收向量: {vector[:3]}... (长度: {len(vector)})") |
|
|
|
|
|
query_vector = np.array(vector).astype('float32').reshape(1, -1) |
|
|
|
|
|
k = min(3, index.ntotal) |
|
if k == 0: |
|
return { |
|
"status": "success", |
|
"results": [], |
|
"message": "索引为空" |
|
} |
|
|
|
print(f"执行FAISS搜索 | k={k}") |
|
D, I = index.search(query_vector, k=k) |
|
|
|
|
|
print(f"搜索结果索引: {I[0]}") |
|
print(f"搜索距离: {D[0]}") |
|
|
|
|
|
results = [] |
|
for i in range(k): |
|
try: |
|
idx = I[0][i] |
|
distance = D[0][i] |
|
|
|
|
|
if not np.isfinite(distance) or distance < 0: |
|
distance = 100.0 |
|
|
|
|
|
confidence = 1 / (1 + distance) |
|
confidence = max(0.0, min(1.0, confidence)) |
|
|
|
|
|
distance = float(distance) |
|
confidence = float(confidence) |
|
|
|
result = { |
|
"source": metadata.iloc[idx]["source"], |
|
"content": metadata.iloc[idx].get("content", ""), |
|
"confidence": confidence, |
|
"distance": distance |
|
} |
|
results.append(result) |
|
except Exception as e: |
|
|
|
results.append({ |
|
"error": str(e), |
|
"confidence": 0.5, |
|
"distance": 0.0 |
|
}) |
|
|
|
return { |
|
"status": "success", |
|
"results": sanitize_floats(results) |
|
} |
|
except Exception as e: |
|
|
|
return { |
|
"status": "error", |
|
"message": f"服务器内部错误: {str(e)}", |
|
"details": sanitize_floats({"trace": traceback.format_exc()}) |
|
} |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.post("/predict") |
|
async def api_predict(request: Request): |
|
"""API预测端点""" |
|
try: |
|
data = await request.json() |
|
vector = data.get("vector") |
|
|
|
if not vector or not isinstance(vector, list): |
|
return JSONResponse( |
|
status_code=400, |
|
content={"status": "error", "message": "无效输入: 需要向量数据"} |
|
) |
|
|
|
result = predict(vector) |
|
return JSONResponse(content=result) |
|
|
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=500, |
|
content={ |
|
"status": "error", |
|
"message": f"服务器内部错误了: {str(e)}" |
|
} |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
print("="*50) |
|
print("Space启动完成 | 准备接收请求") |
|
print(f"索引维度: {index.d}") |
|
print(f"元数据记录数: {len(metadata)}") |
|
print("="*50) |
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |