File size: 4,100 Bytes
71a0948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import uvicorn
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone, ServerlessSpec
import uuid
import os
from contextlib import asynccontextmanager

# --- Environment Setup ---
# It's best practice to get sensitive keys from environment variables
# We will set these up in Hugging Face Spaces Secrets
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "memoria-index")

# --- Global objects ---
# We load these once at startup to save time and memory
model = None
pc = None
index = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Handles startup and shutdown events for the FastAPI app.
    Loads the model and connects to Pinecone on startup.
    """
    global model, pc, index
    print("Application startup...")

    if not PINECONE_API_KEY:
        raise ValueError("PINECONE_API_KEY environment variable not set.")

    # 1. Load the AI Model
    print("Loading lightweight sentence transformer model...")
    model = SentenceTransformer('sentence-transformers/paraphrase-albert-small-v2')
    print("Model loaded.")

    # 2. Connect to Pinecone
    print("Connecting to Pinecone...")
    pc = Pinecone(api_key=PINECONE_API_KEY)

    # 3. Get or create the Pinecone index
    if PINECONE_INDEX_NAME not in pc.list_indexes().names():
        print(f"Creating new Pinecone index: {PINECONE_INDEX_NAME}")
        pc.create_index(
            name=PINECONE_INDEX_NAME,
            dimension=model.get_sentence_embedding_dimension(),
            metric="cosine", # Cosine similarity is great for sentence vectors
            spec=ServerlessSpec(cloud="aws", region="us-east-1")
        )
    index = pc.Index(PINECONE_INDEX_NAME)
    print("Pinecone setup complete.")
    yield
    # Cleanup logic can go here if needed on shutdown
    print("Application shutdown.")

# --- Pydantic Models ---
class Memory(BaseModel):
    content: str

class SearchQuery(BaseModel):
    query: str

# --- FastAPI App ---
app = FastAPI(
    title="Memoria API",
    description="API for storing and retrieving memories.",
    version="1.0.0",
    lifespan=lifespan # Use the lifespan context manager
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], # Allow all origins for simplicity
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- API Endpoints ---
@app.get("/")
def read_root():
    return {"status": "ok", "message": "Welcome to the Memoria API!"}

@app.post("/save_memory")
def save_memory(memory: Memory):
    try:
        embedding = model.encode(memory.content).tolist()
        memory_id = str(uuid.uuid4())
        
        # Upsert (update or insert) the vector into Pinecone
        index.upsert(vectors=[{"id": memory_id, "values": embedding, "metadata": {"text": memory.content}}])
        
        print(f"Successfully saved memory with ID: {memory_id}")
        return {"status": "success", "id": memory_id}
    except Exception as e:
        print(f"An error occurred during save: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/search_memory")
def search_memory(search: SearchQuery):
    try:
        query_embedding = model.encode(search.query).tolist()
        
        # Query Pinecone for the most similar vectors
        results = index.query(vector=query_embedding, top_k=5, include_metadata=True)
        
        # Extract the original text from the metadata
        retrieved_documents = [match['metadata']['text'] for match in results['matches']]
        
        print(f"Found {len(retrieved_documents)} results for query: '{search.query}'")
        return {"status": "success", "results": retrieved_documents}
    except Exception as e:
        print(f"An error occurred during search: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)