File size: 7,801 Bytes
bd161ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7ed1ea
bd161ec
 
 
 
 
 
 
 
 
 
 
 
d7ed1ea
bd161ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
RAG (Retrieval Augmented Generation) service for semantic search and context retrieval.
"""
from typing import List, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text
import logging

from models import Phase, PhaseEmbedding
from services.openai_service import openai_service

logger = logging.getLogger(__name__)


class RAGService:
    """Service for RAG functionality with semantic search."""
    
    @staticmethod
    async def create_embedding(db: AsyncSession, phase_id: str, content: str) -> PhaseEmbedding:
        """Create and store embedding for phase content."""
        try:
            # Generate embedding
            embedding_vector = await openai_service.create_embedding(content)
            
            # Check if embedding already exists for this phase
            existing_result = await db.execute(
                select(PhaseEmbedding).where(PhaseEmbedding.phase_id == phase_id)
            )
            existing_embedding = existing_result.scalar_one_or_none()
            
            if existing_embedding:
                # Update existing embedding
                existing_embedding.content = content
                existing_embedding.embedding = embedding_vector
                await db.commit()
                return existing_embedding
            else:
                # Create new embedding
                phase_embedding = PhaseEmbedding(
                    phase_id=phase_id,
                    content=content,
                    embedding=embedding_vector
                )
                
                db.add(phase_embedding)
                await db.commit()
                await db.refresh(phase_embedding)
                
                logger.info(f"Created embedding for phase {phase_id}")
                return phase_embedding
                
        except Exception as e:
            logger.error(f"Failed to create embedding: {e}")
            raise Exception(f"Failed to create embedding: {str(e)}")
    
    @staticmethod
    async def search_similar_content(
        db: AsyncSession,
        query: str,
        project_id: str,
        limit: int = 5,
        similarity_threshold: float = 0.7
    ) -> List[Tuple[Phase, float]]:
        """Search for similar content using semantic similarity."""
        try:
            # Generate query embedding
            query_embedding = await openai_service.create_embedding(query)
            
            # Perform similarity search using pgvector
            # Note: Using cosine distance (1 - cosine_similarity)
            search_query = text("""
                SELECT p.*, pe.content, (1 - (pe.embedding <=> :query_embedding)) as similarity
                FROM phases p
                JOIN phase_embeddings pe ON p.id = pe.phase_id
                WHERE p.project_id = :project_id
                AND (1 - (pe.embedding <=> :query_embedding)) > :threshold
                ORDER BY similarity DESC
                LIMIT :limit
            """)
            
            result = await db.execute(
                search_query,
                {
                    "query_embedding": query_embedding,
                    "project_id": project_id,
                    "threshold": similarity_threshold,
                    "limit": limit
                }
            )
            
            # Process results
            similar_phases = []
            for row in result:
                # Get the full phase object
                phase_result = await db.execute(
                    select(Phase).where(Phase.id == row.id)
                )
                phase = phase_result.scalar_one()
                similarity = row.similarity
                
                similar_phases.append((phase, similarity))
            
            logger.info(f"Found {len(similar_phases)} similar phases for query in project {project_id}")
            return similar_phases
            
        except Exception as e:
            logger.error(f"Similarity search error: {e}")
            # Return empty results instead of raising exception
            return []
    
    @staticmethod
    async def get_context_for_phase(
        db: AsyncSession,
        project_id: str,
        current_phase_number: int,
        user_input: str
    ) -> Tuple[str, List[str]]:
        """Get relevant context for a phase using RAG and previous phases."""
        try:
            context_parts = []
            context_sources = []
            
            # 1. Get previous phases in order (sequential context)
            previous_phases_result = await db.execute(
                select(Phase)
                .where(
                    Phase.project_id == project_id,
                    Phase.phase_number < current_phase_number,
                    Phase.ai_response.isnot(None)
                )
                .order_by(Phase.phase_number)
            )
            previous_phases = previous_phases_result.scalars().all()
            
            # Add sequential context from previous phases
            for phase in previous_phases[-3:]:  # Last 3 phases for immediate context
                if phase.ai_response:
                    context_parts.append("Phase {} ({}):\n".format(phase.phase_number, phase.title) + phase.ai_response)
                    context_sources.append(f"Phase {phase.phase_number}")
            
            # 2. Get semantically similar content using RAG
            similar_phases = await RAGService.search_similar_content(
                db, user_input, project_id, limit=3, similarity_threshold=0.6
            )
            
            # Add RAG context (avoid duplicates from sequential context)
            added_phases = {p.phase_number for p in previous_phases[-3:]}
            for phase, similarity in similar_phases:
                if phase.phase_number not in added_phases and phase.ai_response:
                    context_parts.append(
                        "Related content from Phase {} ({}) [similarity: {:.2f}]:\n".format(phase.phase_number, phase.title, similarity) + phase.ai_response
                    )
                    context_sources.append(f"Phase {phase.phase_number} (RAG)")
                    added_phases.add(phase.phase_number)
            
            # Combine context
            full_context = "\n\n---\n\n".join(context_parts)
            
            # Truncate if too long (rough token limit)
            if len(full_context) > 6000:
                full_context = full_context[:6000] + "... [context truncated]"
            
            logger.info(f"Built context for phase {current_phase_number} with {len(context_sources)} sources")
            return full_context, context_sources
            
        except Exception as e:
            logger.error(f"Context building error: {e}")
            return "", []
    
    @staticmethod
    async def update_all_embeddings(db: AsyncSession, project_id: str):
        """Update embeddings for all phases in a project."""
        try:
            # Get all phases with content
            phases_result = await db.execute(
                select(Phase)
                .where(
                    Phase.project_id == project_id,
                    Phase.ai_response.isnot(None)
                )
            )
            phases = phases_result.scalars().all()
            
            for phase in phases:
                if phase.ai_response:
                    await RAGService.create_embedding(db, phase.id, phase.ai_response)
            
            logger.info(f"Updated embeddings for {len(phases)} phases in project {project_id}")
            
        except Exception as e:
            logger.error(f"Failed to update embeddings: {e}")
            raise Exception(f"Failed to update embeddings: {str(e)}")