Spaces:
Running
Running
""" | |
Phase management router for the 14-phase workflow system. | |
""" | |
from typing import List | |
from fastapi import APIRouter, Depends, HTTPException, status | |
from sqlalchemy import select, update | |
import logging | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from database import get_async_db | |
from models import User, Project, Phase, PhaseDraft, PhaseStatus | |
from schemas import ( | |
PhaseResponse, PhaseUpdate, PhaseGenerateRequest, PhaseGenerateResponse, | |
PhaseDraftResponse, APIResponse | |
) | |
from dependencies import get_current_active_user, check_project_access | |
from services.phase_service import PhaseService | |
from services.rag_service import RAGService | |
logger = logging.getLogger(__name__) | |
router = APIRouter() | |
async def get_project_phases( | |
project_id: str, | |
project: Project = Depends(check_project_access), | |
db: AsyncSession = Depends(get_async_db) | |
): | |
"""Get all phases for a project (async).""" | |
result = await db.execute(select(Phase).where(Phase.project_id == project_id).order_by(Phase.phase_number)) | |
phases = result.scalars().all() | |
return phases | |
async def get_phase( | |
project_id: str, | |
phase_number: int, | |
project: Project = Depends(check_project_access), | |
db: AsyncSession = Depends(get_async_db) | |
): | |
"""Get a specific phase.""" | |
if not (1 <= phase_number <= 14): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Phase number must be between 1 and 14" | |
) | |
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number)) | |
phase = result.scalar_one_or_none() | |
if not phase: | |
raise HTTPException( | |
status_code=status.HTTP_404_NOT_FOUND, | |
detail="Phase not found" | |
) | |
return phase | |
async def update_phase( | |
project_id: str, | |
phase_number: int, | |
phase_data: PhaseUpdate, | |
project: Project = Depends(check_project_access), | |
current_user: User = Depends(get_current_active_user), | |
db: AsyncSession = Depends(get_async_db) | |
): | |
"""Update a phase.""" | |
if not (1 <= phase_number <= 14): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Phase number must be between 1 and 14" | |
) | |
# Get phase | |
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number)) | |
phase = result.scalar_one_or_none() | |
if not phase: | |
raise HTTPException( | |
status_code=status.HTTP_404_NOT_FOUND, | |
detail="Phase not found" | |
) | |
# Save current state as draft if there's content | |
if phase.ai_response and phase.user_input: | |
await PhaseService.save_draft(db, phase.id, phase.user_input, phase.ai_response) | |
# Update phase | |
update_data = {} | |
if phase_data.title is not None: | |
update_data["title"] = phase_data.title | |
if phase_data.description is not None: | |
update_data["description"] = phase_data.description | |
if phase_data.user_input is not None: | |
update_data["user_input"] = phase_data.user_input | |
if phase_data.prompt_template is not None: | |
update_data["prompt_template"] = phase_data.prompt_template | |
if update_data: | |
await db.execute( | |
update(Phase) | |
.where(Phase.id == phase.id) | |
.values(**update_data) | |
) | |
# Mark subsequent phases as stale if this phase was completed | |
if phase.status == PhaseStatus.COMPLETED: | |
await PhaseService.mark_subsequent_phases_stale(db, project_id, phase_number) | |
await db.commit() | |
# Return updated phase | |
result = await db.execute( | |
select(Phase) | |
.where(Phase.id == phase.id) | |
) | |
updated_phase = result.scalar_one() | |
logger.info(f"Phase {phase_number} updated in project {project_id} by {current_user.email}") | |
return updated_phase | |
async def generate_phase_content( | |
project_id: str, | |
phase_number: int, | |
request: PhaseGenerateRequest, | |
project: Project = Depends(check_project_access), | |
current_user: User = Depends(get_current_active_user), | |
db: AsyncSession = Depends(get_async_db) | |
): | |
"""Generate AI content for a phase.""" | |
if not (1 <= phase_number <= 14): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Phase number must be between 1 and 14" | |
) | |
# Get phase | |
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number)) | |
phase = result.scalar_one_or_none() | |
if not phase: | |
raise HTTPException( | |
status_code=status.HTTP_404_NOT_FOUND, | |
detail="Phase not found" | |
) | |
try: | |
# Save current state as draft if there's content | |
if phase.ai_response and phase.user_input: | |
# TODO: Implement sync version of save_draft | |
pass | |
# Generate content using PhaseService | |
ai_response, context_used = await PhaseService.generate_content( | |
db, phase, request.user_input, request.use_rag, request.temperature | |
) | |
# Update phase | |
await db.execute( | |
update(Phase) | |
.where(Phase.id == phase.id) | |
.values( | |
user_input=request.user_input, | |
ai_response=ai_response, | |
status=PhaseStatus.COMPLETED | |
) | |
) | |
# Create embedding for RAG | |
if request.use_rag: | |
await RAGService.create_embedding(db, phase.id, ai_response) | |
# Mark subsequent phases as stale | |
await PhaseService.mark_subsequent_phases_stale(db, project_id, phase_number) | |
await db.commit() | |
logger.info(f"Content generated for phase {phase_number} in project {project_id}") | |
return PhaseGenerateResponse( | |
phase_id=phase.id, | |
ai_response=ai_response, | |
status=PhaseStatus.COMPLETED, | |
context_used=context_used | |
) | |
except Exception as e: | |
logger.error(f"Error generating content: {e}") | |
await db.rollback() | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Failed to generate content: {str(e)}" | |
) | |
async def reconstruct_phase_context( | |
project_id: str, | |
phase_number: int, | |
project: Project = Depends(check_project_access), | |
current_user: User = Depends(get_current_active_user), | |
db: AsyncSession = Depends(get_async_db) | |
): | |
"""Reconstruct context for a phase and regenerate content.""" | |
if not (1 <= phase_number <= 14): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Phase number must be between 1 and 14" | |
) | |
# Get phase | |
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number)) | |
phase = result.scalar_one_or_none() | |
if not phase: | |
raise HTTPException( | |
status_code=status.HTTP_404_NOT_FOUND, | |
detail="Phase not found" | |
) | |
if not phase.user_input: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Phase has no user input to reconstruct from" | |
) | |
try: | |
# Save current state as draft | |
if phase.ai_response: | |
await PhaseService.save_draft(db, phase.id, phase.user_input, phase.ai_response) | |
# Reconstruct context and regenerate | |
ai_response, context_used = await PhaseService.generate_content( | |
db, phase, phase.user_input, use_rag=True, temperature=0.7 | |
) | |
# Update phase | |
await db.execute( | |
update(Phase) | |
.where(Phase.id == phase.id) | |
.values( | |
ai_response=ai_response, | |
status=PhaseStatus.COMPLETED | |
) | |
) | |
# Update embedding | |
await RAGService.create_embedding(db, phase.id, ai_response) | |
await db.commit() | |
logger.info(f"Context reconstructed for phase {phase_number} in project {project_id}") | |
return PhaseGenerateResponse( | |
phase_id=phase.id, | |
ai_response=ai_response, | |
status=PhaseStatus.COMPLETED, | |
context_used=context_used | |
) | |
except Exception as e: | |
logger.error(f"Error reconstructing context: {e}") | |
await db.rollback() | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Failed to reconstruct context: {str(e)}" | |
) | |
async def get_phase_drafts( | |
project_id: str, | |
phase_number: int, | |
project: Project = Depends(check_project_access), | |
db: AsyncSession = Depends(get_async_db) | |
): | |
"""Get all drafts for a phase.""" | |
# Get phase | |
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number)) | |
phase = result.scalar_one_or_none() | |
if not phase: | |
raise HTTPException( | |
status_code=status.HTTP_404_NOT_FOUND, | |
detail="Phase not found" | |
) | |
# Get drafts | |
drafts_result = await db.execute( | |
select(PhaseDraft) | |
.where(PhaseDraft.phase_id == phase.id) | |
.order_by(PhaseDraft.version.desc()) | |
) | |
drafts = drafts_result.scalars().all() | |
return drafts | |
async def restore_phase_draft( | |
project_id: str, | |
phase_number: int, | |
version: int, | |
project: Project = Depends(check_project_access), | |
current_user: User = Depends(get_current_active_user), | |
db: AsyncSession = Depends(get_async_db) | |
): | |
"""Restore a phase from a specific draft version.""" | |
# Get phase | |
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number)) | |
phase = result.scalar_one_or_none() | |
if not phase: | |
raise HTTPException( | |
status_code=status.HTTP_404_NOT_FOUND, | |
detail="Phase not found" | |
) | |
# Get draft | |
draft_result = await db.execute( | |
select(PhaseDraft).where( | |
PhaseDraft.phase_id == phase.id, | |
PhaseDraft.version == version | |
) | |
) | |
draft = draft_result.scalar_one_or_none() | |
if not draft: | |
raise HTTPException( | |
status_code=status.HTTP_404_NOT_FOUND, | |
detail="Draft not found" | |
) | |
# Save current state as new draft | |
if phase.ai_response and phase.user_input: | |
await PhaseService.save_draft(db, phase.id, phase.user_input, phase.ai_response) | |
# Restore from draft | |
await db.execute( | |
update(Phase) | |
.where(Phase.id == phase.id) | |
.values( | |
user_input=draft.user_input, | |
ai_response=draft.ai_response | |
) | |
) | |
# Mark subsequent phases as stale | |
await PhaseService.mark_subsequent_phases_stale(db, project_id, phase_number) | |
await db.commit() | |
# Return updated phase | |
result = await db.execute( | |
select(Phase) | |
.where(Phase.id == phase.id) | |
) | |
updated_phase = result.scalar_one() | |
logger.info(f"Phase {phase_number} restored to version {version} in project {project_id}") | |
return updated_phase | |