File size: 3,969 Bytes
0edd049
 
 
 
 
 
 
607801a
0edd049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607801a
 
 
 
0edd049
 
 
 
 
 
 
 
 
 
 
 
 
cad80c7
0edd049
 
cad80c7
0edd049
 
 
 
 
607801a
0edd049
607801a
 
 
0edd049
 
 
cad80c7
0edd049
 
607801a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import sys
from importlib import import_module, invalidate_caches
from importlib.util import module_from_spec, spec_from_file_location
from tempfile import TemporaryDirectory

import cv2
import mediapipe as mp
import numpy as np
import plotly.express as px
import requests
import torch
from git import Repo
from huggingface_hub import hf_hub_download


class FECNetModel:
    def __init__(self, hf_token: str) -> None:
        self.hf_token = hf_token
        repo_dir = TemporaryDirectory()
        Repo.clone_from(
            "https://github.com/AmirSh15/FECNet.git",
            repo_dir.name,
        )
        invalidate_caches()
        sys.path.append(repo_dir.name)
        fecnet_module_path = os.path.join(repo_dir.name, "models", "FECNet.py")
        with open(fecnet_module_path, "r") as f:
            content = f.read()
            content = content.replace(
                "cuda",
                "cpu",
            )
        with open(fecnet_module_path, "w") as f:
            f.write(content)
        spec = spec_from_file_location("FECNet", fecnet_module_path)
        fecnet_module = module_from_spec(spec)  # type: ignore
        spec.loader.exec_module(fecnet_module)  # type: ignore

        self.model = self.__load_model(
            self.__download_weights(repo_dir.name), fecnet_module.FECNet
        )

        self.face_detector = mp.solutions.face_detection.FaceDetection(
            min_detection_confidence=0.5
        )

    def __download_weights(self, model_dir: str) -> str:
        model_path = hf_hub_download(
            "natexcvi/pretrained-fecnet",
            "fecnet.pt",
            token=self.hf_token,
        )
        return model_path

    def __load_model(self, model_path: str, model_class):
        model = model_class(pretrained=False)
        model_weights = torch.load(model_path, map_location=torch.device("cpu"))
        model.load_state_dict(model_weights)
        model.eval()
        return model

    def predict(self, image: np.ndarray):
        pred = self.model(image)
        return pred

    def distance(a, b):
        return np.linalg.norm(a - b)

    def embed_image(self, image, crop_face: bool = False) -> np.ndarray:
        image = cv2.imdecode(image, cv2.IMREAD_COLOR)
        if crop_face:
            image = self.extract_face(image)
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (224, 224))
        image = np.transpose(image, (2, 0, 1))
        image = np.expand_dims(image, axis=0)
        image = torch.from_numpy(image.astype(np.float32))
        pred = self.predict(image)
        return pred.detach().numpy()

    def extract_face(self, image):
        mp_face_detection = mp.solutions.face_detection

        # Convert the image to RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Initialize the face detection model

        # Run the face detection model on the image
        results = self.face_detection.process(image)
        # If a face is detected, crop the image to the face box
        if results.detections:
            for detection in results.detections:
                x, y, w, h = (
                    int(
                        detection.location_data.relative_bounding_box.xmin
                        * image.shape[1]
                    ),
                    int(
                        detection.location_data.relative_bounding_box.ymin
                        * image.shape[0]
                    ),
                    int(
                        detection.location_data.relative_bounding_box.width
                        * image.shape[1]
                    ),
                    int(
                        detection.location_data.relative_bounding_box.height
                        * image.shape[0]
                    ),
                )
                cropped_image = image[y : y + h, x : x + w]
                return cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)