Lukas Folle commited on
Commit
70583d3
·
1 Parent(s): 29f1be9

nit: fixed formatting

Browse files
Files changed (5) hide show
  1. DummyModel.py +9 -4
  2. Model.py +8 -2
  3. app.py +26 -10
  4. backend.py +2 -4
  5. entrypoint.py +5 -7
DummyModel.py CHANGED
@@ -1,12 +1,17 @@
1
  import torch
 
2
  import torch.nn
 
3
 
4
 
5
  def load_dummy_model(DEBUG):
6
  model = DummyModel()
7
  if not DEBUG:
8
- file_path = hf_hub_download("lfolle/DeepNAPSIModel", "dummy_model.pth",
9
- use_auth_token=os.environ['DeepNAPSIModel'])
 
 
 
10
  model.load_state_dict(torch.load(file_path))
11
  return model
12
 
@@ -15,8 +20,8 @@ class DummyModel(torch.nn.Module):
15
  def __init__(self):
16
  super().__init__()
17
 
18
- def forward(self, x:list):
19
  return torch.softmax(torch.rand(len(x), 5), 1), 0
20
 
21
- def __call__(self, x:list):
22
  return self.forward(x)
 
1
  import torch
2
+ import os
3
  import torch.nn
4
+ from huggingface_hub import hf_hub_download
5
 
6
 
7
  def load_dummy_model(DEBUG):
8
  model = DummyModel()
9
  if not DEBUG:
10
+ file_path = hf_hub_download(
11
+ "lfolle/DeepNAPSIModel",
12
+ "dummy_model.pth",
13
+ use_auth_token=os.environ["DeepNAPSIModel"],
14
+ )
15
  model.load_state_dict(torch.load(file_path))
16
  return model
17
 
 
20
  def __init__(self):
21
  super().__init__()
22
 
23
+ def forward(self, x: list):
24
  return torch.softmax(torch.rand(len(x), 5), 1), 0
25
 
26
+ def __call__(self, x: list):
27
  return self.forward(x)
Model.py CHANGED
@@ -9,8 +9,14 @@ class Model:
9
  base = r"C:\Users\follels\Documents\hand-ki-model-weights\DeepNAPSIModel\inference_checkpoints_v1"
10
  file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
11
  else:
12
- file_paths = [hf_hub_download("lfolle/DeepNAPSIModel", f"version_{v}.ckpt",
13
- use_auth_token=os.environ['DeepNAPSIModel']) for v in [10, 11, 12, 13, 14]]
 
 
 
 
 
 
14
  self.inference = Inference(file_paths)
15
 
16
  def predict(self, x):
 
9
  base = r"C:\Users\follels\Documents\hand-ki-model-weights\DeepNAPSIModel\inference_checkpoints_v1"
10
  file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
11
  else:
12
+ file_paths = [
13
+ hf_hub_download(
14
+ "lfolle/DeepNAPSIModel",
15
+ f"version_{v}.ckpt",
16
+ use_auth_token=os.environ["DeepNAPSIModel"],
17
+ )
18
+ for v in [10, 11, 12, 13, 14]
19
+ ]
20
  self.inference = Inference(file_paths)
21
 
22
  def predict(self, x):
app.py CHANGED
@@ -1,7 +1,4 @@
1
- import os
2
- import pip
3
  import gradio as gr
4
- from PIL import Image
5
 
6
  from backend import Infer
7
 
@@ -9,7 +6,11 @@ from backend import Infer
9
  DEBUG = False
10
 
11
  infer = Infer(DEBUG)
12
- example_image_path = ["assets/example_1.jpg", "assets/example_2.jpg", "assets/example_3.jpg"]
 
 
 
 
13
 
