import torch import torch.nn as nn from torchvision import models, transforms from flask import Flask, jsonify, request from PIL import Image import io from flask_cors import CORS # -------------------------- # Flask setup # -------------------------- app = Flask(__name__) CORS(app) # -------------------------- # Device setup # -------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # -------------------------- # Transform setup (same as training) # -------------------------- data_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # -------------------------- # Model setup # -------------------------- model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, 3) model.load_state_dict(torch.load("resnet18_brain_tumor.pth", map_location=device)) model.to(device) model.eval() class_names = ["wound", "brain", "lung"] # -------------------------- # Predict route # -------------------------- @app.route("/predict_classify", methods=["POST"]) def predict(): if "file" not in request.files: return jsonify({"error": "No file provided"}), 400 file = request.files["file"] try: # Load image directly from memory (no saving) image_bytes = file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Transform and prepare input input_tensor = data_transforms(image).unsqueeze(0).to(device) # Model inference with torch.no_grad(): outputs = model(input_tensor) pred_idx = torch.argmax(outputs, dim=1).item() pred_label = class_names[pred_idx] return jsonify({ "prediction": pred_label }) except Exception as e: return jsonify({"error": str(e)}), 500 # -------------------------- # Run server # -------------------------- if __name__ == '__main__': app.run(debug=True, host="0.0.0.0", port=7860)