age-detection / server.py
ayushpfullstack's picture
Upload 5 files
4bdca69 verified
import io
import os
import glob
import torch
import torch.nn as nn
from torchvision import models, transforms as T
from PIL import Image
from fastapi import FastAPI, UploadFile, File, HTTPException
import uvicorn
app = FastAPI()
# =================CONFIGURATION=================
CHECKPOINT_PATH = "checkpoints/best_age_model.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===============================================
# 1. Define the Model Architecture (Must match training EXACTLY)
def get_model():
# Load standard ResNet18 structure
model = models.resnet18(weights=None)
# Reconstruct the final layer exactly as we defined in training
num_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_features, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 1)
)
return model
# 2. Load Weights
print(f"Initializing model on {DEVICE}...")
model = get_model()
if os.path.exists(CHECKPOINT_PATH):
print(f"Loading weights from {CHECKPOINT_PATH}...")
# Load the state dictionary
state_dict = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval() # Important: Turn off Dropout for inference
print("✅ Model loaded successfully!")
else:
print(f"⚠️ Warning: Checkpoint not found at {CHECKPOINT_PATH}")
print("Inference will use random weights (garbage output).")
# 3. Define Transforms (Must match training)
transform = T.Compose([
T.Resize((200, 200)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
@app.get("/")
def home():
return {"status": "Age Predictor API is running"}
@app.post("/predict")
async def predict_age(file: UploadFile = File(...)):
"""
Accepts an image file and returns predicted age.
"""
try:
# 1. Read and Process Image
content = await file.read()
image = Image.open(io.BytesIO(content)).convert("RGB")
# 2. Transform
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
# 3. Inference
with torch.no_grad():
prediction = model(input_tensor).squeeze()
predicted_age = prediction.item()
# 4. Return Result
return {
"filename": file.filename,
"predicted_age": round(predicted_age, 1) # Round to 1 decimal
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
# Host 0.0.0.0 is needed for cloud environments/Docker
uvicorn.run(app, host="0.0.0.0", port=8000)