from fastapi import FastAPI, Request, HTTPException, Depends, Header from pydantic import BaseModel, Field from sentence_transformers import SentenceTransformer import numpy as np import logging, os # 设置日志记录 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 定义依赖项来校验 Authorization async def check_authorization(authorization: str = Header(..., alias="Authorization")): # 去掉 Bearer 和后面的空格 if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid Authorization header format") token = authorization[len("Bearer "):] if token != os.environ.get("AUTHORIZATION"): raise HTTPException(status_code=401, detail="Unauthorized access") return token app = FastAPI() try: # Load the Sentence Transformer model model = SentenceTransformer("BAAI/bge-large-zh-v1.5") except Exception as e: logger.error(f"Failed to load model: {e}") raise HTTPException(status_code=500, detail="Model loading failed") class EmbeddingRequest(BaseModel): input: str = Field(..., min_length=1, max_length=1000) @app.post("/v1/embeddings") async def embeddings(request: EmbeddingRequest, authorization: str = Depends(check_authorization)): # async def embeddings(request: EmbeddingRequest): # logger.info("Received request for embeddings") # return '2222222222' # return request.input input_texts = request.input try: if not input_texts: return { "object": "list", "data": [], "model": "BAAI/bge-large-zh-v1.5", "usage": { "prompt_tokens": 0, "total_tokens": 0 } } # Calculate embeddings # embeddings = model.encode(input_text) embeddings = model.encode([t for t in input_texts], normalize_embeddings=True) # Format the embeddings in OpenAI compatible format data = { "object": "list", "data": [ { "object": "embedding", "embedding": embeddings.tolist(), "index": 0 } ], "model": "BAAI/bge-large-zh-v1.5", "usage": { "prompt_tokens": len(input_texts), "total_tokens": len(input_texts) } } return data except Exception as e: logger.error(f"Error processing embeddings: {e}") raise HTTPException(status_code=500, detail="Internal Server Error")