File size: 5,326 Bytes
ad022d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9878f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from typing import Any, Iterable, List, Optional
from langchain_core.embeddings import Embeddings
import uuid
from langchain_community.vectorstores.lancedb import LanceDB

class MultimodalLanceDB(LanceDB):
    """`LanceDB` vector store to process multimodal data

    

    To use, you should have ``lancedb`` python package installed.

    You can install it with ``pip install lancedb``.



    Args:

        connection: LanceDB connection to use. If not provided, a new connection

                    will be created.

        embedding: Embedding to use for the vectorstore.

        vector_key: Key to use for the vector in the database. Defaults to ``vector``.

        id_key: Key to use for the id in the database. Defaults to ``id``.

        text_key: Key to use for the text in the database. Defaults to ``text``.

        image_path_key: Key to use for the path to image in the database. Defaults to ``image_path``.

        table_name: Name of the table to use. Defaults to ``vectorstore``.

        api_key: API key to use for LanceDB cloud database.

        region: Region to use for LanceDB cloud database.

        mode: Mode to use for adding data to the table. Defaults to ``overwrite``.







    Example:

        .. code-block:: python

            vectorstore = MultimodalLanceDB(uri='/lancedb', embedding_function)

            vectorstore.add_texts(['text1', 'text2'])

            result = vectorstore.similarity_search('text1')

    """
    
    def __init__(

        self,

        connection: Optional[Any] = None,

        embedding: Optional[Embeddings] = None,

        uri: Optional[str] = "/tmp/lancedb",

        vector_key: Optional[str] = "vector",

        id_key: Optional[str] = "id",

        text_key: Optional[str] = "text",

        image_path_key: Optional[str] = "image_path", 

        table_name: Optional[str] = "vectorstore",

        api_key: Optional[str] = None,

        region: Optional[str] = None,

        mode: Optional[str] = "append",

    ):
        super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode)
        self._image_path_key = image_path_key
        
    def add_text_image_pairs(

        self,

        texts: Iterable[str],

        image_paths: Iterable[str],

        metadatas: Optional[List[dict]] = None,

        ids: Optional[List[str]] = None,

        **kwargs: Any,

    ) -> List[str]:
        """Turn text-image pairs into embedding and add it to the database



        Args:

            texts: Iterable of strings to combine with corresponding images to add to the vectorstore.

            images: Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore.

            metadatas: Optional list of metadatas associated with the texts.

            ids: Optional list of ids to associate w    ith the texts.



        Returns:

            List of ids of the added text-image pairs.

        """
        # the length of texts must be equal to the length of images
        assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images"

        # Embed texts and create documents
        docs = []
        ids = ids or [str(uuid.uuid4()) for _ in texts]
        embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths))  # type: ignore
        for idx, text in enumerate(texts):
            embedding = embeddings[idx]
            metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
            docs.append(
                {
                    self._vector_key: embedding,
                    self._id_key: ids[idx],
                    self._text_key: text,
                    self._image_path_key : image_paths[idx],
                    "metadata": metadata,
                }
            )

        if 'mode' in kwargs:
            mode = kwargs['mode']
        else:
            mode = self.mode
        if self._table_name in self._connection.table_names():
            tbl = self._connection.open_table(self._table_name)
            if self.api_key is None:
                tbl.add(docs, mode=mode)
            else:
                tbl.add(docs)
        else:
            self._connection.create_table(self._table_name, data=docs)
        return ids

    @classmethod
    def from_text_image_pairs(

        cls,

        texts: List[str],

        image_paths: List[str],

        embedding: Embeddings,

        metadatas: Optional[List[dict]] = None,

        connection: Any = None,

        vector_key: Optional[str] = "vector",

        id_key: Optional[str] = "id",

        text_key: Optional[str] = "text",

        image_path_key: Optional[str] = "image_path",

        table_name: Optional[str] = "vectorstore",

        **kwargs: Any,

    ):

        instance = MultimodalLanceDB(
            connection=connection,
            embedding=embedding,
            vector_key=vector_key,
            id_key=id_key,
            text_key=text_key,
            image_path_key=image_path_key,
            table_name=table_name,
        )
        instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs)

        return instance