Spaces:
Runtime error
Runtime error
File size: 7,665 Bytes
63deadc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import datetime
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
"""Get the hours passed between two datetimes."""
return (time - ref_time).total_seconds() / 3600
class TimeWeightedVectorStoreRetriever(BaseRetriever):
"""Retriever that combines embedding similarity with
recency in retrieving values."""
vectorstore: VectorStore
"""The vectorstore to store documents and determine salience."""
search_kwargs: dict = Field(default_factory=lambda: dict(k=100))
"""Keyword arguments to pass to the vectorstore similarity search."""
# TODO: abstract as a queue
memory_stream: List[Document] = Field(default_factory=list)
"""The memory_stream of documents to search through."""
decay_rate: float = Field(default=0.01)
"""The exponential decay factor used as (1.0-decay_rate)**(hrs_passed)."""
k: int = 4
"""The maximum number of documents to retrieve in a given call."""
other_score_keys: List[str] = []
"""Other keys in the metadata to factor into the score, e.g. 'importance'."""
default_salience: Optional[float] = None
"""The salience to assign memories not retrieved from the vector store.
None assigns no salience to documents not fetched from the vector store.
"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _document_get_date(self, field: str, document: Document) -> datetime.datetime:
"""Return the value of the date field of a document."""
if field in document.metadata:
if isinstance(document.metadata[field], float):
return datetime.datetime.fromtimestamp(document.metadata[field])
return document.metadata[field]
return datetime.datetime.now()
def _get_combined_score(
self,
document: Document,
vector_relevance: Optional[float],
current_time: datetime.datetime,
) -> float:
"""Return the combined score for a document."""
hours_passed = _get_hours_passed(
current_time,
self._document_get_date("last_accessed_at", document),
)
score = (1.0 - self.decay_rate) ** hours_passed
for key in self.other_score_keys:
if key in document.metadata:
score += document.metadata[key]
if vector_relevance is not None:
score += vector_relevance
return score
def get_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
"""Return documents that are salient to the query."""
docs_and_scores: List[Tuple[Document, float]]
docs_and_scores = self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
results = {}
for fetched_doc, relevance in docs_and_scores:
if "buffer_idx" in fetched_doc.metadata:
buffer_idx = fetched_doc.metadata["buffer_idx"]
doc = self.memory_stream[buffer_idx]
results[buffer_idx] = (doc, relevance)
return results
async def aget_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
"""Return documents that are salient to the query."""
docs_and_scores: List[Tuple[Document, float]]
docs_and_scores = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
results = {}
for fetched_doc, relevance in docs_and_scores:
if "buffer_idx" in fetched_doc.metadata:
buffer_idx = fetched_doc.metadata["buffer_idx"]
doc = self.memory_stream[buffer_idx]
results[buffer_idx] = (doc, relevance)
return results
def _get_rescored_docs(
self, docs_and_scores: Dict[Any, Tuple[Document, Optional[float]]]
) -> List[Document]:
current_time = datetime.datetime.now()
rescored_docs = [
(doc, self._get_combined_score(doc, relevance, current_time))
for doc, relevance in docs_and_scores.values()
]
rescored_docs.sort(key=lambda x: x[1], reverse=True)
result = []
# Ensure frequently accessed memories aren't forgotten
for doc, _ in rescored_docs[: self.k]:
# TODO: Update vector store doc once `update` method is exposed.
buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]]
buffered_doc.metadata["last_accessed_at"] = current_time
result.append(buffered_doc)
return result
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(self.get_salient_docs(query))
return self._get_rescored_docs(docs_and_scores)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(await self.aget_salient_docs(query))
return self._get_rescored_docs(docs_and_scores)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time")
if current_time is None:
current_time = datetime.datetime.now()
# Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs):
if "last_accessed_at" not in doc.metadata:
doc.metadata["last_accessed_at"] = current_time
if "created_at" not in doc.metadata:
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return self.vectorstore.add_documents(dup_docs, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time")
if current_time is None:
current_time = datetime.datetime.now()
# Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs):
if "last_accessed_at" not in doc.metadata:
doc.metadata["last_accessed_at"] = current_time
if "created_at" not in doc.metadata:
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return await self.vectorstore.aadd_documents(dup_docs, **kwargs)
|