File size: 7,864 Bytes
d57551e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
# cosmosConnector.py
from jsonschema import ValidationError
from langchain_openai import AzureOpenAIEmbeddings
from models.converterModels import PowerConverter
from models.converterVectorStoreModels import PowerConverterVector
import os
from azure.cosmos import CosmosClient
from typing import List, Optional, Dict
from rapidfuzz import fuzz
import logging
import os
from dotenv import load_dotenv
from semantic_kernel.functions import kernel_function
load_dotenv()
# Initialize logging
logger = logging.getLogger(__name__)
class CosmosLampHandler:
def __init__(self, logger: Optional[logging.Logger] = None):
self.client = CosmosClient(
os.getenv("AZURE_COSMOS_DB_ENDPOINT"),
os.getenv("AZURE_COSMOS_DB_KEY")
)
self.database = self.client.get_database_client("TAL_DB")
self.container = self.database.get_container_client("Converters_with_embeddings")
self.logger = logger
self.embedding_model = AzureOpenAIEmbeddings(
azure_endpoint=os.environ["OPENAI_API_ENDPOINT"],
azure_deployment=os.environ["OPENAI_EMBEDDINGS_MODEL_DEPLOYMENT"],
api_key=os.environ["AZURE_OPENAI_KEY"]
)
async def _generate_embedding(self, query: str) -> List[float]:
"""Generate embedding for the given query using Azure OpenAI"""
try:
return self.embedding_model.embed_query(query)
except Exception as e:
logger.error(f"Embedding generation failed: {str(e)}")
raise
async def get_compatible_lamps(self, artnr: int) -> List[str]:
"""Get compatible lamps for a converter with fuzzy matching"""
try:
parameters = [{"name": "@artnr", "value": artnr}]
query = "SELECT * FROM c WHERE c.artnr = @artnr"
# Collect results properly
results = [item for item in list(self.container.query_items(
query=query,
parameters=parameters
))]
if not results:
return []
return list(results[0]["lamps"].keys())
except Exception as e:
logger.error(f"Failed to get compatible lamps: {str(e)}")
return []
async def get_converters_by_lamp_type(self, lamp_type: str, threshold: int = 75) -> List[PowerConverter]:
"""Get converters with fuzzy-matched lamp types"""
try:
# Case-insensitive search with fuzzy matching
query = """
SELECT
*
FROM c WHERE IS_DEFINED(c.lamps)"""
converters = []
results = list(self.container.query_items(
query=query,
enable_cross_partition_query=True))
for item in results:
lamp_keys = item.get("lamps", {}).keys()
matches = [key for key in lamp_keys
if fuzz.ratio(key.lower(), lamp_type.lower()) >= threshold]
if matches:
converters.append(PowerConverter(**item))
return converters
except Exception as e:
logger.error(f"Lamp type search failed: {str(e)}")
return []
async def get_lamp_limits(self, artnr: int, lamp_type: str) -> Dict[str, int]:
"""Get lamp limits with typo tolerance"""
try:
parameters = [{"name": "@artnr", "value": artnr}]
query = """
SELECT c.lamps FROM c
WHERE c.artnr = @artnr
"""
results_iter = list(self.container.query_items(
query=query,
parameters=parameters
))
results = [item for item in results_iter] # Collect results asynchronously
if not results:
return {}
lamps = results[0]["lamps"]
# Fuzzy match lamp type
best_match = max(
lamps.keys(),
key=lambda x: fuzz.ratio(x.lower(), lamp_type.lower())
)
if fuzz.ratio(best_match.lower(), lamp_type.lower()) < 65:
raise ValueError("No matching lamp type found")
return {
"min": int(lamps[best_match]["min"]),
"max": int(lamps[best_match]["max"])
}
except Exception as e:
logger.error(f"Failed to get lamp limits: {str(e)}")
raise
async def query_converters(self, query: str) -> str:
try:
print(f"Executing query: {query}")
items = list(self.container.query_items(
query=query,
enable_cross_partition_query=True
))
print(f"Query returned {len(items)} items")
items = items[:10]
# self.logger.debug(f"Raw items: {items}")
items = [PowerConverter(**item) for item in items] if items else []
self.logger.info(f"Query returned {len(items)} items after conversion")
return str(items)
except Exception as e:
self.logger.info(f"Query failed: {str(e)}")
return f"Query failed: {str(e)}"
async def RAG_search(self, query: str, artnr: Optional[int] = None, threshold: int = 75) -> List[PowerConverterVector]:
"""Hybrid search using raw Cosmos DB vector search"""
try:
# Generate embedding
print(f"Performing hybrid search for query: {query} (ARTNR: {artnr})")
query_vector = await self._generate_embedding(query)
sql_query = """
SELECT TOP 5
c.id,
c.converter_description,
c.ip,
c.efficiency_full_load,
c.name,
c.artnr,
c.type,
c.lamps,
c.pdf_link,
c.nom_input_voltage_v,
c.output_voltage_v,
c.unit,
c["listprice"],
c["lifecycle"],
c.size,
c.ccr_amplitude,
c.dimmability,
c.dimlist_type,
c.strain_relief,
c.gross_weight,
VectorDistance(c.embedding, @vector) AS SimilarityScore
FROM c
ORDER BY VectorDistance(c.embedding, @vector)
"""
parameters = [{"name": "@vector", "value": query_vector}]
# Execute query
results = list(self.container.query_items(
query=sql_query,
parameters=parameters,
enable_cross_partition_query=True
))
items = []
for item in results:
items.append(item)
converters = []
for item in items:
# Convert float values to integers before validation
if "lamps" in item:
for lamp_key in item["lamps"]:
lamp_data = item["lamps"][lamp_key]
lamp_data["min"] = int(lamp_data["min"])
lamp_data["max"] = int(lamp_data["max"])
converters.append(PowerConverterVector(**item))
return converters
except ValidationError as exc:
print(exc)
except Exception as e:
logger.error(f"Hybrid search failed: {str(e)}")
print(f"Hybrid search failed: {str(e)}")
|