File size: 7,287 Bytes
751d628
 
 
 
 
 
 
 
 
 
 
4701375
751d628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4701375
751d628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4701375
 
 
751d628
 
 
 
4701375
751d628
 
 
 
 
 
4701375
 
751d628
 
 
4701375
751d628
 
4701375
 
 
 
 
 
 
 
751d628
 
 
 
 
 
4701375
 
751d628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import json
import os
import logging
import torch
from typing import List
from langchain_core.documents import Document
from sentence_transformers import SentenceTransformer
try:
    from datasets import load_dataset
except ImportError:
    load_dataset = None

logger = logging.getLogger(__name__)

def get_device():
    """
    Determine the appropriate device for PyTorch.
    
    Returns:
        str: Device name ('cuda', 'mps', or 'cpu').
    """
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    return "cpu"

def load_guest_dataset(dataset_path: str = "agents-course/unit3-invitees") -> List[Document]:
    """
    Load guest dataset from a local JSON file or Hugging Face dataset.
    
    Args:
        dataset_path (str): Path to local JSON file or Hugging Face dataset name.
    
    Returns:
        List[Document]: List of Document objects with guest information.
    """
    try:
        # Try loading from Hugging Face dataset if datasets library is available
        if load_dataset and not os.path.exists(dataset_path):
            logger.info(f"Attempting to load Hugging Face dataset: {dataset_path}")
            guest_dataset = load_dataset(dataset_path, split="train")
            docs = [
                Document(
                    page_content="\n".join([
                        f"Name: {guest['name']}",
                        f"Relation: {guest['relation']}",
                        f"Description: {guest['description']}",
                        f"Email: {guest['email']}"
                    ]),
                    metadata={
                        "name": guest["name"],
                        "relation": guest["relation"],
                        "description": guest["description"],
                        "email": guest["email"]
                    }
                )
                for guest in guest_dataset
            ]
            logger.info(f"Loaded {len(docs)} guests from Hugging Face dataset")
            return docs

        # Try loading from local JSON file
        if os.path.exists(dataset_path):
            logger.info(f"Loading guest dataset from local path: {dataset_path}")
            with open(dataset_path, 'r') as f:
                guests = json.load(f)
            docs = [
                Document(
                    page_content=guest.get('description', ''),
                    metadata={
                        'name': guest.get('name', ''),
                        'relation': guest.get('relation', ''),
                        'description': guest.get('description', ''),
                        'email': guest.get('email', '')  # Optional email field
                    }
                )
                for guest in guests
            ]
            logger.info(f"Loaded {len(docs)} guests from local JSON")
            return docs

        # Fallback to mock dataset if both fail
        logger.warning(f"Dataset not found at {dataset_path}, using mock dataset")
        docs = [
            Document(
                page_content="\n".join([
                    "Name: Dr. Nikola Tesla",
                    "Relation: old friend from university days",
                    "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
                    "Email: nikola.tesla@gmail.com"
                ]),
                metadata={
                    "name": "Dr. Nikola Tesla",
                    "relation": "old friend from university days",
                    "description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
                    "email": "nikola.tesla@gmail.com"
                }
            )
        ]
        logger.info("Loaded mock dataset with 1 guest")
        return docs

    except Exception as e:
        logger.error(f"Failed to load guest dataset: {e}")
        # Return mock dataset as final fallback
        docs = [
            Document(
                page_content="\n".join([
                    "Name: Dr. Nikola Tesla",
                    "Relation: old friend from university days",
                    "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
                    "Email: nikola.tesla@gmail.com"
                ]),
                metadata={
                    "name": "Dr. Nikola Tesla",
                    "relation": "old friend from university days",
                    "description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
                    "email": "nikola.tesla@gmail.com"
                }
            )
        ]
        logger.info("Loaded mock dataset with 1 guest due to error")
        return docs

class BM25Retriever:
    """
    A retriever class using SentenceTransformer for embedding-based search.
    """
    def __init__(self, dataset_path: str):
        """
        Initialize the retriever with a SentenceTransformer model.
        
        Args:
            dataset_path (str): Path to the dataset for retrieval.
        
        Raises:
            Exception: If embedder initialization fails.
        """
        try:
            self.model = SentenceTransformer("all-MiniLM-L6-v2", device=get_device())
            self.dataset_path = dataset_path
            logger.info("Initialized SentenceTransformer")
        except Exception as e:
            logger.error(f"Failed to initialize embedder: {e}")
            raise

    def search(self, query: str) -> List[dict]:
        """
        Search the dataset for relevant guest information.
        
        Args:
            query (str): Search query (e.g., guest name or relation).
        
        Returns:
            List[dict]: List of matching guest metadata dictionaries.
        """
        try:
            # Load dataset
            docs = load_guest_dataset(self.dataset_path)
            if not docs:
                logger.warning("No documents available for search")
                return []

            # Convert documents to text for BM25 (using metadata for consistency)
            texts = [f"{doc.metadata['name']} {doc.metadata['relation']} {doc.metadata['description']}" for doc in docs]
            from langchain_community.retrievers import BM25Retriever
            retriever = BM25Retriever.from_texts(texts)
            retriever.k = 3  # Limit to top 3 results

            # Perform search
            results = retriever.invoke(query)
            # Map results back to original metadata
            matches = [
                docs[i].metadata
                for i in range(len(docs))
                if any(f"{docs[i].metadata['name']} {docs[i].metadata['relation']} {docs[i].metadata['description']}" in r.page_content for r in results)
            ]
            logger.info(f"Found {len(matches)} matches for query: {query}")
            return matches[:3]  # Return top 3 matches

        except Exception as e:
            logger.error(f"Search failed for query '{query}': {e}")
            return []