kyanmahajan commited on
Commit
22c4d3c
·
verified ·
1 Parent(s): ed253dd

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +22 -0
  2. requirements.txt +10 -0
  3. xray_classifier.py +106 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ git \
8
+ libgl1 \
9
+ libglib2.0-0 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy all files into the container
13
+ COPY . /app
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Expose port for Flask/Gradio
19
+ EXPOSE 7860
20
+
21
+ # Start the app
22
+ CMD ["python", "app.py"]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ torch
4
+ torchvision
5
+ pillow
6
+ numpy
7
+ matplotlib
8
+ opencv-python-headless
9
+ git+https://github.com/jacobgil/pytorch-grad-cam.git
10
+
xray_classifier.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from flask import Flask, jsonify, request, render_template
6
+ from PIL import Image
7
+ import os
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from PIL import Image
11
+ import numpy as np
12
+ import cv2
13
+ import cv2
14
+ import torch
15
+
16
+ from pytorch_grad_cam import GradCAM
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image
18
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
19
+
20
+ from flask_cors import CORS
21
+
22
+
23
+
24
+
25
+ app = Flask(__name__)
26
+ CORS(app)
27
+ os.makedirs("static", exist_ok=True)
28
+
29
+ # Device setup
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+
33
+ # Transform setup (same as training)
34
+ data_transforms = transforms.Compose([
35
+ transforms.Resize(256),
36
+ transforms.CenterCrop(224),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(
39
+ mean=[0.485, 0.456, 0.406],
40
+ std=[0.229, 0.224, 0.225]
41
+ )
42
+ ])
43
+
44
+
45
+ model = models.resnet18(pretrained=False);
46
+ model.fc = nn.Linear(model.fc.in_features, 3);
47
+ model.load_state_dict(torch.load("resnet18_brain_tumor.pth", map_location=device))
48
+
49
+ model.to(device)
50
+ model.eval()
51
+
52
+
53
+ class_names = [
54
+ "wound",
55
+ "brain",
56
+ "lung"
57
+ ]
58
+
59
+ # @app.route("/")
60
+ # def home():
61
+ # return render_template("index.html")
62
+
63
+ @app.route("/predict_classify", methods=["POST"])
64
+ def predict():
65
+ if "file" not in request.files:
66
+ return jsonify({"error": "No file provided"}), 400
67
+
68
+ file = request.files["file"]
69
+ filepath = os.path.join("static", file.filename)
70
+ file.save(filepath)
71
+
72
+ try:
73
+ image = Image.open(filepath).convert("RGB")
74
+ input_tensor = data_transforms(image).unsqueeze(0).to(device)
75
+
76
+ with torch.no_grad():
77
+ output = model(input_tensor)
78
+ pred_idx = torch.argmax(output, dim=1).item()
79
+ pred_label = class_names[pred_idx]
80
+
81
+
82
+
83
+
84
+
85
+ file={
86
+ "prediction": pred_label,
87
+
88
+ }
89
+ print(file)
90
+ return jsonify(file)
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+ except Exception as e:
100
+ return jsonify({"error": str(e)}), 500
101
+
102
+
103
+ if __name__ == '__main__':
104
+ app.run(debug=True, host="0.0.0.0", port=7860)
105
+
106
+