Realcat commited on
Commit
c2c673f
·
1 Parent(s): 4d1f87b
Files changed (1) hide show
  1. imcui/hloc/extractors/rdd.py +5 -3
imcui/hloc/extractors/rdd.py CHANGED
@@ -9,6 +9,7 @@ sys.path.append(str(rdd_path))
9
 
10
  from RDD.RDD import build as build_rdd
11
 
 
12
  class Rdd(BaseModel):
13
  default_conf = {
14
  "keypoint_threshold": 0.1,
@@ -22,9 +23,7 @@ class Rdd(BaseModel):
22
  logger.info("Loading RDD model...")
23
  model_path = self._download_model(
24
  repo_id=MODEL_REPO_ID,
25
- filename="{}/{}".format(
26
- Path(__file__).stem, self.conf["model_name"]
27
- ),
28
  )
29
  config_path = rdd_path / "configs/default.yaml"
30
  with open(config_path, "r") as file:
@@ -38,6 +37,9 @@ class Rdd(BaseModel):
38
 
39
  def _forward(self, data):
40
  image = data["image"]
 
 
 
41
  pred = self.net.extract(image)[0]
42
  keypoints = pred["keypoints"]
43
  descriptors = pred["descriptors"]
 
9
 
10
  from RDD.RDD import build as build_rdd
11
 
12
+
13
  class Rdd(BaseModel):
14
  default_conf = {
15
  "keypoint_threshold": 0.1,
 
23
  logger.info("Loading RDD model...")
24
  model_path = self._download_model(
25
  repo_id=MODEL_REPO_ID,
26
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
 
 
27
  )
28
  config_path = rdd_path / "configs/default.yaml"
29
  with open(config_path, "r") as file:
 
37
 
38
  def _forward(self, data):
39
  image = data["image"]
40
+ self.net.set_softdetect(
41
+ top_k=self.conf["max_keypoints"], scores_th=self.conf["keypoint_threshold"]
42
+ )
43
  pred = self.net.extract(image)[0]
44
  keypoints = pred["keypoints"]
45
  descriptors = pred["descriptors"]