Spaces:
Running
on
Zero
Running
on
Zero
fix: rdd
Browse files
imcui/hloc/extractors/rdd.py
CHANGED
@@ -9,6 +9,7 @@ sys.path.append(str(rdd_path))
|
|
9 |
|
10 |
from RDD.RDD import build as build_rdd
|
11 |
|
|
|
12 |
class Rdd(BaseModel):
|
13 |
default_conf = {
|
14 |
"keypoint_threshold": 0.1,
|
@@ -22,9 +23,7 @@ class Rdd(BaseModel):
|
|
22 |
logger.info("Loading RDD model...")
|
23 |
model_path = self._download_model(
|
24 |
repo_id=MODEL_REPO_ID,
|
25 |
-
filename="{}/{}".format(
|
26 |
-
Path(__file__).stem, self.conf["model_name"]
|
27 |
-
),
|
28 |
)
|
29 |
config_path = rdd_path / "configs/default.yaml"
|
30 |
with open(config_path, "r") as file:
|
@@ -38,6 +37,9 @@ class Rdd(BaseModel):
|
|
38 |
|
39 |
def _forward(self, data):
|
40 |
image = data["image"]
|
|
|
|
|
|
|
41 |
pred = self.net.extract(image)[0]
|
42 |
keypoints = pred["keypoints"]
|
43 |
descriptors = pred["descriptors"]
|
|
|
9 |
|
10 |
from RDD.RDD import build as build_rdd
|
11 |
|
12 |
+
|
13 |
class Rdd(BaseModel):
|
14 |
default_conf = {
|
15 |
"keypoint_threshold": 0.1,
|
|
|
23 |
logger.info("Loading RDD model...")
|
24 |
model_path = self._download_model(
|
25 |
repo_id=MODEL_REPO_ID,
|
26 |
+
filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
|
|
|
|
|
27 |
)
|
28 |
config_path = rdd_path / "configs/default.yaml"
|
29 |
with open(config_path, "r") as file:
|
|
|
37 |
|
38 |
def _forward(self, data):
|
39 |
image = data["image"]
|
40 |
+
self.net.set_softdetect(
|
41 |
+
top_k=self.conf["max_keypoints"], scores_th=self.conf["keypoint_threshold"]
|
42 |
+
)
|
43 |
pred = self.net.extract(image)[0]
|
44 |
keypoints = pred["keypoints"]
|
45 |
descriptors = pred["descriptors"]
|