File size: 5,326 Bytes
ad022d3 7d9878f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from typing import Any, Iterable, List, Optional
from langchain_core.embeddings import Embeddings
import uuid
from langchain_community.vectorstores.lancedb import LanceDB
class MultimodalLanceDB(LanceDB):
"""`LanceDB` vector store to process multimodal data
To use, you should have ``lancedb`` python package installed.
You can install it with ``pip install lancedb``.
Args:
connection: LanceDB connection to use. If not provided, a new connection
will be created.
embedding: Embedding to use for the vectorstore.
vector_key: Key to use for the vector in the database. Defaults to ``vector``.
id_key: Key to use for the id in the database. Defaults to ``id``.
text_key: Key to use for the text in the database. Defaults to ``text``.
image_path_key: Key to use for the path to image in the database. Defaults to ``image_path``.
table_name: Name of the table to use. Defaults to ``vectorstore``.
api_key: API key to use for LanceDB cloud database.
region: Region to use for LanceDB cloud database.
mode: Mode to use for adding data to the table. Defaults to ``overwrite``.
Example:
.. code-block:: python
vectorstore = MultimodalLanceDB(uri='/lancedb', embedding_function)
vectorstore.add_texts(['text1', 'text2'])
result = vectorstore.similarity_search('text1')
"""
def __init__(
self,
connection: Optional[Any] = None,
embedding: Optional[Embeddings] = None,
uri: Optional[str] = "/tmp/lancedb",
vector_key: Optional[str] = "vector",
id_key: Optional[str] = "id",
text_key: Optional[str] = "text",
image_path_key: Optional[str] = "image_path",
table_name: Optional[str] = "vectorstore",
api_key: Optional[str] = None,
region: Optional[str] = None,
mode: Optional[str] = "append",
):
super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode)
self._image_path_key = image_path_key
def add_text_image_pairs(
self,
texts: Iterable[str],
image_paths: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Turn text-image pairs into embedding and add it to the database
Args:
texts: Iterable of strings to combine with corresponding images to add to the vectorstore.
images: Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of ids to associate w ith the texts.
Returns:
List of ids of the added text-image pairs.
"""
# the length of texts must be equal to the length of images
assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images"
# Embed texts and create documents
docs = []
ids = ids or [str(uuid.uuid4()) for _ in texts]
embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths)) # type: ignore
for idx, text in enumerate(texts):
embedding = embeddings[idx]
metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
docs.append(
{
self._vector_key: embedding,
self._id_key: ids[idx],
self._text_key: text,
self._image_path_key : image_paths[idx],
"metadata": metadata,
}
)
if 'mode' in kwargs:
mode = kwargs['mode']
else:
mode = self.mode
if self._table_name in self._connection.table_names():
tbl = self._connection.open_table(self._table_name)
if self.api_key is None:
tbl.add(docs, mode=mode)
else:
tbl.add(docs)
else:
self._connection.create_table(self._table_name, data=docs)
return ids
@classmethod
def from_text_image_pairs(
cls,
texts: List[str],
image_paths: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
connection: Any = None,
vector_key: Optional[str] = "vector",
id_key: Optional[str] = "id",
text_key: Optional[str] = "text",
image_path_key: Optional[str] = "image_path",
table_name: Optional[str] = "vectorstore",
**kwargs: Any,
):
instance = MultimodalLanceDB(
connection=connection,
embedding=embedding,
vector_key=vector_key,
id_key=id_key,
text_key=text_key,
image_path_key=image_path_key,
table_name=table_name,
)
instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs)
return instance |