AlzDisServer / app_fast.py
gouravbhadraDev's picture
Update app_fast.py
6e1b524 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import numpy as np
import cv2
from tensorflow import keras
import base64
# --- Existing imports for PyTorch ViT ---
import torch
import torch.nn.functional as F
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from io import BytesIO
from collections import OrderedDict
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class_names = ["Mild Dementia", "Moderate Dementia", "Non Dementia", "Very Mild Dementia"]
# --- Load TensorFlow Keras model once at startup ---
model = keras.models.load_model('./model/AlzDisConvModel_hyperTuned.h5')
# Store uploaded image in memory (global variable)
uploaded_image = None
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/upload")
async def upload_file(image: UploadFile = File(...)):
global uploaded_image
try:
contents = await image.read()
nparr = np.frombuffer(contents, np.uint8)
uploaded_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if uploaded_image is None:
raise ValueError("Cannot decode image")
uploaded_image = cv2.resize(uploaded_image, (176, 176))
# Encode resized image to PNG bytes and then base64 string
_, buffer = cv2.imencode('.png', uploaded_image)
img_bytes = buffer.tobytes()
img_b64 = base64.b64encode(img_bytes).decode('utf-8')
return JSONResponse(content={
"message": "File uploaded successfully",
"image_b64": img_b64
})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error reading the image: {e}")
@app.post("/predict")
async def predict():
global uploaded_image
try:
if uploaded_image is None:
raise HTTPException(status_code=400, detail="No uploaded image")
prediction_probs = model.predict(np.expand_dims(uploaded_image, axis=0))
predicted_class_index = np.argmax(prediction_probs)
predicted_class_name = class_names[predicted_class_index]
confidence = float(prediction_probs[0][predicted_class_index])
return {
"prediction": predicted_class_name,
"confidence": confidence
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error predicting: {e}")
finally:
# Reset uploaded image after prediction
uploaded_image = None
# --- PyTorch ViT model and processor loading ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'dhritic9/vit-base-brain-mri-dementia-detection'
processor = ViTImageProcessor.from_pretrained(model_name)
vit_model = ViTForImageClassification.from_pretrained(model_name)
checkpoint_path = './best_vit_model.pth' # adjust path if needed
state_dict = torch.load(checkpoint_path, map_location=device)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
vit_model.load_state_dict(new_state_dict)
vit_model.to(device)
vit_model.eval()
# --- New endpoint for PyTorch ViT prediction ---
@app.post("/predict_vit")
async def predict_vit():
global uploaded_image
try:
if uploaded_image is None:
raise HTTPException(status_code=400, detail="No uploaded image")
# Convert OpenCV image (BGR) to PIL Image (RGB) and resize to 224x224 for ViT
img_rgb = cv2.cvtColor(uploaded_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(img_rgb).resize((224, 224))
inputs = processor(images=pil_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = vit_model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=1)
predicted_class_idx = probs.argmax(dim=1).item()
confidence = probs[0, predicted_class_idx].item()
class_name = vit_model.config.id2label[predicted_class_idx]
return {
"prediction": class_name + " [by VIT]",
"confidence": confidence
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"ViT prediction error: {e}")
finally:
# Reset uploaded image after prediction
uploaded_image = None
'''
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import numpy as np
import cv2
from tensorflow import keras
import base64
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class_names = ["Mild Dementia", "Moderate Dementia", "Non Dementia", "Very Mild Dementia"]
# Load model once at startup
model = keras.models.load_model('./model/AlzDisConvModel_hyperTuned.h5')
# Store uploaded image in memory (global variable)
uploaded_image = None
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/upload")
async def upload_file(image: UploadFile = File(...)):
global uploaded_image
try:
contents = await image.read()
nparr = np.frombuffer(contents, np.uint8)
uploaded_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if uploaded_image is None:
raise ValueError("Cannot decode image")
uploaded_image = cv2.resize(uploaded_image, (176, 176))
# Encode resized image to PNG bytes and then base64 string
_, buffer = cv2.imencode('.png', uploaded_image)
img_bytes = buffer.tobytes()
img_b64 = base64.b64encode(img_bytes).decode('utf-8')
return JSONResponse(content={
"message": "File uploaded successfully",
"image_b64": img_b64
})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error reading the image: {e}")
@app.post("/predict")
async def predict():
global uploaded_image
try:
if uploaded_image is None:
raise HTTPException(status_code=400, detail="No uploaded image")
prediction_probs = model.predict(np.expand_dims(uploaded_image, axis=0))
predicted_class_index = np.argmax(prediction_probs)
predicted_class_name = class_names[predicted_class_index]
confidence = float(prediction_probs[0][predicted_class_index])
return {
"prediction": predicted_class_name,
"confidence": confidence
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error predicting: {e}")
finally:
# Reset uploaded image after prediction
uploaded_image = None
'''