kyanmahajan commited on
Commit
10046cc
·
verified ·
1 Parent(s): 53f2d7b

Update xray_classifier.py

Browse files
Files changed (1) hide show
  1. xray_classifier.py +106 -106
xray_classifier.py CHANGED
@@ -1,106 +1,106 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import models, transforms
5
- from flask import Flask, jsonify, request, render_template
6
- from PIL import Image
7
- import os
8
- import numpy as np
9
- import matplotlib.pyplot as plt
10
- from PIL import Image
11
- import numpy as np
12
- import cv2
13
- import cv2
14
- import torch
15
-
16
- from pytorch_grad_cam import GradCAM
17
- from pytorch_grad_cam.utils.image import show_cam_on_image
18
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
19
-
20
- from flask_cors import CORS
21
-
22
-
23
-
24
-
25
- app = Flask(__name__)
26
- CORS(app)
27
- os.makedirs("static", exist_ok=True)
28
-
29
- # Device setup
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
-
32
-
33
- # Transform setup (same as training)
34
- data_transforms = transforms.Compose([
35
- transforms.Resize(256),
36
- transforms.CenterCrop(224),
37
- transforms.ToTensor(),
38
- transforms.Normalize(
39
- mean=[0.485, 0.456, 0.406],
40
- std=[0.229, 0.224, 0.225]
41
- )
42
- ])
43
-
44
-
45
- model = models.resnet18(pretrained=False);
46
- model.fc = nn.Linear(model.fc.in_features, 3);
47
- model.load_state_dict(torch.load("resnet18_brain_tumor.pth", map_location=device))
48
-
49
- model.to(device)
50
- model.eval()
51
-
52
-
53
- class_names = [
54
- "wound",
55
- "brain",
56
- "lung"
57
- ]
58
-
59
- # @app.route("/")
60
- # def home():
61
- # return render_template("index.html")
62
-
63
- @app.route("/predict_classify", methods=["POST"])
64
- def predict():
65
- if "file" not in request.files:
66
- return jsonify({"error": "No file provided"}), 400
67
-
68
- file = request.files["file"]
69
- filepath = os.path.join("static", file.filename)
70
- file.save(filepath)
71
-
72
- try:
73
- image = Image.open(filepath).convert("RGB")
74
- input_tensor = data_transforms(image).unsqueeze(0).to(device)
75
-
76
- with torch.no_grad():
77
- output = model(input_tensor)
78
- pred_idx = torch.argmax(output, dim=1).item()
79
- pred_label = class_names[pred_idx]
80
-
81
-
82
-
83
-
84
-
85
- file={
86
- "prediction": pred_label,
87
-
88
- }
89
- print(file)
90
- return jsonify(file)
91
-
92
-
93
-
94
-
95
-
96
-
97
-
98
-
99
- except Exception as e:
100
- return jsonify({"error": str(e)}), 500
101
-
102
-
103
- if __name__ == '__main__':
104
- app.run(debug=True, host="127.0.0.1", port=5000)
105
-
106
-
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from flask import Flask, jsonify, request, render_template
6
+ from PIL import Image
7
+ import os
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from PIL import Image
11
+ import numpy as np
12
+ import cv2
13
+ import cv2
14
+ import torch
15
+
16
+ from pytorch_grad_cam import GradCAM
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image
18
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
19
+
20
+ from flask_cors import CORS
21
+
22
+
23
+
24
+
25
+ app = Flask(__name__)
26
+ CORS(app)
27
+ os.makedirs("static", exist_ok=True)
28
+
29
+ # Device setup
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+
33
+ # Transform setup (same as training)
34
+ data_transforms = transforms.Compose([
35
+ transforms.Resize(256),
36
+ transforms.CenterCrop(224),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(
39
+ mean=[0.485, 0.456, 0.406],
40
+ std=[0.229, 0.224, 0.225]
41
+ )
42
+ ])
43
+
44
+
45
+ model = models.resnet18(pretrained=False);
46
+ model.fc = nn.Linear(model.fc.in_features, 3);
47
+ model.load_state_dict(torch.load("resnet18_brain_tumor.pth", map_location=device))
48
+
49
+ model.to(device)
50
+ model.eval()
51
+
52
+
53
+ class_names = [
54
+ "wound",
55
+ "brain",
56
+ "lung"
57
+ ]
58
+
59
+ # @app.route("/")
60
+ # def home():
61
+ # return render_template("index.html")
62
+
63
+ @app.route("/predict_classify", methods=["POST"])
64
+ def predict():
65
+ if "file" not in request.files:
66
+ return jsonify({"error": "No file provided"}), 400
67
+
68
+ file = request.files["file"]
69
+ filepath = os.path.join("static", file.filename)
70
+ file.save(filepath)
71
+
72
+ try:
73
+ image = Image.open(filepath).convert("RGB")
74
+ input_tensor = data_transforms(image).unsqueeze(0).to(device)
75
+
76
+ with torch.no_grad():
77
+ output = model(input_tensor)
78
+ pred_idx = torch.argmax(output, dim=1).item()
79
+ pred_label = class_names[pred_idx]
80
+
81
+
82
+
83
+
84
+
85
+ file={
86
+ "prediction": pred_label,
87
+
88
+ }
89
+ print(file)
90
+ return jsonify(file)
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+ except Exception as e:
100
+ return jsonify({"error": str(e)}), 500
101
+
102
+
103
+ if __name__ == '__main__':
104
+ app.run(debug=True, host="0.0.0.0", port=7860)
105
+
106
+