File size: 9,474 Bytes
6d01d5b
 
 
 
 
ed6b1d2
6d01d5b
 
 
 
 
2a4b4c6
 
992bd88
6d01d5b
 
 
 
 
 
 
 
 
 
 
 
 
2a4b4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d01d5b
 
 
 
 
3f8cf16
 
 
 
992bd88
 
6d01d5b
 
 
 
 
992bd88
 
 
6d01d5b
 
2a4b4c6
 
 
 
 
992bd88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a4b4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992bd88
 
 
 
2a4b4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d01d5b
 
 
 
 
 
 
 
 
 
992bd88
 
 
 
 
 
 
 
 
 
6d01d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992bd88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d01d5b
 
 
 
992bd88
 
 
 
 
 
6d01d5b
992bd88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from passlib.context import CryptContext
from jose import jwt
from pydantic import BaseModel, EmailStr
from app.database import get_db  # Updated: use the correct async session dependency
from app.models import User
import os
import logging
from dotenv import load_dotenv
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError
from datetime import datetime, timedelta, timezone

router = APIRouter()
logger = logging.getLogger(__name__)

load_dotenv()

# Load secret key and JWT algorithm
SECRET_KEY = os.getenv("SECRET_KEY", "secret")
ALGORITHM = "HS256"

# Password hashing config
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

security = HTTPBearer()

async def get_current_user(token: HTTPAuthorizationCredentials = Depends(security),
                           db: AsyncSession = Depends(get_db)):
    credentials_exception = HTTPException(
        status_code=401,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token.credentials, SECRET_KEY, algorithms=[ALGORITHM])
        user_id: int = payload.get("user_id")
        if user_id is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception

    result = await db.execute(select(User).where(User.id == user_id))
    user = result.scalar_one_or_none()
    if user is None:
        raise credentials_exception
    return user


# Request Schemas
class SignUp(BaseModel):
    email: EmailStr
    password: str
    mobile: str | None = None
    name: str | None = None
    dob: str | None = None
    preparing_for: str | None = None
    exam: str | None = None
    subjects: list[str] | None = None


class Login(BaseModel):
    email: EmailStr
    password: str
    # Allow capturing preferences at sign-in if provided by UI
    exam: str | None = None
    subjects: list[str] | None = None


class UpdateProfile(BaseModel):
    mobile: str | None = None
    name: str | None = None
    dob: str | None = None
    preparing_for: str | None = None
    exam: str | None = None
    subjects: list[str] | None = None


class RequestPasswordReset(BaseModel):
    email: EmailStr


class ConfirmPasswordReset(BaseModel):
    token: str
    new_password: str


class ChangePassword(BaseModel):
    current_password: str
    new_password: str


@router.put("/auth/profile")
async def update_profile(data: UpdateProfile,
                         current_user: User = Depends(get_current_user),
                         db: AsyncSession = Depends(get_db)):
    # Update user fields if provided
    if data.mobile is not None:
        current_user.mobile = data.mobile
    if data.name is not None:
        current_user.name = data.name
    if data.dob is not None:
        current_user.dob = data.dob
    if data.preparing_for is not None:
        current_user.preparing_for = data.preparing_for
    if data.exam is not None:
        current_user.exam = data.exam
    if data.subjects is not None:
        current_user.subjects = ",".join([s.strip() for s in data.subjects if s and s.strip()]) or None

    try:
        await db.commit()
        await db.refresh(current_user)
        return {"message": "Profile updated successfully",
                "user": {"id": current_user.id,
                         "email": current_user.email,
                         "mobile": current_user.mobile,
                         "name": current_user.name,
                         "dob": current_user.dob,
                         "preparing_for": current_user.preparing_for}}
    except Exception as e:
        await db.rollback()
        logger.error(f"Profile update error: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")


