| | |
| | import os |
| | import uuid |
| | import logging |
| | from datetime import datetime, timedelta |
| | from urllib.parse import quote_plus |
| | from typing import Optional |
| |
|
| | from dotenv import load_dotenv |
| | from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form |
| | from fastapi.responses import StreamingResponse |
| | from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
| | from jose import JWTError, jwt |
| | from passlib.context import CryptContext |
| | from pymongo import MongoClient |
| | import gridfs |
| | from bson import ObjectId |
| |
|
| | from models import User, UserUpdate, Token, LoginResponse |
| | from config import CONNECTION_STRING, SECRET_KEY, ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS |
| |
|
| | load_dotenv() |
| |
|
| | logger = logging.getLogger("uvicorn") |
| | logger.setLevel(logging.INFO) |
| |
|
| | |
| | client = MongoClient(CONNECTION_STRING) |
| | db = client.users_database |
| | users_collection = db.users |
| | |
| | fs = gridfs.GridFS(db, collection="avatars") |
| |
|
| | |
| | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
| | router = APIRouter(prefix="/auth", tags=["auth"]) |
| |
|
| | |
| | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
| |
|
| | def verify_password(plain_password: str, hashed_password: str) -> bool: |
| | return pwd_context.verify(plain_password, hashed_password) |
| |
|
| | def get_password_hash(password: str) -> str: |
| | return pwd_context.hash(password) |
| |
|
| | def get_user(email: str) -> Optional[dict]: |
| | return users_collection.find_one({"email": email}) |
| |
|
| | def authenticate_user(email: str, password: str) -> Optional[dict]: |
| | user = get_user(email) |
| | if not user or not verify_password(password, user["hashed_password"]): |
| | return None |
| | return user |
| |
|
| | def create_token(data: dict, expires_delta: timedelta = None) -> str: |
| | to_encode = data.copy() |
| | expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15)) |
| | to_encode.update({"exp": expire}) |
| | algorithm = "HS256" |
| | return jwt.encode(to_encode, SECRET_KEY, algorithm=algorithm) |
| |
|
| | def create_access_token(email: str) -> str: |
| | return create_token({"sub": email}, timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) |
| |
|
| | def create_refresh_token(email: str) -> str: |
| | return create_token({"sub": email}, timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)) |
| |
|
| | def get_current_user(token: str = Depends(oauth2_scheme)) -> dict: |
| | try: |
| | payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) |
| | email: str = payload.get("sub") |
| | if not email: |
| | raise HTTPException(status_code=401, detail="Invalid credentials") |
| | user = get_user(email) |
| | if not user: |
| | raise HTTPException(status_code=401, detail="User not found") |
| | return user |
| | except JWTError: |
| | raise HTTPException(status_code=401, detail="Invalid token") |
| |
|
| | async def save_avatar_file_to_gridfs(file: UploadFile) -> str: |
| | allowed_types = ["image/jpeg", "image/png", "image/gif"] |
| | if file.content_type not in allowed_types: |
| | logger.error(f"Unsupported file type: {file.content_type}") |
| | raise HTTPException( |
| | status_code=400, |
| | detail="Invalid image format. Only JPEG, PNG, and GIF are accepted." |
| | ) |
| | try: |
| | contents = await file.read() |
| | file_id = fs.put(contents, filename=file.filename, contentType=file.content_type) |
| | logger.info(f"Avatar stored in GridFS with file_id: {file_id}") |
| | return str(file_id) |
| | except Exception as e: |
| | logger.exception("Failed to store avatar in GridFS") |
| | raise HTTPException(status_code=500, detail="Could not store avatar file in MongoDB.") |
| |
|
| | @router.post("/signup", response_model=Token) |
| | async def signup( |
| | request: Request, |
| | name: str = Form(...), |
| | email: str = Form(...), |
| | password: str = Form(...), |
| | role: str = Form(...), |
| | avatar: Optional[UploadFile] = File(None) |
| | ): |
| | try: |
| | _ = User(name=name, email=email, password=password) |
| | except Exception as e: |
| | logger.error(f"Validation error during signup: {e}") |
| | raise HTTPException(status_code=400, detail=str(e)) |
| | if get_user(email): |
| | logger.warning(f"Attempt to register already existing email: {email}") |
| | raise HTTPException(status_code=400, detail="Email already registered") |
| | |
| | hashed_password = get_password_hash(password) |
| | |
| | user_data = { |
| | "name": name, |
| | "email": email, |
| | "hashed_password": hashed_password, |
| | "role": role, |
| | "chat_histories": [] |
| | } |
| | |
| | if avatar: |
| | file_id = await save_avatar_file_to_gridfs(avatar) |
| | user_data["avatar"] = file_id |
| | |
| | users_collection.insert_one(user_data) |
| | logger.info(f"New user registered: {email} with role: {role}") |
| | |
| | return { |
| | "access_token": create_access_token(email), |
| | "refresh_token": create_refresh_token(email), |
| | "token_type": "bearer" |
| | } |
| |
|
| | @router.post("/login", response_model=LoginResponse) |
| | async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()): |
| | user = authenticate_user(form_data.username, form_data.password) |
| | if not user: |
| | logger.warning(f"Failed login attempt for: {form_data.username}") |
| | raise HTTPException(status_code=401, detail="Incorrect username or password") |
| | |
| | logger.info(f"User logged in: {user['email']}") |
| | |
| | avatar_url = None |
| | if "avatar" in user and user["avatar"]: |
| | avatar_url = f"/auth/avatar/{user['avatar']}" |
| | |
| | return { |
| | "access_token": create_access_token(user["email"]), |
| | "refresh_token": create_refresh_token(user["email"]), |
| | "token_type": "bearer", |
| | "name": user["name"], |
| | "avatar": avatar_url, |
| | "role": user.get("role", "user") |
| | } |
| |
|
| | @router.get("/user/data") |
| | async def get_user_data(request: Request, current_user: dict = Depends(get_current_user)): |
| | avatar_url = None |
| | if "avatar" in current_user and current_user["avatar"]: |
| | avatar_url = f"/auth/avatar/{current_user['avatar']}" |
| | |
| | return { |
| | "name": current_user["name"], |
| | "email": current_user["email"], |
| | "avatar": avatar_url, |
| | "role": current_user.get("role", "user"), |
| | "chat_histories": current_user.get("chat_histories", []) |
| | } |
| |
|
| | @router.put("/user/update") |
| | async def update_user( |
| | request: Request, |
| | name: Optional[str] = Form(None), |
| | email: Optional[str] = Form(None), |
| | password: Optional[str] = Form(None), |
| | avatar: Optional[UploadFile] = File(None), |
| | current_user: dict = Depends(get_current_user) |
| | ): |
| | update_data = {} |
| | if name is not None: |
| | update_data["name"] = name |
| | if email is not None: |
| | update_data["email"] = email |
| | if password is not None: |
| | try: |
| | _ = User(name=current_user["name"], email=current_user["email"], password=password) |
| | except Exception as e: |
| | logger.error(f"Password validation error during update: {e}") |
| | raise HTTPException(status_code=400, detail=str(e)) |
| | update_data["hashed_password"] = get_password_hash(password) |
| | |
| | if avatar: |
| | file_id = await save_avatar_file_to_gridfs(avatar) |
| | update_data["avatar"] = file_id |
| | |
| | if not update_data: |
| | logger.info("No update parameters provided") |
| | raise HTTPException(status_code=400, detail="No update parameters provided") |
| | |
| | users_collection.update_one({"email": current_user["email"]}, {"$set": update_data}) |
| | logger.info(f"User updated: {current_user['email']}") |
| | |
| | return {"message": "User updated successfully"} |
| |
|
| | @router.post("/logout") |
| | async def logout(request: Request, current_user: dict = Depends(get_current_user)): |
| | logger.info(f"User logged out: {current_user['email']}") |
| | return {"message": "User logged out successfully"} |
| |
|
| | @router.get("/avatar/{file_id}") |
| | async def get_avatar(file_id: str): |
| | try: |
| | |
| | file = fs.get(ObjectId(file_id)) |
| | return StreamingResponse(file, media_type=file.content_type) |
| | except Exception as e: |
| | logger.error(f"Avatar not found for file_id {file_id}: {e}") |
| | raise HTTPException(status_code=404, detail="Avatar not found") |