Spaces:
Running
Running
# app/routers/auth.py | |
from fastapi import APIRouter, HTTPException, Depends | |
from pydantic import BaseModel, EmailStr | |
from werkzeug.security import generate_password_hash, check_password_hash | |
from datetime import datetime, timedelta | |
import random | |
import string | |
from sendgrid import SendGridAPIClient | |
from sendgrid.helpers.mail import Mail | |
import os | |
from app.database.database_query import DatabaseQuery | |
from app.middleware.auth import create_access_token, get_current_user | |
from dotenv import load_dotenv | |
load_dotenv() | |
SENDGRID_API_KEY = os.getenv("SENDGRID_API_KEY") | |
FROM_EMAIL = os.getenv("FROM_EMAIL") | |
router = APIRouter() | |
query = DatabaseQuery() | |
class LoginRequest(BaseModel): | |
identifier: str | |
password: str | |
class LoginResponse(BaseModel): | |
message: str | |
token: str | |
class RegisterRequest(BaseModel): | |
username: str | |
email: EmailStr | |
password: str | |
name: str | |
age: int | |
class VerifyEmailRequest(BaseModel): | |
username: str | |
code: str | |
class ResendCodeRequest(BaseModel): | |
username: str | |
class ForgotPasswordRequest(BaseModel): | |
email: EmailStr | |
class ResetPasswordRequest(BaseModel): | |
token: str | |
password: str | |
class ChatSessionCheck(BaseModel): | |
session_id: str | |
async def login(login_data: LoginRequest): | |
try: | |
identifier = login_data.identifier | |
password = login_data.password | |
user = query.get_user_by_identifier(identifier) | |
if user: | |
if not user.get('is_verified'): | |
raise HTTPException(status_code=401, detail="Please verify your email before logging in") | |
if check_password_hash(user['password'], password): | |
access_token = create_access_token({"sub": user['username']}) | |
return {"message": "Login successful", "token": access_token} | |
raise HTTPException(status_code=401, detail="Invalid username/email or password") | |
except Exception as e: | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def register(register_data: RegisterRequest): | |
try: | |
username = register_data.username | |
email = register_data.email | |
password = register_data.password | |
name = register_data.name | |
age = register_data.age | |
if query.is_username_or_email_exists(username, email): | |
raise HTTPException(status_code=409, detail="Username or email already exists") | |
verification_code = ''.join(random.choices(string.digits, k=6)) | |
code_expiration = datetime.utcnow() + timedelta(minutes=10) | |
hashed_password = generate_password_hash(password) | |
created_at = datetime.utcnow() | |
temp_user = { | |
'username': username, | |
'email': email, | |
'password': hashed_password, | |
'name': name, | |
'age': age, | |
'created_at': created_at, | |
'verification_code': verification_code, | |
'code_expiration': code_expiration | |
} | |
query.create_or_update_temp_user(username, email, temp_user) | |
message = Mail( | |
from_email=FROM_EMAIL, | |
to_emails=email, | |
subject='Verify your email address', | |
html_content=f''' | |
<p>Hi {name},</p> | |
<p>Thank you for registering. Please use the following code to verify your email address:</p> | |
<h2>{verification_code}</h2> | |
<p>This code will expire in 10 minutes.</p> | |
''' | |
) | |
try: | |
sg = SendGridAPIClient(SENDGRID_API_KEY) | |
sg.send(message) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail="Failed to send verification email") | |
return {"message": "Registration successful. A verification code has been sent to your email."} | |
except Exception as e: | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def verify_email(verify_data: VerifyEmailRequest): | |
try: | |
username = verify_data.username | |
code = verify_data.code | |
temp_user = query.get_temp_user_by_username(username) | |
if not temp_user: | |
raise HTTPException(status_code=404, detail="User not found or already verified") | |
if temp_user['verification_code'] != code: | |
raise HTTPException(status_code=400, detail="Invalid verification code") | |
if datetime.utcnow() > temp_user['code_expiration']: | |
raise HTTPException(status_code=400, detail="Verification code has expired") | |
user_data = temp_user.copy() | |
user_data['is_verified'] = True | |
user_data.pop('verification_code', None) | |
user_data.pop('code_expiration', None) | |
user_data.pop('_id', None) | |
query.create_user_from_data(user_data) | |
query.delete_temp_user(username) | |
# Set default language to English | |
query.set_user_language(username, "English") | |
# Set default theme to light (passing false for dark theme) | |
query.set_user_theme(username, False) | |
default_preferences = { | |
'keywords': True, | |
'references': True, | |
'websearch': False, | |
'personalized_recommendations': True, | |
'environmental_recommendations': True | |
} | |
query.set_user_preferences(username, default_preferences) | |
return {"message": "Email verification successful"} | |
except Exception as e: | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def resend_code(resend_data: ResendCodeRequest): | |
try: | |
username = resend_data.username | |
temp_user = query.get_temp_user_by_username(username) | |
if not temp_user: | |
raise HTTPException(status_code=404, detail="User not found or already verified") | |
verification_code = ''.join(random.choices(string.digits, k=6)) | |
code_expiration = datetime.utcnow() + timedelta(minutes=10) | |
temp_user['verification_code'] = verification_code | |
temp_user['code_expiration'] = code_expiration | |
query.create_or_update_temp_user(username, temp_user['email'], temp_user) | |
message = Mail( | |
from_email=FROM_EMAIL, | |
to_emails=temp_user['email'], | |
subject='Your new verification code', | |
html_content=f''' | |
<p>Hi {temp_user['name']},</p> | |
<p>You requested a new verification code. Please use the following code to verify your email address:</p> | |
<h2>{verification_code}</h2> | |
<p>This code will expire in 10 minutes.</p> | |
''' | |
) | |
try: | |
sg = SendGridAPIClient(SENDGRID_API_KEY) | |
sg.send(message) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail="Failed to send verification email") | |
return {"message": "A new verification code has been sent to your email."} | |
except Exception as e: | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def check_chatsession(data: ChatSessionCheck, username: str = Depends(get_current_user)): | |
session_id = data.session_id | |
is_chat_exit = query.check_chat_session(session_id) | |
return {"ischatexit": is_chat_exit} | |
async def check_token(username: str = Depends(get_current_user)): | |
try: | |
return {'valid': True, 'user': username} | |
except Exception as e: | |
raise HTTPException(status_code=401, detail=str(e)) | |
async def forgot_password(data: ForgotPasswordRequest): | |
try: | |
email = data.email | |
user = query.get_user_by_identifier(email) | |
if not user: | |
raise HTTPException(status_code=404, detail="Email not found") | |
reset_token = ''.join(random.choices(string.ascii_letters + string.digits, k=32)) | |
expiration = datetime.utcnow() + timedelta(hours=1) | |
query.store_reset_token(email, reset_token, expiration) | |
reset_link = f"http://localhost:3000/reset-password?token={reset_token}" | |
message = Mail( | |
from_email=FROM_EMAIL, | |
to_emails=email, | |
subject='Reset Your Password', | |
html_content=f''' | |
<p>Hi,</p> | |
<p>You requested to reset your password. Click the link below to reset it:</p> | |
<p><a href="{reset_link}">Reset Password</a></p> | |
<p>This link will expire in 1 hour.</p> | |
<p>If you didn't request this, please ignore this email.</p> | |
''' | |
) | |
sg = SendGridAPIClient(SENDGRID_API_KEY) | |
sg.send(message) | |
return {"message": "Password reset instructions sent to email"} | |
except Exception as e: | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def reset_password(data: ResetPasswordRequest): | |
try: | |
token = data.token | |
new_password = data.password | |
if not token or not new_password: | |
raise HTTPException(status_code=400, detail="Token and new password are required") | |
reset_info = query.verify_reset_token(token) | |
if not reset_info: | |
raise HTTPException(status_code=400, detail="Invalid or expired reset token") | |
hashed_password = generate_password_hash(new_password) | |
query.update_password(reset_info['email'], hashed_password) | |
return {"message": "Password successfully reset"} | |
except Exception as e: | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) |