Spaces:
Building
Building
from fastapi import APIRouter, HTTPException, Depends | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from sqlalchemy.future import select | |
from passlib.context import CryptContext | |
from jose import jwt | |
from pydantic import BaseModel, EmailStr | |
from app.database import get_db # Updated: use the correct async session dependency | |
from app.models import User | |
import os | |
import logging | |
from dotenv import load_dotenv | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from jose import JWTError | |
from datetime import datetime, timedelta, timezone | |
router = APIRouter() | |
logger = logging.getLogger(__name__) | |
load_dotenv() | |
# Load secret key and JWT algorithm | |
SECRET_KEY = os.getenv("SECRET_KEY", "secret") | |
ALGORITHM = "HS256" | |
# Password hashing config | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
security = HTTPBearer() | |
async def get_current_user(token: HTTPAuthorizationCredentials = Depends(security), | |
db: AsyncSession = Depends(get_db)): | |
credentials_exception = HTTPException( | |
status_code=401, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
try: | |
payload = jwt.decode(token.credentials, SECRET_KEY, algorithms=[ALGORITHM]) | |
user_id: int = payload.get("user_id") | |
if user_id is None: | |
raise credentials_exception | |
except JWTError: | |
raise credentials_exception | |
result = await db.execute(select(User).where(User.id == user_id)) | |
user = result.scalar_one_or_none() | |
if user is None: | |
raise credentials_exception | |
return user | |
# Request Schemas | |
class SignUp(BaseModel): | |
email: EmailStr | |
password: str | |
mobile: str | None = None | |
name: str | None = None | |
dob: str | None = None | |
preparing_for: str | None = None | |
exam: str | None = None | |
subjects: list[str] | None = None | |
class Login(BaseModel): | |
email: EmailStr | |
password: str | |
# Allow capturing preferences at sign-in if provided by UI | |
exam: str | None = None | |
subjects: list[str] | None = None | |
class UpdateProfile(BaseModel): | |
mobile: str | None = None | |
name: str | None = None | |
dob: str | None = None | |
preparing_for: str | None = None | |
exam: str | None = None | |
subjects: list[str] | None = None | |
class RequestPasswordReset(BaseModel): | |
email: EmailStr | |
class ConfirmPasswordReset(BaseModel): | |
token: str | |
new_password: str | |
class ChangePassword(BaseModel): | |
current_password: str | |
new_password: str | |
async def update_profile(data: UpdateProfile, | |
current_user: User = Depends(get_current_user), | |
db: AsyncSession = Depends(get_db)): | |
# Update user fields if provided | |
if data.mobile is not None: | |
current_user.mobile = data.mobile | |
if data.name is not None: | |
current_user.name = data.name | |
if data.dob is not None: | |
current_user.dob = data.dob | |
if data.preparing_for is not None: | |
current_user.preparing_for = data.preparing_for | |
if data.exam is not None: | |
current_user.exam = data.exam | |
if data.subjects is not None: | |
current_user.subjects = ",".join([s.strip() for s in data.subjects if s and s.strip()]) or None | |
try: | |
await db.commit() | |
await db.refresh(current_user) | |
return {"message": "Profile updated successfully", | |
"user": {"id": current_user.id, | |
"email": current_user.email, | |
"mobile": current_user.mobile, | |
"name": current_user.name, | |
"dob": current_user.dob, | |
"preparing_for": current_user.preparing_for}} | |
except Exception as e: | |
await db.rollback() | |
logger.error(f"Profile update error: {e}") | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |
async def signup(data: SignUp, db: AsyncSession = Depends(get_db)): | |
# Check if user already exists | |
result = await db.execute(select(User).where(User.email == data.email)) | |
existing_user = result.scalar_one_or_none() | |
if existing_user: | |
raise HTTPException(status_code=400, detail="Email already exists") | |
hashed_password = pwd_context.hash(data.password) | |
new_user = User( | |
email=data.email, | |
hashed_password=hashed_password, | |
mobile=data.mobile, | |
name=data.name, | |
dob=data.dob, | |
preparing_for=data.preparing_for, | |
exam=data.exam, | |
subjects=(",".join(data.subjects) if data.subjects else None), | |
) | |
try: | |
db.add(new_user) | |
await db.commit() | |
await db.refresh(new_user) | |
return {"message": "User created", "user_id": new_user.id} | |
except Exception as e: | |
await db.rollback() | |
logger.error(f"Signup error: {e}") | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |
async def login(data: Login, db: AsyncSession = Depends(get_db)): | |
result = await db.execute(select(User).where(User.email == data.email)) | |
user = result.scalar_one_or_none() | |
if not user or not pwd_context.verify(data.password, user.hashed_password): | |
raise HTTPException(status_code=401, detail="Invalid credentials") | |
# Optionally update preferences at login if provided | |
try: | |
updated = False | |
if data.exam is not None and data.exam != user.exam: | |
user.exam = data.exam | |
updated = True | |
if data.subjects is not None: | |
subjects_joined = ",".join([s.strip() for s in data.subjects if s and s.strip()]) or None | |
if subjects_joined != (user.subjects or None): | |
user.subjects = subjects_joined | |
updated = True | |
if updated: | |
await db.commit() | |
await db.refresh(user) | |
except Exception: | |
await db.rollback() | |
token = jwt.encode({"user_id": user.id}, SECRET_KEY, algorithm=ALGORITHM) | |
return { | |
"access_token": token, | |
"token_type": "bearer", | |
"user": { | |
"id": user.id, | |
"email": user.email, | |
"exam": user.exam, | |
"subjects": (user.subjects.split(",") if user.subjects else []), | |
}, | |
} | |
async def request_password_reset(data: RequestPasswordReset, db: AsyncSession = Depends(get_db)): | |
result = await db.execute(select(User).where(User.email == data.email)) | |
user = result.scalar_one_or_none() | |
if not user: | |
# Do not reveal whether email exists | |
return {"message": "If the email exists, a reset link has been sent."} | |
# Create a short-lived token | |
expires_at = datetime.now(timezone.utc) + timedelta(minutes=30) | |
reset_token = jwt.encode({"user_id": user.id, "pr": True, "exp": expires_at}, SECRET_KEY, algorithm=ALGORITHM) | |
user.password_reset_token = reset_token | |
user.password_reset_expires = expires_at | |
try: | |
await db.commit() | |
except Exception as e: | |
await db.rollback() | |
logger.error(f"Failed to store reset token: {e}") | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |
# NOTE: Integrate email/SMS sending here. For now, return token for development. | |
return {"message": "Reset token generated", "reset_token": reset_token} | |
async def confirm_password_reset(data: ConfirmPasswordReset, db: AsyncSession = Depends(get_db)): | |
try: | |
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM]) | |
user_id = payload.get("user_id") | |
if not user_id or not payload.get("pr"): | |
raise HTTPException(status_code=400, detail="Invalid token") | |
except JWTError: | |
raise HTTPException(status_code=400, detail="Invalid or expired token") | |
result = await db.execute(select(User).where(User.id == user_id)) | |
user = result.scalar_one_or_none() | |
if not user or user.password_reset_token != data.token: | |
raise HTTPException(status_code=400, detail="Invalid token") | |
if user.password_reset_expires and datetime.now(timezone.utc) > user.password_reset_expires: | |
raise HTTPException(status_code=400, detail="Token expired") | |
user.hashed_password = pwd_context.hash(data.new_password) | |
user.password_reset_token = None | |
user.password_reset_expires = None | |
try: | |
await db.commit() | |
return {"message": "Password reset successful"} | |
except Exception as e: | |
await db.rollback() | |
logger.error(f"Password reset error: {e}") | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |
async def change_password(data: ChangePassword, | |
current_user: User = Depends(get_current_user), | |
db: AsyncSession = Depends(get_db)): | |
if not pwd_context.verify(data.current_password, current_user.hashed_password): | |
raise HTTPException(status_code=400, detail="Current password incorrect") | |
current_user.hashed_password = pwd_context.hash(data.new_password) | |
try: | |
await db.commit() | |
return {"message": "Password changed"} | |
except Exception as e: | |
await db.rollback() | |
logger.error(f"Change password error: {e}") | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |