File size: 4,964 Bytes
5f3b20a
f10485f
 
 
5f3b20a
 
f10485f
 
 
 
5f3b20a
f10485f
5f3b20a
f10485f
5f3b20a
 
f10485f
5f3b20a
f10485f
 
5f3b20a
 
 
 
 
 
 
f10485f
5f3b20a
 
 
 
f10485f
5f3b20a
f10485f
 
 
5f3b20a
f10485f
5f3b20a
 
 
 
 
 
f10485f
5f3b20a
 
 
 
 
 
f10485f
5f3b20a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f10485f
5f3b20a
 
 
f10485f
 
 
 
 
 
 
5f3b20a
 
f10485f
5f3b20a
 
 
 
 
f10485f
 
5f3b20a
 
 
 
 
 
f10485f
5f3b20a
f10485f
5f3b20a
 
 
f10485f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import re
import glob
import time
import argparse
import logging
from collections import defaultdict

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings

# Logging Configuration
logging.getLogger().setLevel(logging.ERROR)

# Embedding model loading
def get_embeddings(model_name="intfloat/multilingual-e5-large-instruct", device="cuda"):
    print(f"[INFO] Embedding model device: {device}")
    return HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs={'device': device},
        encode_kwargs={'normalize_embeddings': True}
    )

def build_vector_store_batch(documents, embeddings, save_path="vector_db", batch_size=4):
    if not documents:
        raise ValueError("No documents found. Check if documents were loaded correctly.")

    texts = [doc.page_content for doc in documents]
    metadatas = [doc.metadata for doc in documents]

    # Print chunk lengths
    lengths = [len(t) for t in texts]
    print(f"💡 Number of chunks: {len(texts)}")
    print(f"💡 Longest chunk length: {max(lengths)} chars")
    print(f"💡 Average chunk length: {sum(lengths) // len(lengths)} chars")

    # Split into batches
    batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
    metadata_batches = [metadatas[i:i + batch_size] for i in range(0, len(metadatas), batch_size)]

    print(f"Processing {len(batches)} batches with size {batch_size}")
    print(f"Initializing vector store with batch 1/{len(batches)}")

    # Use from_documents
    first_docs = [
        Document(page_content=text, metadata=meta)
        for text, meta in zip(batches[0], metadata_batches[0])
    ]
    vectorstore = FAISS.from_documents(first_docs, embeddings)

    # Add remaining batches
    for i in tqdm(range(1, len(batches)), desc="Processing batches"):
        try:
            docs_batch = [
                Document(page_content=text, metadata=meta)
                for text, meta in zip(batches[i], metadata_batches[i])
            ]
            vectorstore.add_documents(docs_batch)

            if i % 10 == 0:
                temp_save_path = f"{save_path}_temp"
                os.makedirs(os.path.dirname(temp_save_path) if os.path.dirname(temp_save_path) else '.', exist_ok=True)
                vectorstore.save_local(temp_save_path)
                print(f"Temporary vector store saved to {temp_save_path} after batch {i}")

        except Exception as e:
            print(f"Error processing batch {i}: {e}")
            error_save_path = f"{save_path}_error_at_batch_{i}"
            os.makedirs(os.path.dirname(error_save_path) if os.path.dirname(error_save_path) else '.', exist_ok=True)
            vectorstore.save_local(error_save_path)
            print(f"Partial vector store saved to {error_save_path}")
            raise

    os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
    vectorstore.save_local(save_path)
    print(f"Vector store saved to {save_path}")

    return vectorstore

def load_vector_store(embeddings, load_path="vector_db"):
    if not os.path.exists(load_path):
        raise FileNotFoundError(f"Cannot find vector store: {load_path}")
    return FAISS.load_local(load_path, embeddings, allow_dangerous_deserialization=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Builds a vector store")
    parser.add_argument("--folder", type=str, default="final_dataset", help="Path to the folder containing the documents")
    parser.add_argument("--save_path", type=str, default="vector_db", help="Path to save the vector store")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
    parser.add_argument("--model_name", type=str, default="intfloat/multilingual-e5-large-instruct", help="Name of the embedding model")
    parser.add_argument("--device", type=str, default="cuda", help="Device to use ('cuda' or 'cpu' or 'cuda:0')") #Ermöglicht cuda:0
    
    args = parser.parse_args()

    # Import the document processing module
    from document_processor_image_test import load_documents, split_documents

    documents = load_documents(args.folder)
    chunks = split_documents(documents, chunk_size=800, chunk_overlap=100)

    print(f"[DEBUG] Document loading and chunk splitting complete, entering embedding stage")
    print(f"[INFO] Selected device: {args.device}")

    try:
        embeddings = get_embeddings(
            model_name=args.model_name,
            device=args.device
        )
        print(f"[DEBUG] Embedding model created")
    except Exception as e:
        print(f"[ERROR] Error creating embedding model: {e}")
        import traceback; traceback.print_exc()
        exit(1)

    build_vector_store_batch(chunks, embeddings, args.save_path, args.batch_size)