@router.post("/auth/signup")
async def signup(data: SignUp, db: AsyncSession = Depends(get_db)):
    # Check if user already exists
    result = await db.execute(select(User).where(User.email == data.email))
    existing_user = result.scalar_one_or_none()

    if existing_user:
        raise HTTPException(status_code=400, detail="Email already exists")

    hashed_password = pwd_context.hash(data.password)
    new_user = User(
        email=data.email,
        hashed_password=hashed_password,
        mobile=data.mobile,
        name=data.name,
        dob=data.dob,
        preparing_for=data.preparing_for,
        exam=data.exam,
        subjects=(",".join(data.subjects) if data.subjects else None),
    )

    try:
        db.add(new_user)
        await db.commit()
        await db.refresh(new_user)
        return {"message": "User created", "user_id": new_user.id}
    except Exception as e:
        await db.rollback()
        logger.error(f"Signup error: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")


@router.post("/auth/login")
async def login(data: Login, db: AsyncSession = Depends(get_db)):
    result = await db.execute(select(User).where(User.email == data.email))
    user = result.scalar_one_or_none()

    if not user or not pwd_context.verify(data.password, user.hashed_password):
        raise HTTPException(status_code=401, detail="Invalid credentials")

    # Optionally update preferences at login if provided
    try:
        updated = False
        if data.exam is not None and data.exam != user.exam:
            user.exam = data.exam
            updated = True
        if data.subjects is not None:
            subjects_joined = ",".join([s.strip() for s in data.subjects if s and s.strip()]) or None
            if subjects_joined != (user.subjects or None):
                user.subjects = subjects_joined
                updated = True
        if updated:
            await db.commit()
            await db.refresh(user)
    except Exception:
        await db.rollback()

    token = jwt.encode({"user_id": user.id}, SECRET_KEY, algorithm=ALGORITHM)
    return {
        "access_token": token,
        "token_type": "bearer",
        "user": {
            "id": user.id,
            "email": user.email,
            "exam": user.exam,
            "subjects": (user.subjects.split(",") if user.subjects else []),
        },
    }


@router.post("/auth/password/request-reset")
async def request_password_reset(data: RequestPasswordReset, db: AsyncSession = Depends(get_db)):
    result = await db.execute(select(User).where(User.email == data.email))
    user = result.scalar_one_or_none()
    if not user:
        # Do not reveal whether email exists
        return {"message": "If the email exists, a reset link has been sent."}

    # Create a short-lived token
    expires_at = datetime.now(timezone.utc) + timedelta(minutes=30)
    reset_token = jwt.encode({"user_id": user.id, "pr": True, "exp": expires_at}, SECRET_KEY, algorithm=ALGORITHM)
    user.password_reset_token = reset_token
    user.password_reset_expires = expires_at
    try:
        await db.commit()
    except Exception as e:
        await db.rollback()
        logger.error(f"Failed to store reset token: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")

    # NOTE: Integrate email/SMS sending here. For now, return token for development.
    return {"message": "Reset token generated", "reset_token": reset_token}


@router.post("/auth/password/confirm-reset")
async def confirm_password_reset(data: ConfirmPasswordReset, db: AsyncSession = Depends(get_db)):
    try:
        payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
        user_id = payload.get("user_id")
        if not user_id or not payload.get("pr"):
            raise HTTPException(status_code=400, detail="Invalid token")
    except JWTError:
        raise HTTPException(status_code=400, detail="Invalid or expired token")

    result = await db.execute(select(User).where(User.id == user_id))
    user = result.scalar_one_or_none()
    if not user or user.password_reset_token != data.token:
        raise HTTPException(status_code=400, detail="Invalid token")
    if user.password_reset_expires and datetime.now(timezone.utc) > user.password_reset_expires:
        raise HTTPException(status_code=400, detail="Token expired")

    user.hashed_password = pwd_context.hash(data.new_password)
    user.password_reset_token = None
    user.password_reset_expires = None
    try:
        await db.commit()
        return {"message": "Password reset successful"}
    except Exception as e:
        await db.rollback()
        logger.error(f"Password reset error: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")


@router.post("/auth/password/change")
async def change_password(data: ChangePassword,
                         current_user: User = Depends(get_current_user),
                         db: AsyncSession = Depends(get_db)):
    if not pwd_context.verify(data.current_password, current_user.hashed_password):
        raise HTTPException(status_code=400, detail="Current password incorrect")
    current_user.hashed_password = pwd_context.hash(data.new_password)
    try:
        await db.commit()
        return {"message": "Password changed"}
    except Exception as e:
        await db.rollback()
        logger.error(f"Change password error: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")