Spaces:
Sleeping
Sleeping
File size: 7,801 Bytes
bd161ec d7ed1ea bd161ec d7ed1ea bd161ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
"""
RAG (Retrieval Augmented Generation) service for semantic search and context retrieval.
"""
from typing import List, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text
import logging
from models import Phase, PhaseEmbedding
from services.openai_service import openai_service
logger = logging.getLogger(__name__)
class RAGService:
"""Service for RAG functionality with semantic search."""
@staticmethod
async def create_embedding(db: AsyncSession, phase_id: str, content: str) -> PhaseEmbedding:
"""Create and store embedding for phase content."""
try:
# Generate embedding
embedding_vector = await openai_service.create_embedding(content)
# Check if embedding already exists for this phase
existing_result = await db.execute(
select(PhaseEmbedding).where(PhaseEmbedding.phase_id == phase_id)
)
existing_embedding = existing_result.scalar_one_or_none()
if existing_embedding:
# Update existing embedding
existing_embedding.content = content
existing_embedding.embedding = embedding_vector
await db.commit()
return existing_embedding
else:
# Create new embedding
phase_embedding = PhaseEmbedding(
phase_id=phase_id,
content=content,
embedding=embedding_vector
)
db.add(phase_embedding)
await db.commit()
await db.refresh(phase_embedding)
logger.info(f"Created embedding for phase {phase_id}")
return phase_embedding
except Exception as e:
logger.error(f"Failed to create embedding: {e}")
raise Exception(f"Failed to create embedding: {str(e)}")
@staticmethod
async def search_similar_content(
db: AsyncSession,
query: str,
project_id: str,
limit: int = 5,
similarity_threshold: float = 0.7
) -> List[Tuple[Phase, float]]:
"""Search for similar content using semantic similarity."""
try:
# Generate query embedding
query_embedding = await openai_service.create_embedding(query)
# Perform similarity search using pgvector
# Note: Using cosine distance (1 - cosine_similarity)
search_query = text("""
SELECT p.*, pe.content, (1 - (pe.embedding <=> :query_embedding)) as similarity
FROM phases p
JOIN phase_embeddings pe ON p.id = pe.phase_id
WHERE p.project_id = :project_id
AND (1 - (pe.embedding <=> :query_embedding)) > :threshold
ORDER BY similarity DESC
LIMIT :limit
""")
result = await db.execute(
search_query,
{
"query_embedding": query_embedding,
"project_id": project_id,
"threshold": similarity_threshold,
"limit": limit
}
)
# Process results
similar_phases = []
for row in result:
# Get the full phase object
phase_result = await db.execute(
select(Phase).where(Phase.id == row.id)
)
phase = phase_result.scalar_one()
similarity = row.similarity
similar_phases.append((phase, similarity))
logger.info(f"Found {len(similar_phases)} similar phases for query in project {project_id}")
return similar_phases
except Exception as e:
logger.error(f"Similarity search error: {e}")
# Return empty results instead of raising exception
return []
@staticmethod
async def get_context_for_phase(
db: AsyncSession,
project_id: str,
current_phase_number: int,
user_input: str
) -> Tuple[str, List[str]]:
"""Get relevant context for a phase using RAG and previous phases."""
try:
context_parts = []
context_sources = []
# 1. Get previous phases in order (sequential context)
previous_phases_result = await db.execute(
select(Phase)
.where(
Phase.project_id == project_id,
Phase.phase_number < current_phase_number,
Phase.ai_response.isnot(None)
)
.order_by(Phase.phase_number)
)
previous_phases = previous_phases_result.scalars().all()
# Add sequential context from previous phases
for phase in previous_phases[-3:]: # Last 3 phases for immediate context
if phase.ai_response:
context_parts.append("Phase {} ({}):\n".format(phase.phase_number, phase.title) + phase.ai_response)
context_sources.append(f"Phase {phase.phase_number}")
# 2. Get semantically similar content using RAG
similar_phases = await RAGService.search_similar_content(
db, user_input, project_id, limit=3, similarity_threshold=0.6
)
# Add RAG context (avoid duplicates from sequential context)
added_phases = {p.phase_number for p in previous_phases[-3:]}
for phase, similarity in similar_phases:
if phase.phase_number not in added_phases and phase.ai_response:
context_parts.append(
"Related content from Phase {} ({}) [similarity: {:.2f}]:\n".format(phase.phase_number, phase.title, similarity) + phase.ai_response
)
context_sources.append(f"Phase {phase.phase_number} (RAG)")
added_phases.add(phase.phase_number)
# Combine context
full_context = "\n\n---\n\n".join(context_parts)
# Truncate if too long (rough token limit)
if len(full_context) > 6000:
full_context = full_context[:6000] + "... [context truncated]"
logger.info(f"Built context for phase {current_phase_number} with {len(context_sources)} sources")
return full_context, context_sources
except Exception as e:
logger.error(f"Context building error: {e}")
return "", []
@staticmethod
async def update_all_embeddings(db: AsyncSession, project_id: str):
"""Update embeddings for all phases in a project."""
try:
# Get all phases with content
phases_result = await db.execute(
select(Phase)
.where(
Phase.project_id == project_id,
Phase.ai_response.isnot(None)
)
)
phases = phases_result.scalars().all()
for phase in phases:
if phase.ai_response:
await RAGService.create_embedding(db, phase.id, phase.ai_response)
logger.info(f"Updated embeddings for {len(phases)} phases in project {project_id}")
except Exception as e:
logger.error(f"Failed to update embeddings: {e}")
raise Exception(f"Failed to update embeddings: {str(e)}")
|