14
  outputs = [
15
  gr.Image(label="Thumb"),
@@ -28,16 +29,27 @@ outputs = [
28
  with gr.Blocks(analytics_enabled=False, title="DeepNAPSI") as demo:
29
  with gr.Column():
30
  gr.Markdown("## Welcome to the DeepNAPSI application!")
31
- gr.Markdown("Upload an image of the one hand and click **Predict NAPSI** to see the output.")
32
- gr.Markdown("*Note*: Make sure there are no identifying information present in the image. The prediction can take up to 4.5 minutes." )
33
- gr.Markdown("*Note*: This is not a medical product and cannot be used for a patient diagnosis in any way.")
 
 
 
 
 
 
34
  with gr.Column():
35
  with gr.Row():
36
  with gr.Column():
37
  with gr.Row():
38
  image_input = gr.Image()
39
- example_images = gr.Examples(example_image_path, image_input, outputs,
40
- fn=infer.predict, cache_examples=True)
 
 
 
 
 
41
  with gr.Row():
42
  image_button = gr.Button("Predict NAPSI")
43
  with gr.Row():
@@ -59,4 +71,8 @@ with gr.Blocks(analytics_enabled=False, title="DeepNAPSI") as demo:
59
  outputs[10].render()
60
  image_button.click(infer.predict, inputs=image_input, outputs=outputs)
61
 
62
- demo.launch(share=True if DEBUG else False, enable_queue=True, favicon_path="assets/favicon-32x32.png")
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
  from backend import Infer
4
 
 
6
  DEBUG = False
7
 
8
  infer = Infer(DEBUG)
9
+ example_image_path = [
10
+ "assets/example_1.jpg",
11
+ "assets/example_2.jpg",
12
+ "assets/example_3.jpg",
13
+ ]
14
 
15
  outputs = [
16
  gr.Image(label="Thumb"),
 
29
  with gr.Blocks(analytics_enabled=False, title="DeepNAPSI") as demo:
30
  with gr.Column():
31
  gr.Markdown("## Welcome to the DeepNAPSI application!")
32
+ gr.Markdown(
33
+ "Upload an image of the one hand and click **Predict NAPSI** to see the output."
34
+ )
35
+ gr.Markdown(
36
+ "*Note*: Make sure there are no identifying information present in the image. The prediction can take up to 4.5 minutes."
37
+ )
38
+ gr.Markdown(
39
+ "*Note*: This is not a medical product and cannot be used for a patient diagnosis in any way."
40
+ )
41
  with gr.Column():
42
  with gr.Row():
43
  with gr.Column():
44
  with gr.Row():
45
  image_input = gr.Image()
46
+ example_images = gr.Examples(
47
+ example_image_path,
48
+ image_input,
49
+ outputs,
50
+ fn=infer.predict,
51
+ cache_examples=True,
52
+ )
53
  with gr.Row():
54
  image_button = gr.Button("Predict NAPSI")
55
  with gr.Row():
 
71
  outputs[10].render()
72
  image_button.click(infer.predict, inputs=image_input, outputs=outputs)
73
 
74
+ demo.launch(
75
+ share=True if DEBUG else False,
76
+ enable_queue=True,
77
+ favicon_path="assets/favicon-32x32.png",
78
+ )
backend.py CHANGED
@@ -1,14 +1,12 @@
1
  import torch
2
  import cv2
3
  import numpy as np
4
- from huggingface_hub import hf_hub_download
5
  from nail_detection.main import get_nails
6
 
7
- from DummyModel import load_dummy_model
8
  from Model import Model
9
 
10
 
11
- class Infer():
12
  def __init__(self, DEBUG):
13
  # self.model = load_dummy_model(DEBUG)
14
  self.model = Model(DEBUG)
@@ -22,7 +20,7 @@ class Infer():
22
  predictions.append(-1)
23
  predictions.append("-1")
24
  else:
25
- model_prediction, uncertainty = self.model(nails)
26
  model_prediction = model_prediction[0]
27
  napsi_predictions = torch.argmax(model_prediction, 1)
28
  napsi_sum = int(napsi_predictions.sum().detach().cpu())
 
1
  import torch
2
  import cv2
3
  import numpy as np
 
4
  from nail_detection.main import get_nails
5
 
 
6
  from Model import Model
7
 
8
 
9
+ class Infer:
10
  def __init__(self, DEBUG):
11
  # self.model = load_dummy_model(DEBUG)
12
  self.model = Model(DEBUG)
 
20
  predictions.append(-1)
21
  predictions.append("-1")
22
  else:
23
+ model_prediction, _ = self.model(nails)
24
  model_prediction = model_prediction[0]
25
  napsi_predictions = torch.argmax(model_prediction, 1)
26
  napsi_sum = int(napsi_predictions.sum().detach().cpu())
entrypoint.py CHANGED
@@ -1,14 +1,12 @@
1
  import os
2
  import sys
3
- import shutil
4
  import subprocess
5
  from huggingface_hub import Repository
6
 
7
- # Bug got fixed lately
8
- # hooks_path = ".git/hooks/"
9
- # if os.path.exists(hooks_path):
10
- # shutil.rmtree(hooks_path)
11
- Repository("repos/hand-ki-model", f"https://oauth2:{os.getenv('HANDKIGIT5')}@git5.cs.fau.de/folle/hand-ki-model.git", use_auth_token=os.getenv(""))
12
 
13
  subprocess.check_call([sys.executable, "-m", "pip", "install", "repos/hand-ki-model/"])
14
- import app
 
1
  import os
2
  import sys
 
3
  import subprocess
4
  from huggingface_hub import Repository
5
 
6
+ Repository(
7
+ "repos/hand-ki-model",
8
+ f"https://oauth2:{os.getenv('HANDKIGIT5')}@git5.cs.fau.de/folle/hand-ki-model.git",
9
+ use_auth_token=os.getenv(""),
10
+ )
11
 
12
  subprocess.check_call([sys.executable, "-m", "pip", "install", "repos/hand-ki-model/"])