aghaai's picture
Initial deployment of Unified Assistant with OpenAI and Hugging Face integration
bd161ec
"""
Phase management router for the 14-phase workflow system.
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select, update
import logging
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_async_db
from models import User, Project, Phase, PhaseDraft, PhaseStatus
from schemas import (
PhaseResponse, PhaseUpdate, PhaseGenerateRequest, PhaseGenerateResponse,
PhaseDraftResponse, APIResponse
)
from dependencies import get_current_active_user, check_project_access
from services.phase_service import PhaseService
from services.rag_service import RAGService
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/projects/{project_id}/phases", response_model=List[PhaseResponse])
async def get_project_phases(
project_id: str,
project: Project = Depends(check_project_access),
db: AsyncSession = Depends(get_async_db)
):
"""Get all phases for a project (async)."""
result = await db.execute(select(Phase).where(Phase.project_id == project_id).order_by(Phase.phase_number))
phases = result.scalars().all()
return phases
@router.get("/projects/{project_id}/phases/{phase_number}", response_model=PhaseResponse)
async def get_phase(
project_id: str,
phase_number: int,
project: Project = Depends(check_project_access),
db: AsyncSession = Depends(get_async_db)
):
"""Get a specific phase."""
if not (1 <= phase_number <= 14):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Phase number must be between 1 and 14"
)
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number))
phase = result.scalar_one_or_none()
if not phase:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Phase not found"
)
return phase
@router.put("/projects/{project_id}/phases/{phase_number}", response_model=PhaseResponse)
async def update_phase(
project_id: str,
phase_number: int,
phase_data: PhaseUpdate,
project: Project = Depends(check_project_access),
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_async_db)
):
"""Update a phase."""
if not (1 <= phase_number <= 14):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Phase number must be between 1 and 14"
)
# Get phase
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number))
phase = result.scalar_one_or_none()
if not phase:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Phase not found"
)
# Save current state as draft if there's content
if phase.ai_response and phase.user_input:
await PhaseService.save_draft(db, phase.id, phase.user_input, phase.ai_response)
# Update phase
update_data = {}
if phase_data.title is not None:
update_data["title"] = phase_data.title
if phase_data.description is not None:
update_data["description"] = phase_data.description
if phase_data.user_input is not None:
update_data["user_input"] = phase_data.user_input
if phase_data.prompt_template is not None:
update_data["prompt_template"] = phase_data.prompt_template
if update_data:
await db.execute(
update(Phase)
.where(Phase.id == phase.id)
.values(**update_data)
)
# Mark subsequent phases as stale if this phase was completed
if phase.status == PhaseStatus.COMPLETED:
await PhaseService.mark_subsequent_phases_stale(db, project_id, phase_number)
await db.commit()
# Return updated phase
result = await db.execute(
select(Phase)
.where(Phase.id == phase.id)
)
updated_phase = result.scalar_one()
logger.info(f"Phase {phase_number} updated in project {project_id} by {current_user.email}")
return updated_phase
@router.post("/projects/{project_id}/phases/{phase_number}/generate", response_model=PhaseGenerateResponse)
async def generate_phase_content(
project_id: str,
phase_number: int,
request: PhaseGenerateRequest,
project: Project = Depends(check_project_access),
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_async_db)
):
"""Generate AI content for a phase."""
if not (1 <= phase_number <= 14):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Phase number must be between 1 and 14"
)
# Get phase
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number))
phase = result.scalar_one_or_none()
if not phase:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Phase not found"
)
try:
# Save current state as draft if there's content
if phase.ai_response and phase.user_input:
# TODO: Implement sync version of save_draft
pass
# Generate content using PhaseService
ai_response, context_used = await PhaseService.generate_content(
db, phase, request.user_input, request.use_rag, request.temperature
)
# Update phase
await db.execute(
update(Phase)
.where(Phase.id == phase.id)
.values(
user_input=request.user_input,
ai_response=ai_response,
status=PhaseStatus.COMPLETED
)
)
# Create embedding for RAG
if request.use_rag:
await RAGService.create_embedding(db, phase.id, ai_response)
# Mark subsequent phases as stale
await PhaseService.mark_subsequent_phases_stale(db, project_id, phase_number)
await db.commit()
logger.info(f"Content generated for phase {phase_number} in project {project_id}")
return PhaseGenerateResponse(
phase_id=phase.id,
ai_response=ai_response,
status=PhaseStatus.COMPLETED,
context_used=context_used
)
except Exception as e:
logger.error(f"Error generating content: {e}")
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate content: {str(e)}"
)
@router.post("/projects/{project_id}/phases/{phase_number}/reconstruct-context", response_model=PhaseGenerateResponse)
async def reconstruct_phase_context(
project_id: str,
phase_number: int,
project: Project = Depends(check_project_access),
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_async_db)
):
"""Reconstruct context for a phase and regenerate content."""
if not (1 <= phase_number <= 14):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Phase number must be between 1 and 14"
)
# Get phase
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number))
phase = result.scalar_one_or_none()
if not phase:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Phase not found"
)
if not phase.user_input:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Phase has no user input to reconstruct from"
)
try:
# Save current state as draft
if phase.ai_response:
await PhaseService.save_draft(db, phase.id, phase.user_input, phase.ai_response)
# Reconstruct context and regenerate
ai_response, context_used = await PhaseService.generate_content(
db, phase, phase.user_input, use_rag=True, temperature=0.7
)
# Update phase
await db.execute(
update(Phase)
.where(Phase.id == phase.id)
.values(
ai_response=ai_response,
status=PhaseStatus.COMPLETED
)
)
# Update embedding
await RAGService.create_embedding(db, phase.id, ai_response)
await db.commit()
logger.info(f"Context reconstructed for phase {phase_number} in project {project_id}")
return PhaseGenerateResponse(
phase_id=phase.id,
ai_response=ai_response,
status=PhaseStatus.COMPLETED,
context_used=context_used
)
except Exception as e:
logger.error(f"Error reconstructing context: {e}")
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to reconstruct context: {str(e)}"
)
@router.get("/projects/{project_id}/phases/{phase_number}/drafts", response_model=List[PhaseDraftResponse])
async def get_phase_drafts(
project_id: str,
phase_number: int,
project: Project = Depends(check_project_access),
db: AsyncSession = Depends(get_async_db)
):
"""Get all drafts for a phase."""
# Get phase
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number))
phase = result.scalar_one_or_none()
if not phase:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Phase not found"
)
# Get drafts
drafts_result = await db.execute(
select(PhaseDraft)
.where(PhaseDraft.phase_id == phase.id)
.order_by(PhaseDraft.version.desc())
)
drafts = drafts_result.scalars().all()
return drafts
@router.post("/projects/{project_id}/phases/{phase_number}/drafts/{version}/restore", response_model=PhaseResponse)
async def restore_phase_draft(
project_id: str,
phase_number: int,
version: int,
project: Project = Depends(check_project_access),
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_async_db)
):
"""Restore a phase from a specific draft version."""
# Get phase
result = await db.execute(select(Phase).where(Phase.project_id == project_id, Phase.phase_number == phase_number))
phase = result.scalar_one_or_none()
if not phase:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Phase not found"
)
# Get draft
draft_result = await db.execute(
select(PhaseDraft).where(
PhaseDraft.phase_id == phase.id,
PhaseDraft.version == version
)
)
draft = draft_result.scalar_one_or_none()
if not draft:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Draft not found"
)
# Save current state as new draft
if phase.ai_response and phase.user_input:
await PhaseService.save_draft(db, phase.id, phase.user_input, phase.ai_response)
# Restore from draft
await db.execute(
update(Phase)
.where(Phase.id == phase.id)
.values(
user_input=draft.user_input,
ai_response=draft.ai_response
)
)
# Mark subsequent phases as stale
await PhaseService.mark_subsequent_phases_stale(db, project_id, phase_number)
await db.commit()
# Return updated phase
result = await db.execute(
select(Phase)
.where(Phase.id == phase.id)
)
updated_phase = result.scalar_one()
logger.info(f"Phase {phase_number} restored to version {version} in project {project_id}")
return updated_phase