Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import glob | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms as T | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| import uvicorn | |
| app = FastAPI() | |
| # =================CONFIGURATION================= | |
| CHECKPOINT_PATH = "checkpoints/best_age_model.pth" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # =============================================== | |
| # 1. Define the Model Architecture (Must match training EXACTLY) | |
| def get_model(): | |
| # Load standard ResNet18 structure | |
| model = models.resnet18(weights=None) | |
| # Reconstruct the final layer exactly as we defined in training | |
| num_features = model.fc.in_features | |
| model.fc = nn.Sequential( | |
| nn.Linear(num_features, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, 1) | |
| ) | |
| return model | |
| # 2. Load Weights | |
| print(f"Initializing model on {DEVICE}...") | |
| model = get_model() | |
| if os.path.exists(CHECKPOINT_PATH): | |
| print(f"Loading weights from {CHECKPOINT_PATH}...") | |
| # Load the state dictionary | |
| state_dict = torch.load(CHECKPOINT_PATH, map_location=DEVICE) | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() # Important: Turn off Dropout for inference | |
| print("✅ Model loaded successfully!") | |
| else: | |
| print(f"⚠️ Warning: Checkpoint not found at {CHECKPOINT_PATH}") | |
| print("Inference will use random weights (garbage output).") | |
| # 3. Define Transforms (Must match training) | |
| transform = T.Compose([ | |
| T.Resize((200, 200)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def home(): | |
| return {"status": "Age Predictor API is running"} | |
| async def predict_age(file: UploadFile = File(...)): | |
| """ | |
| Accepts an image file and returns predicted age. | |
| """ | |
| try: | |
| # 1. Read and Process Image | |
| content = await file.read() | |
| image = Image.open(io.BytesIO(content)).convert("RGB") | |
| # 2. Transform | |
| input_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| # 3. Inference | |
| with torch.no_grad(): | |
| prediction = model(input_tensor).squeeze() | |
| predicted_age = prediction.item() | |
| # 4. Return Result | |
| return { | |
| "filename": file.filename, | |
| "predicted_age": round(predicted_age, 1) # Round to 1 decimal | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| # Host 0.0.0.0 is needed for cloud environments/Docker | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |