from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from passlib.context import CryptContext from jose import jwt from pydantic import BaseModel, EmailStr from app.database import get_db # Updated: use the correct async session dependency from app.models import User import os import logging from dotenv import load_dotenv from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from jose import JWTError from datetime import datetime, timedelta, timezone router = APIRouter() logger = logging.getLogger(__name__) load_dotenv() # Load secret key and JWT algorithm SECRET_KEY = os.getenv("SECRET_KEY", "secret") ALGORITHM = "HS256" # Password hashing config pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") security = HTTPBearer() async def get_current_user(token: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db)): credentials_exception = HTTPException( status_code=401, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token.credentials, SECRET_KEY, algorithms=[ALGORITHM]) user_id: int = payload.get("user_id") if user_id is None: raise credentials_exception except JWTError: raise credentials_exception result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: raise credentials_exception return user # Request Schemas class SignUp(BaseModel): email: EmailStr password: str mobile: str | None = None name: str | None = None dob: str | None = None preparing_for: str | None = None exam: str | None = None subjects: list[str] | None = None class Login(BaseModel): email: EmailStr password: str # Allow capturing preferences at sign-in if provided by UI exam: str | None = None subjects: list[str] | None = None class UpdateProfile(BaseModel): mobile: str | None = None name: str | None = None dob: str | None = None preparing_for: str | None = None exam: str | None = None subjects: list[str] | None = None class RequestPasswordReset(BaseModel): email: EmailStr class ConfirmPasswordReset(BaseModel): token: str new_password: str class ChangePassword(BaseModel): current_password: str new_password: str @router.put("/auth/profile") async def update_profile(data: UpdateProfile, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)): # Update user fields if provided if data.mobile is not None: current_user.mobile = data.mobile if data.name is not None: current_user.name = data.name if data.dob is not None: current_user.dob = data.dob if data.preparing_for is not None: current_user.preparing_for = data.preparing_for if data.exam is not None: current_user.exam = data.exam if data.subjects is not None: current_user.subjects = ",".join([s.strip() for s in data.subjects if s and s.strip()]) or None try: await db.commit() await db.refresh(current_user) return {"message": "Profile updated successfully", "user": {"id": current_user.id, "email": current_user.email, "mobile": current_user.mobile, "name": current_user.name, "dob": current_user.dob, "preparing_for": current_user.preparing_for}} except Exception as e: await db.rollback() logger.error(f"Profile update error: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") @router.post("/auth/signup") async def signup(data: SignUp, db: AsyncSession = Depends(get_db)): # Check if user already exists result = await db.execute(select(User).where(User.email == data.email)) existing_user = result.scalar_one_or_none() if existing_user: raise HTTPException(status_code=400, detail="Email already exists") hashed_password = pwd_context.hash(data.password) new_user = User( email=data.email, hashed_password=hashed_password, mobile=data.mobile, name=data.name, dob=data.dob, preparing_for=data.preparing_for, exam=data.exam, subjects=(",".join(data.subjects) if data.subjects else None), ) try: db.add(new_user) await db.commit() await db.refresh(new_user) return {"message": "User created", "user_id": new_user.id} except Exception as e: await db.rollback() logger.error(f"Signup error: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") @router.post("/auth/login") async def login(data: Login, db: AsyncSession = Depends(get_db)): result = await db.execute(select(User).where(User.email == data.email)) user = result.scalar_one_or_none() if not user or not pwd_context.verify(data.password, user.hashed_password): raise HTTPException(status_code=401, detail="Invalid credentials") # Optionally update preferences at login if provided try: updated = False if data.exam is not None and data.exam != user.exam: user.exam = data.exam updated = True if data.subjects is not None: subjects_joined = ",".join([s.strip() for s in data.subjects if s and s.strip()]) or None if subjects_joined != (user.subjects or None): user.subjects = subjects_joined updated = True if updated: await db.commit() await db.refresh(user) except Exception: await db.rollback() token = jwt.encode({"user_id": user.id}, SECRET_KEY, algorithm=ALGORITHM) return { "access_token": token, "token_type": "bearer", "user": { "id": user.id, "email": user.email, "exam": user.exam, "subjects": (user.subjects.split(",") if user.subjects else []), }, } @router.post("/auth/password/request-reset") async def request_password_reset(data: RequestPasswordReset, db: AsyncSession = Depends(get_db)): result = await db.execute(select(User).where(User.email == data.email)) user = result.scalar_one_or_none() if not user: # Do not reveal whether email exists return {"message": "If the email exists, a reset link has been sent."} # Create a short-lived token expires_at = datetime.now(timezone.utc) + timedelta(minutes=30) reset_token = jwt.encode({"user_id": user.id, "pr": True, "exp": expires_at}, SECRET_KEY, algorithm=ALGORITHM) user.password_reset_token = reset_token user.password_reset_expires = expires_at try: await db.commit() except Exception as e: await db.rollback() logger.error(f"Failed to store reset token: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") # NOTE: Integrate email/SMS sending here. For now, return token for development. return {"message": "Reset token generated", "reset_token": reset_token} @router.post("/auth/password/confirm-reset") async def confirm_password_reset(data: ConfirmPasswordReset, db: AsyncSession = Depends(get_db)): try: payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM]) user_id = payload.get("user_id") if not user_id or not payload.get("pr"): raise HTTPException(status_code=400, detail="Invalid token") except JWTError: raise HTTPException(status_code=400, detail="Invalid or expired token") result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user or user.password_reset_token != data.token: raise HTTPException(status_code=400, detail="Invalid token") if user.password_reset_expires and datetime.now(timezone.utc) > user.password_reset_expires: raise HTTPException(status_code=400, detail="Token expired") user.hashed_password = pwd_context.hash(data.new_password) user.password_reset_token = None user.password_reset_expires = None try: await db.commit() return {"message": "Password reset successful"} except Exception as e: await db.rollback() logger.error(f"Password reset error: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") @router.post("/auth/password/change") async def change_password(data: ChangePassword, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)): if not pwd_context.verify(data.current_password, current_user.hashed_password): raise HTTPException(status_code=400, detail="Current password incorrect") current_user.hashed_password = pwd_context.hash(data.new_password) try: await db.commit() return {"message": "Password changed"} except Exception as e: await db.rollback() logger.error(f"Change password error: {e}") raise HTTPException(status_code=500, detail="Internal Server Error")