import io import os import glob import torch import torch.nn as nn from torchvision import models, transforms as T from PIL import Image from fastapi import FastAPI, UploadFile, File, HTTPException import uvicorn app = FastAPI() # =================CONFIGURATION================= CHECKPOINT_PATH = "checkpoints/best_age_model.pth" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # =============================================== # 1. Define the Model Architecture (Must match training EXACTLY) def get_model(): # Load standard ResNet18 structure model = models.resnet18(weights=None) # Reconstruct the final layer exactly as we defined in training num_features = model.fc.in_features model.fc = nn.Sequential( nn.Linear(num_features, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 1) ) return model # 2. Load Weights print(f"Initializing model on {DEVICE}...") model = get_model() if os.path.exists(CHECKPOINT_PATH): print(f"Loading weights from {CHECKPOINT_PATH}...") # Load the state dictionary state_dict = torch.load(CHECKPOINT_PATH, map_location=DEVICE) model.load_state_dict(state_dict) model.to(DEVICE) model.eval() # Important: Turn off Dropout for inference print("✅ Model loaded successfully!") else: print(f"⚠️ Warning: Checkpoint not found at {CHECKPOINT_PATH}") print("Inference will use random weights (garbage output).") # 3. Define Transforms (Must match training) transform = T.Compose([ T.Resize((200, 200)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) @app.get("/") def home(): return {"status": "Age Predictor API is running"} @app.post("/predict") async def predict_age(file: UploadFile = File(...)): """ Accepts an image file and returns predicted age. """ try: # 1. Read and Process Image content = await file.read() image = Image.open(io.BytesIO(content)).convert("RGB") # 2. Transform input_tensor = transform(image).unsqueeze(0).to(DEVICE) # 3. Inference with torch.no_grad(): prediction = model(input_tensor).squeeze() predicted_age = prediction.item() # 4. Return Result return { "filename": file.filename, "predicted_age": round(predicted_age, 1) # Round to 1 decimal } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": # Host 0.0.0.0 is needed for cloud environments/Docker uvicorn.run(app, host="0.0.0.0", port=8000)