XingyiHe commited on
Commit
6f13a83
·
1 Parent(s): 3f5b583

ADD: MatchAnything

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. imcui/third_party/MatchAnything/LICENSE +202 -0
  2. imcui/third_party/MatchAnything/README.md +104 -0
  3. imcui/third_party/MatchAnything/configs/models/eloftr_model.py +128 -0
  4. imcui/third_party/MatchAnything/configs/models/roma_model.py +27 -0
  5. imcui/third_party/MatchAnything/environment.yaml +14 -0
  6. imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py +1 -0
  7. imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py +344 -0
  8. imcui/third_party/MatchAnything/requirements.txt +22 -0
  9. imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh +17 -0
  10. imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh +17 -0
  11. imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh +17 -0
  12. imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh +17 -0
  13. imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh +17 -0
  14. imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh +17 -0
  15. imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh +17 -0
  16. imcui/third_party/MatchAnything/src/__init__.py +0 -0
  17. imcui/third_party/MatchAnything/src/config/default.py +344 -0
  18. imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py +343 -0
  19. imcui/third_party/MatchAnything/src/loftr/__init__.py +1 -0
  20. imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py +61 -0
  21. imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py +319 -0
  22. imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py +1094 -0
  23. imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py +131 -0
  24. imcui/third_party/MatchAnything/src/loftr/loftr.py +273 -0
  25. imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py +2 -0
  26. imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py +350 -0
  27. imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py +217 -0
  28. imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py +1768 -0
  29. imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py +76 -0
  30. imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py +266 -0
  31. imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py +493 -0
  32. imcui/third_party/MatchAnything/src/loftr/utils/geometry.py +298 -0
  33. imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py +131 -0
  34. imcui/third_party/MatchAnything/src/loftr/utils/supervision.py +475 -0
  35. imcui/third_party/MatchAnything/src/optimizers/__init__.py +50 -0
  36. imcui/third_party/MatchAnything/src/utils/__init__.py +0 -0
  37. imcui/third_party/MatchAnything/src/utils/augment.py +55 -0
  38. imcui/third_party/MatchAnything/src/utils/colmap.py +530 -0
  39. imcui/third_party/MatchAnything/src/utils/colmap/__init__.py +0 -0
  40. imcui/third_party/MatchAnything/src/utils/colmap/database.py +417 -0
  41. imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py +232 -0
  42. imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py +509 -0
  43. imcui/third_party/MatchAnything/src/utils/comm.py +265 -0
  44. imcui/third_party/MatchAnything/src/utils/dataloader.py +23 -0
  45. imcui/third_party/MatchAnything/src/utils/dataset.py +518 -0
  46. imcui/third_party/MatchAnything/src/utils/easydict.py +148 -0
  47. imcui/third_party/MatchAnything/src/utils/geometry.py +366 -0
  48. imcui/third_party/MatchAnything/src/utils/homography_utils.py +366 -0
  49. imcui/third_party/MatchAnything/src/utils/metrics.py +445 -0
  50. imcui/third_party/MatchAnything/src/utils/misc.py +101 -0
imcui/third_party/MatchAnything/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
imcui/third_party/MatchAnything/README.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training
2
+ ### [Project Page](https://zju3dv.github.io/MatchAnything) | [Paper](??)
3
+
4
+ > MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training\
5
+ > [Xingyi He](https://hxy-123.github.io/),
6
+ [Hao Yu](https://ritianyu.github.io/),
7
+ [Sida Peng](https://pengsida.net),
8
+ [Dongli Tan](https://github.com/Cuistiano),
9
+ [Zehong Shen](https://zehongs.github.io),
10
+ [Xiaowei Zhou](https://xzhou.me/),
11
+ [Hujun Bao](http://www.cad.zju.edu.cn/home/bao/)<sup>†</sup>\
12
+ > Arxiv 2025
13
+
14
+ <p align="center">
15
+ <img src=docs/teaser_demo.gif alt="animated" />
16
+ </p>
17
+
18
+ ## TODO List
19
+ - [x] Pre-trained models and inference code
20
+ - [x] Huggingface demo
21
+ - [ ] Data generation and training code
22
+ - [ ] Finetune code to further train on your own data
23
+ - [ ] Incorporate more synthetic modalities and image generation methods
24
+
25
+ ## Quick Start
26
+
27
+ ### [<img src="https://s2.loli.net/2024/09/15/aw3rElfQAsOkNCn.png" width="20"> HuggingFace demo for MatchAnything](https://huggingface.co/spaces/LittleFrog/MatchAnything)
28
+
29
+ ## Setup
30
+ Create the python environment by:
31
+ ```
32
+ conda env create -f environment.yaml
33
+ conda activate env
34
+ ```
35
+ We have tested our code on the device with CUDA 11.7.
36
+
37
+ Download pretrained weights from [here](https://drive.google.com/file/d/12L3g9-w8rR9K2L4rYaGaDJ7NqX1D713d/view?usp=sharing) and place it under repo directory. Then unzip it by running the following command:
38
+ ```
39
+ unzip weights.zip
40
+ rm -rf weights.zip
41
+ ```
42
+
43
+ ## Test:
44
+ We evaluate the models pretrained by our framework using a single network weight on all cross-modality matching and registration tasks.
45
+
46
+ ### Data Preparing
47
+ Download the `test_data` directory from [here](https://drive.google.com/drive/folders/1jpxIOcgnQfl9IEPPifdXQ7S7xuj9K4j7?usp=sharing) and plase it under `repo_directory/data`. Then, unzip all datasets by:
48
+ ```shell
49
+ cd repo_directiry/data/test_data
50
+
51
+ for file in *.zip; do
52
+ unzip "$file" && rm "$file"
53
+ done
54
+ ```
55
+
56
+ The data structure should looks like:
57
+ ```
58
+ repo_directiry/data/test_data
59
+ - Liver_CT-MR
60
+ - havard_medical_matching
61
+ - remote_sense_thermal
62
+ - MTV_cross_modal_data
63
+ - thermal_visible_ground
64
+ - visible_sar_dataset
65
+ - visible_vectorized_map
66
+ ```
67
+
68
+ ### Evaluation
69
+ ```shell
70
+ # For Tomography datasets:
71
+ sh scripts/evaluate/eval_liver_ct_mr.sh
72
+ sh scripts/evaluate/eval_harvard_brain.sh
73
+
74
+
75
+
76
+ # For visible-thermal datasets:
77
+ sh scripts/evaluate/eval_thermal_remote_sense.sh
78
+ sh scripts/evaluate/eval_thermal_mtv.sh
79
+ sh scripts/evaluate/eval_thermal_ground.sh
80
+
81
+ # For visible-sar dataset:
82
+ sh scripts/evaluate/eval_visible_sar.sh
83
+
84
+ # For visible-vectorized map dataset:
85
+ sh scripts/evaluate/eval_visible_vectorized_map.sh
86
+ ```
87
+
88
+ # Citation
89
+
90
+ If you find this code useful for your research, please use the following BibTeX entry.
91
+
92
+ ```
93
+ @inproceedings{he2025matchanything,
94
+ title={MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training},
95
+ author={He, Xingyi and Yu, Hao and Peng, Sida and Tan, Dongli and Shen, Zehong and Bao, Hujun and Zhou, Xiaowei},
96
+ booktitle={Arxiv},
97
+ year={2025}
98
+ }
99
+ ```
100
+
101
+ # Acknowledgement
102
+ We thank the authors of
103
+ [ELoFTR](https://github.com/zju3dv/EfficientLoFTR),
104
+ [ROMA](https://github.com/Parskatt/RoMa) for their great works, without which our project/code would not be possible.
imcui/third_party/MatchAnything/configs/models/eloftr_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.config.default import _CN as cfg
2
+
3
+ cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
4
+ cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False
5
+
6
+ cfg.TRAINER.CANONICAL_LR = 8e-3
7
+ cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
8
+ cfg.TRAINER.WARMUP_RATIO = 0.1
9
+
10
+ cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16]
11
+
12
+ # pose estimation
13
+ cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
14
+
15
+ cfg.TRAINER.OPTIMIZER = "adamw"
16
+ cfg.TRAINER.ADAMW_DECAY = 0.1
17
+
18
+ cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.1
19
+
20
+ cfg.LOFTR.MATCH_COARSE.MTD_SPVS = True
21
+ cfg.LOFTR.FINE.MTD_SPVS = True
22
+
23
+ cfg.LOFTR.RESOLUTION = (8, 1) # options: [(8, 2), (16, 4)]
24
+ cfg.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be odd
25
+ cfg.LOFTR.MATCH_FINE.THR = 0
26
+ cfg.LOFTR.LOSS.FINE_TYPE = 'l2' # ['l2_with_std', 'l2']
27
+
28
+ cfg.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
29
+
30
+ cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
31
+
32
+ # PAN
33
+ cfg.LOFTR.COARSE.PAN = True
34
+ cfg.LOFTR.COARSE.POOl_SIZE = 4
35
+ cfg.LOFTR.COARSE.BN = False
36
+ cfg.LOFTR.COARSE.XFORMER = True
37
+ cfg.LOFTR.COARSE.ATTENTION = 'full' # options: ['linear', 'full']
38
+
39
+ cfg.LOFTR.FINE.PAN = False
40
+ cfg.LOFTR.FINE.POOl_SIZE = 4
41
+ cfg.LOFTR.FINE.BN = False
42
+ cfg.LOFTR.FINE.XFORMER = False
43
+
44
+ # noalign
45
+ cfg.LOFTR.ALIGN_CORNER = False
46
+
47
+ # fp16
48
+ cfg.DATASET.FP16 = False
49
+ cfg.LOFTR.FP16 = False
50
+
51
+ # DEBUG
52
+ cfg.LOFTR.FP16LOG = False
53
+ cfg.LOFTR.MATCH_COARSE.FP16LOG = False
54
+
55
+ # fine skip
56
+ cfg.LOFTR.FINE.SKIP = True
57
+
58
+ # clip
59
+ cfg.TRAINER.GRADIENT_CLIPPING = 0.5
60
+
61
+ # backbone
62
+ cfg.LOFTR.BACKBONE_TYPE = 'RepVGG'
63
+
64
+ # A1
65
+ cfg.LOFTR.RESNETFPN.INITIAL_DIM = 64
66
+ cfg.LOFTR.RESNETFPN.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3
67
+ cfg.LOFTR.COARSE.D_MODEL = 256
68
+ cfg.LOFTR.FINE.D_MODEL = 64
69
+
70
+ # FPN backbone_inter_feat with coarse_attn.
71
+ cfg.LOFTR.COARSE_FEAT_ONLY = True
72
+ cfg.LOFTR.INTER_FEAT = True
73
+ cfg.LOFTR.RESNETFPN.COARSE_FEAT_ONLY = True
74
+ cfg.LOFTR.RESNETFPN.INTER_FEAT = True
75
+
76
+ # loop back spv coarse match
77
+ cfg.LOFTR.FORCE_LOOP_BACK = False
78
+
79
+ # fix norm fine match
80
+ cfg.LOFTR.MATCH_FINE.NORMFINEM = True
81
+
82
+ # loss cf weight
83
+ cfg.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = True
84
+ cfg.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = True
85
+
86
+ # leaky relu
87
+ cfg.LOFTR.RESNETFPN.LEAKY = False
88
+ cfg.LOFTR.COARSE.LEAKY = 0.01
89
+
90
+ # prevent FP16 OVERFLOW in dirty data
91
+ cfg.LOFTR.NORM_FPNFEAT = True
92
+ cfg.LOFTR.REPLACE_NAN = True
93
+
94
+ # force mutual nearest
95
+ cfg.LOFTR.MATCH_COARSE.FORCE_NEAREST = True
96
+ cfg.LOFTR.MATCH_COARSE.THR = 0.1
97
+
98
+ # fix fine matching
99
+ cfg.LOFTR.MATCH_FINE.FIX_FINE_MATCHING = True
100
+
101
+ # dwconv
102
+ cfg.LOFTR.COARSE.DWCONV = True
103
+
104
+ # localreg
105
+ cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS = True
106
+ cfg.LOFTR.LOSS.LOCAL_WEIGHT = 0.25
107
+
108
+ # it5
109
+ cfg.LOFTR.EVAL_TIMES = 1
110
+
111
+ # rope
112
+ cfg.LOFTR.COARSE.ROPE = True
113
+
114
+ # local regress temperature
115
+ cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0
116
+
117
+ # SLICE
118
+ cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICE = True
119
+ cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
120
+
121
+ # inner with no mask [64,100]
122
+ cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_INNER = True
123
+ cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK = True
124
+
125
+ cfg.LOFTR.MATCH_FINE.TOPK = 1
126
+ cfg.LOFTR.MATCH_COARSE.FINE_TOPK = 1
127
+
128
+ cfg.LOFTR.MATCH_COARSE.FP16MATMUL = False
imcui/third_party/MatchAnything/configs/models/roma_model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.config.default import _CN as cfg
2
+ cfg.ROMA.RESIZE_BY_STRETCH = True
3
+ cfg.DATASET.RESIZE_BY_STRETCH = True
4
+
5
+ cfg.TRAINER.CANONICAL_LR = 8e-3
6
+ cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
7
+ cfg.TRAINER.WARMUP_RATIO = 0.1
8
+
9
+ cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16, 18, 20]
10
+
11
+ # pose estimation
12
+ cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
13
+
14
+ cfg.TRAINER.OPTIMIZER = "adamw"
15
+ cfg.TRAINER.ADAMW_DECAY = 0.1
16
+ cfg.TRAINER.OPTIMIZER_EPS = 5e-7
17
+
18
+ cfg.TRAINER.EPI_ERR_THR = 5e-4
19
+
20
+ # fp16
21
+ cfg.DATASET.FP16 = False
22
+ cfg.LOFTR.FP16 = True
23
+
24
+ # clip
25
+ cfg.TRAINER.GRADIENT_CLIPPING = 0.5
26
+
27
+ cfg.LOFTR.ROMA_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = True
imcui/third_party/MatchAnything/environment.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: env
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - python=3.8
9
+ - pytorch-cuda=11.7
10
+ - pytorch=1.12.1
11
+ - torchvision=0.13.1
12
+ - pip
13
+ - pip:
14
+ - -r requirements.txt
imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .plotting import *
imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib
4
+ from matplotlib.colors import hsv_to_rgb
5
+ import pylab as pl
6
+ import matplotlib.cm as cm
7
+ from PIL import Image
8
+ import cv2
9
+
10
+
11
+ def visualize_features(feat, img_h, img_w, save_path=None):
12
+ from sklearn.decomposition import PCA
13
+ pca = PCA(n_components=3, svd_solver="arpack")
14
+ img = pca.fit_transform(feat).reshape(img_h * 2, img_w, 3)
15
+ img_norm = cv2.normalize(
16
+ img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC3
17
+ )
18
+ img_resized = cv2.resize(
19
+ img_norm, (img_w * 8, img_h * 2 * 8), interpolation=cv2.INTER_LINEAR
20
+ )
21
+ img_colormap = img_resized
22
+ img1, img2 = img_colormap[: img_h * 8, :, :], img_colormap[img_h * 8 :, :, :]
23
+ img_gapped = np.hstack(
24
+ (img1, np.ones((img_h * 8, 10, 3), dtype=np.uint8) * 255, img2)
25
+ )
26
+ if save_path is not None:
27
+ cv2.imwrite(save_path, img_gapped)
28
+
29
+ fig, axes = plt.subplots(1, 1, dpi=200)
30
+ axes.imshow(img_gapped)
31
+ axes.get_yaxis().set_ticks([])
32
+ axes.get_xaxis().set_ticks([])
33
+ plt.tight_layout(pad=0.5)
34
+ return fig
35
+
36
+ def make_matching_figure(
37
+ img0,
38
+ img1,
39
+ mkpts0,
40
+ mkpts1,
41
+ color,
42
+ kpts0=None,
43
+ kpts1=None,
44
+ text=[],
45
+ path=None,
46
+ draw_detection=False,
47
+ draw_match_type='corres', # ['color', 'corres', None]
48
+ r_normalize_factor=0.4,
49
+ white_center=True,
50
+ vertical=False,
51
+ use_position_color=False,
52
+ draw_local_window=False,
53
+ window_size=(9, 9),
54
+ plot_size_factor=1, # Point size and line width
55
+ anchor_pts0=None,
56
+ anchor_pts1=None,
57
+ rescale_thr=5000,
58
+ ):
59
+ if (max(img0.shape) > rescale_thr) or (max(img1.shape) > rescale_thr):
60
+ scale_factor = 0.5
61
+ img0 = np.array(Image.fromarray((img0 * 255).astype(np.uint8)).resize((int(img0.shape[1] * scale_factor), int(img0.shape[0] * scale_factor)))) / 255.
62
+ img1 = np.array(Image.fromarray((img1 * 255).astype(np.uint8)).resize((int(img1.shape[1] * scale_factor), int(img1.shape[0] * scale_factor)))) / 255.
63
+ mkpts0, mkpts1 = mkpts0 * scale_factor, mkpts1 * scale_factor
64
+ if kpts0 is not None:
65
+ kpts0, kpts1 = kpts0 * scale_factor, kpts1 * scale_factor
66
+
67
+ # draw image pair
68
+ fig, axes = (
69
+ plt.subplots(2, 1, figsize=(10, 6), dpi=600)
70
+ if vertical
71
+ else plt.subplots(1, 2, figsize=(10, 6), dpi=600)
72
+ )
73
+ axes[0].imshow(img0, aspect='auto')
74
+ axes[1].imshow(img1, aspect='auto')
75
+
76
+ # axes[0].imshow(img0, aspect='equal')
77
+ # axes[1].imshow(img1, aspect='equal')
78
+ for i in range(2): # clear all frames
79
+ axes[i].get_yaxis().set_ticks([])
80
+ axes[i].get_xaxis().set_ticks([])
81
+ for spine in axes[i].spines.values():
82
+ spine.set_visible(False)
83
+ plt.tight_layout(pad=1)
84
+
85
+ if use_position_color:
86
+ mean_coord = np.mean(mkpts0, axis=0)
87
+ x_center, y_center = mean_coord
88
+ # NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive.
89
+ position_color = matching_coord2color(
90
+ mkpts0,
91
+ x_center,
92
+ y_center,
93
+ r_normalize_factor=r_normalize_factor,
94
+ white_center=white_center,
95
+ )
96
+ color[:, :3] = position_color
97
+
98
+ if draw_detection and kpts0 is not None and kpts1 is not None:
99
+ # color = 'g'
100
+ color = 'r'
101
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=1 * plot_size_factor)
102
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=1 * plot_size_factor)
103
+
104
+ if draw_match_type is 'corres':
105
+ # draw matches
106
+ fig.canvas.draw()
107
+ plt.pause(2.0)
108
+ transFigure = fig.transFigure.inverted()
109
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
110
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
111
+ fig.lines = [
112
+ matplotlib.lines.Line2D(
113
+ (fkpts0[i, 0], fkpts1[i, 0]),
114
+ (fkpts0[i, 1], fkpts1[i, 1]),
115
+ transform=fig.transFigure,
116
+ c=color[i],
117
+ linewidth=1* plot_size_factor,
118
+ )
119
+ for i in range(len(mkpts0))
120
+ ]
121
+
122
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=2* plot_size_factor)
123
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=2* plot_size_factor)
124
+ elif draw_match_type is 'color':
125
+ # x_center = img0.shape[-1] / 2
126
+ # y_center = img1.shape[-2] / 2
127
+
128
+ mean_coord = np.mean(mkpts0, axis=0)
129
+ x_center, y_center = mean_coord
130
+ # NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive.
131
+ kpts_color = matching_coord2color(
132
+ mkpts0,
133
+ x_center,
134
+ y_center,
135
+ r_normalize_factor=r_normalize_factor,
136
+ white_center=white_center,
137
+ )
138
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=kpts_color, s=1 * plot_size_factor)
139
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=kpts_color, s=1 * plot_size_factor)
140
+
141
+ if draw_local_window:
142
+ anchor_pts0 = mkpts0 if anchor_pts0 is None else anchor_pts0
143
+ anchor_pts1 = mkpts1 if anchor_pts1 is None else anchor_pts1
144
+ plot_local_windows(
145
+ anchor_pts0, color=(1, 0, 0, 0.4), lw=0.2, ax_=0, window_size=window_size
146
+ )
147
+ plot_local_windows(
148
+ anchor_pts1, color=(1, 0, 0, 0.4), lw=0.2, ax_=1, window_size=window_size
149
+ ) # lw =0.2
150
+
151
+ # put txts
152
+ txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
153
+ fig.text(
154
+ 0.01,
155
+ 0.99,
156
+ "\n".join(text),
157
+ transform=fig.axes[0].transAxes,
158
+ fontsize=15,
159
+ va="top",
160
+ ha="left",
161
+ color=txt_color,
162
+ )
163
+ plt.tight_layout(pad=1)
164
+
165
+ # save or return figure
166
+ if path:
167
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
168
+ plt.close()
169
+ else:
170
+ return fig
171
+
172
+ def make_triple_matching_figure(
173
+ img0,
174
+ img1,
175
+ img2,
176
+ mkpts01,
177
+ mkpts12,
178
+ color01,
179
+ color12,
180
+ text=[],
181
+ path=None,
182
+ draw_match=True,
183
+ r_normalize_factor=0.4,
184
+ white_center=True,
185
+ vertical=False,
186
+ draw_local_window=False,
187
+ window_size=(9, 9),
188
+ anchor_pts0=None,
189
+ anchor_pts1=None,
190
+ ):
191
+ # draw image pair
192
+ fig, axes = (
193
+ plt.subplots(3, 1, figsize=(10, 6), dpi=600)
194
+ if vertical
195
+ else plt.subplots(1, 3, figsize=(10, 6), dpi=600)
196
+ )
197
+ axes[0].imshow(img0)
198
+ axes[1].imshow(img1)
199
+ axes[2].imshow(img2)
200
+ for i in range(3): # clear all frames
201
+ axes[i].get_yaxis().set_ticks([])
202
+ axes[i].get_xaxis().set_ticks([])
203
+ for spine in axes[i].spines.values():
204
+ spine.set_visible(False)
205
+ plt.tight_layout(pad=1)
206
+
207
+ if draw_match:
208
+ # draw matches for [0,1]
209
+ fig.canvas.draw()
210
+ transFigure = fig.transFigure.inverted()
211
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts01[0]))
212
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts01[1]))
213
+ fig.lines = [
214
+ matplotlib.lines.Line2D(
215
+ (fkpts0[i, 0], fkpts1[i, 0]),
216
+ (fkpts0[i, 1], fkpts1[i, 1]),
217
+ transform=fig.transFigure,
218
+ c=color01[i],
219
+ linewidth=1,
220
+ )
221
+ for i in range(len(mkpts01[0]))
222
+ ]
223
+
224
+ axes[0].scatter(mkpts01[0][:, 0], mkpts01[0][:, 1], c=color01[:, :3], s=1)
225
+ axes[1].scatter(mkpts01[1][:, 0], mkpts01[1][:, 1], c=color01[:, :3], s=1)
226
+
227
+ fig.canvas.draw()
228
+ # draw matches for [1,2]
229
+ fkpts1_1 = transFigure.transform(axes[1].transData.transform(mkpts12[0]))
230
+ fkpts2 = transFigure.transform(axes[2].transData.transform(mkpts12[1]))
231
+ fig.lines += [
232
+ matplotlib.lines.Line2D(
233
+ (fkpts1_1[i, 0], fkpts2[i, 0]),
234
+ (fkpts1_1[i, 1], fkpts2[i, 1]),
235
+ transform=fig.transFigure,
236
+ c=color12[i],
237
+ linewidth=1,
238
+ )
239
+ for i in range(len(mkpts12[0]))
240
+ ]
241
+
242
+ axes[1].scatter(mkpts12[0][:, 0], mkpts12[0][:, 1], c=color12[:, :3], s=1)
243
+ axes[2].scatter(mkpts12[1][:, 0], mkpts12[1][:, 1], c=color12[:, :3], s=1)
244
+
245
+ # # put txts
246
+ # txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
247
+ # fig.text(
248
+ # 0.01,
249
+ # 0.99,
250
+ # "\n".join(text),
251
+ # transform=fig.axes[0].transAxes,
252
+ # fontsize=15,
253
+ # va="top",
254
+ # ha="left",
255
+ # color=txt_color,
256
+ # )
257
+ plt.tight_layout(pad=0.1)
258
+
259
+ # save or return figure
260
+ if path:
261
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
262
+ plt.close()
263
+ else:
264
+ return fig
265
+
266
+
267
+ def matching_coord2color(kpts, x_center, y_center, r_normalize_factor=0.4, white_center=True):
268
+ """
269
+ r_normalize_factor is used to visualize clearer according to points space distribution
270
+ r_normalize_factor maxium=1, larger->points darker/brighter
271
+ """
272
+ if not white_center:
273
+ # dark center points
274
+ V, H = np.mgrid[0:1:10j, 0:1:360j]
275
+ S = np.ones_like(V)
276
+ else:
277
+ # white center points
278
+ S, H = np.mgrid[0:1:10j, 0:1:360j]
279
+ V = np.ones_like(S)
280
+
281
+ HSV = np.dstack((H, S, V))
282
+ RGB = hsv_to_rgb(HSV)
283
+ """
284
+ # used to visualize hsv
285
+ pl.imshow(RGB, origin="lower", extent=[0, 360, 0, 1], aspect=150)
286
+ pl.xlabel("H")
287
+ pl.ylabel("S")
288
+ pl.title("$V_{HSV}=1$")
289
+ pl.show()
290
+ """
291
+ kpts = np.copy(kpts)
292
+ distance = kpts - np.array([x_center, y_center])[None]
293
+ r_max = np.percentile(np.linalg.norm(distance, axis=1), 85)
294
+ # r_max = np.sqrt((x_center) ** 2 + (y_center) ** 2)
295
+ kpts[:, 0] = kpts[:, 0] - x_center # x
296
+ kpts[:, 1] = kpts[:, 1] - y_center # y
297
+
298
+ r = np.sqrt(kpts[:, 0] ** 2 + kpts[:, 1] ** 2) + 1e-6
299
+ r_normalized = r / (r_max * r_normalize_factor)
300
+ r_normalized[r_normalized > 1] = 1
301
+ r_normalized = (r_normalized) * 9
302
+
303
+ cos_theta = kpts[:, 0] / r # x / r
304
+ theta = np.arccos(cos_theta) # from 0 to pi
305
+ change_angle_mask = kpts[:, 1] < 0
306
+ theta[change_angle_mask] = 2 * np.pi - theta[change_angle_mask]
307
+ theta_degree = np.degrees(theta)
308
+ theta_degree[theta_degree == 360] = 0 # to avoid overflow
309
+ theta_degree = theta_degree / 360 * 360
310
+ kpts_color = RGB[r_normalized.astype(int), theta_degree.astype(int)]
311
+ return kpts_color
312
+
313
+
314
+ def show_image_pair(img0, img1, path=None):
315
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=200)
316
+ axes[0].imshow(img0, cmap="gray")
317
+ axes[1].imshow(img1, cmap="gray")
318
+ for i in range(2): # clear all frames
319
+ axes[i].get_yaxis().set_ticks([])
320
+ axes[i].get_xaxis().set_ticks([])
321
+ for spine in axes[i].spines.values():
322
+ spine.set_visible(False)
323
+ plt.tight_layout(pad=1)
324
+ if path:
325
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
326
+ return fig
327
+
328
+ def plot_local_windows(kpts, color="r", lw=1, ax_=0, window_size=(9, 9)):
329
+ ax = plt.gcf().axes
330
+ for kpt in kpts:
331
+ ax[ax_].add_patch(
332
+ matplotlib.patches.Rectangle(
333
+ (
334
+ kpt[0] - (window_size[0] // 2) - 1,
335
+ kpt[1] - (window_size[1] // 2) - 1,
336
+ ),
337
+ window_size[0] + 1,
338
+ window_size[1] + 1,
339
+ lw=lw,
340
+ color=color,
341
+ fill=False,
342
+ )
343
+ )
344
+
imcui/third_party/MatchAnything/requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv_python==4.4.0.46
2
+ albumentations==0.5.1 --no-binary=imgaug,albumentations
3
+ Pillow==9.5.0
4
+ ray==2.9.3
5
+ einops==0.3.0
6
+ kornia==0.4.1
7
+ loguru==0.5.3
8
+ yacs>=0.1.8
9
+ tqdm
10
+ autopep8
11
+ pylint
12
+ ipython
13
+ jupyterlab
14
+ matplotlib
15
+ h5py==3.1.0
16
+ pytorch-lightning==1.3.5
17
+ torchmetrics==0.6.0 # version problem: https://github.com/NVIDIA/DeepLearningExamples/issues/1113#issuecomment-1102969461
18
+ joblib>=1.0.1
19
+ pynvml
20
+ gpustat
21
+ safetensors
22
+ timm==0.6.7
imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ SCRIPTPATH=$(dirname $(readlink -f "$0"))
4
+ PROJECT_DIR="${SCRIPTPATH}/../../"
5
+
6
+ cd $PROJECT_DIR
7
+
8
+ DEVICE_ID='0'
9
+ NPZ_ROOT=data/test_data/havard_medical_matching/all_eval
10
+ NPZ_LIST_PATH=data/test_data/havard_medical_matching/all_eval/val_list.txt
11
+ OUTPUT_PATH=results/havard_medical_matching
12
+
13
+ # ELoFTR pretrained:
14
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --thr 0.05 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
15
+
16
+ # ROMA pretrained:
17
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ SCRIPTPATH=$(dirname $(readlink -f "$0"))
4
+ PROJECT_DIR="${SCRIPTPATH}/../../"
5
+
6
+ cd $PROJECT_DIR
7
+
8
+ DEVICE_ID='0'
9
+ NPZ_ROOT=data/test_data/Liver_CT-MR/eval_indexs
10
+ NPZ_LIST_PATH=data/test_data/Liver_CT-MR/eval_indexs/val_list.txt
11
+ OUTPUT_PATH=results/Liver_CT-MR
12
+
13
+ # ELoFTR pretrained:
14
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
15
+
16
+ # ROMA pretrained:
17
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ SCRIPTPATH=$(dirname $(readlink -f "$0"))
4
+ PROJECT_DIR="${SCRIPTPATH}/../../"
5
+
6
+ cd $PROJECT_DIR
7
+
8
+ DEVICE_ID='0'
9
+ NPZ_ROOT=data/test_data/thermal_visible_ground/eval_indexs
10
+ NPZ_LIST_PATH=data/test_data/thermal_visible_ground/eval_indexs/val_list.txt
11
+ OUTPUT_PATH=results/thermal_visible_ground
12
+
13
+ # ELoFTR pretrained:
14
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
15
+
16
+ # ROMA pretrained:
17
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ SCRIPTPATH=$(dirname $(readlink -f "$0"))
4
+ PROJECT_DIR="${SCRIPTPATH}/../../"
5
+
6
+ cd $PROJECT_DIR
7
+
8
+ DEVICE_ID='0'
9
+ NPZ_ROOT=data/test_data/MTV_cross_modal_data/scene_info/scene_info
10
+ NPZ_LIST_PATH=data/test_data/MTV_cross_modal_data/scene_info/test_list.txt
11
+ OUTPUT_PATH=results/MTV_cross_modal_data
12
+
13
+ # ELoFTR pretrained:
14
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
15
+
16
+ # ROMA pretrained:
17
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ SCRIPTPATH=$(dirname $(readlink -f "$0"))
4
+ PROJECT_DIR="${SCRIPTPATH}/../../"
5
+
6
+ cd $PROJECT_DIR
7
+
8
+ DEVICE_ID='0'
9
+ NPZ_ROOT=data/test_data/remote_sense_thermal/eval_Optical-Infrared
10
+ NPZ_LIST_PATH=data/test_data/remote_sense_thermal/eval_Optical-Infrared/val_list.txt
11
+ OUTPUT_PATH=results/remote_sense_thermal
12
+
13
+ # ELoFTR pretrained:
14
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
15
+
16
+ # ROMA pretrained:
17
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ SCRIPTPATH=$(dirname $(readlink -f "$0"))
4
+ PROJECT_DIR="${SCRIPTPATH}/../../"
5
+
6
+ cd $PROJECT_DIR
7
+
8
+ DEVICE_ID='0'
9
+ NPZ_ROOT=data/test_data/visible_sar_dataset/eval
10
+ NPZ_LIST_PATH=data/test_data/visible_sar_dataset/eval/val_list.txt
11
+ OUTPUT_PATH=results/visible_sar_dataset
12
+
13
+ # ELoFTR pretrained:
14
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --thr 0.05 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
15
+
16
+ # ROMA pretrained:
17
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ SCRIPTPATH=$(dirname $(readlink -f "$0"))
4
+ PROJECT_DIR="${SCRIPTPATH}/../../"
5
+
6
+ cd $PROJECT_DIR
7
+
8
+ DEVICE_ID='0'
9
+ NPZ_ROOT=data/test_data/visible_vectorized_map/scene_indices
10
+ NPZ_LIST_PATH=data/test_data/visible_vectorized_map/scene_indices/val_list.txt
11
+ OUTPUT_PATH=results/visible_vectorized_map
12
+
13
+ # ELoFTR pretrained:
14
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
15
+
16
+ # ROMA pretrained:
17
+ CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
imcui/third_party/MatchAnything/src/__init__.py ADDED
File without changes
imcui/third_party/MatchAnything/src/config/default.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+ _CN = CN()
3
+ ############## ROMA Pipeline #########
4
+ _CN.ROMA = CN()
5
+ _CN.ROMA.MATCH_THRESH = 0.0
6
+ _CN.ROMA.RESIZE_BY_STRETCH = False # Used for test mode
7
+ _CN.ROMA.NORMALIZE_IMG = False # Used for test mode
8
+
9
+ _CN.ROMA.MODE = "train_framework" # Used in Lightning Train & Val
10
+ _CN.ROMA.MODEL = CN()
11
+ _CN.ROMA.MODEL.COARSE_BACKBONE = 'DINOv2_large'
12
+ _CN.ROMA.MODEL.COARSE_FEAT_DIM = 1024
13
+ _CN.ROMA.MODEL.MEDIUM_FEAT_DIM = 512
14
+ _CN.ROMA.MODEL.COARSE_PATCH_SIZE = 14
15
+ _CN.ROMA.MODEL.AMP = True # FP16 mode
16
+
17
+ _CN.ROMA.SAMPLE = CN()
18
+ _CN.ROMA.SAMPLE.METHOD = "threshold_balanced"
19
+ _CN.ROMA.SAMPLE.N_SAMPLE = 5000
20
+ _CN.ROMA.SAMPLE.THRESH = 0.05
21
+
22
+ _CN.ROMA.TEST_TIME = CN()
23
+ _CN.ROMA.TEST_TIME.COARSE_RES = (560, 560) # need to divisable by 14 & 8
24
+ _CN.ROMA.TEST_TIME.UPSAMPLE = True
25
+ _CN.ROMA.TEST_TIME.UPSAMPLE_RES = (864, 864) # need to divisable by 8
26
+ _CN.ROMA.TEST_TIME.SYMMETRIC = True
27
+ _CN.ROMA.TEST_TIME.ATTENUTATE_CERT = True
28
+
29
+ ############## ↓ LoFTR Pipeline ↓ ##############
30
+ _CN.LOFTR = CN()
31
+ _CN.LOFTR.BACKBONE_TYPE = 'ResNetFPN'
32
+ _CN.LOFTR.ALIGN_CORNER = True
33
+ _CN.LOFTR.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
34
+ _CN.LOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
35
+ _CN.LOFTR.FINE_WINDOW_MATCHING_SIZE = 5 # window_size for loftr fine-matching, odd for select and even for average
36
+ _CN.LOFTR.FINE_CONCAT_COARSE_FEAT = True
37
+ _CN.LOFTR.FINE_SAMPLE_COARSE_FEAT = False
38
+ _CN.LOFTR.COARSE_FEAT_ONLY = False # TO BE DONE
39
+ _CN.LOFTR.INTER_FEAT = False # FPN backbone inter feat with coarse_attn.
40
+ _CN.LOFTR.FP16 = False
41
+ _CN.LOFTR.FIX_BIAS = False
42
+ _CN.LOFTR.MATCHABILITY = False
43
+ _CN.LOFTR.FORCE_LOOP_BACK = False
44
+ _CN.LOFTR.NORM_FPNFEAT = False
45
+ _CN.LOFTR.NORM_FPNFEAT2 = False
46
+ _CN.LOFTR.REPLACE_NAN = False
47
+ _CN.LOFTR.PLOT_SCORES = False
48
+ _CN.LOFTR.REP_FPN = False
49
+ _CN.LOFTR.REP_DEPLOY = False
50
+ _CN.LOFTR.EVAL_TIMES = 1
51
+
52
+ # 1. LoFTR-backbone (local feature CNN) config
53
+ _CN.LOFTR.RESNETFPN = CN()
54
+ _CN.LOFTR.RESNETFPN.INITIAL_DIM = 128
55
+ _CN.LOFTR.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
56
+ _CN.LOFTR.RESNETFPN.SAMPLE_FINE = False
57
+ _CN.LOFTR.RESNETFPN.COARSE_FEAT_ONLY = False # TO BE DONE
58
+ _CN.LOFTR.RESNETFPN.INTER_FEAT = False # FPN backbone inter feat with coarse_attn.
59
+ _CN.LOFTR.RESNETFPN.LEAKY = False
60
+ _CN.LOFTR.RESNETFPN.REPVGGMODEL = None
61
+
62
+ # 2. LoFTR-coarse module config
63
+ _CN.LOFTR.COARSE = CN()
64
+ _CN.LOFTR.COARSE.D_MODEL = 256
65
+ _CN.LOFTR.COARSE.D_FFN = 256
66
+ _CN.LOFTR.COARSE.NHEAD = 8
67
+ _CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
68
+ _CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
69
+ _CN.LOFTR.COARSE.TEMP_BUG_FIX = True
70
+ _CN.LOFTR.COARSE.NPE = False
71
+ _CN.LOFTR.COARSE.PAN = False
72
+ _CN.LOFTR.COARSE.POOl_SIZE = 4
73
+ _CN.LOFTR.COARSE.POOl_SIZE2 = 4
74
+ _CN.LOFTR.COARSE.BN = True
75
+ _CN.LOFTR.COARSE.XFORMER = False
76
+ _CN.LOFTR.COARSE.BIDIRECTION = False
77
+ _CN.LOFTR.COARSE.DEPTH_CONFIDENCE = -1.0
78
+ _CN.LOFTR.COARSE.WIDTH_CONFIDENCE = -1.0
79
+ _CN.LOFTR.COARSE.LEAKY = -1.0
80
+ _CN.LOFTR.COARSE.ASYMMETRIC = False
81
+ _CN.LOFTR.COARSE.ASYMMETRIC_SELF = False
82
+ _CN.LOFTR.COARSE.ROPE = False
83
+ _CN.LOFTR.COARSE.TOKEN_MIXER = None
84
+ _CN.LOFTR.COARSE.SKIP = False
85
+ _CN.LOFTR.COARSE.DWCONV = False
86
+ _CN.LOFTR.COARSE.DWCONV2 = False
87
+ _CN.LOFTR.COARSE.SCATTER = False
88
+ _CN.LOFTR.COARSE.ROPE = False
89
+ _CN.LOFTR.COARSE.NPE = None
90
+ _CN.LOFTR.COARSE.NORM_BEFORE = True
91
+ _CN.LOFTR.COARSE.VIT_NORM = False
92
+ _CN.LOFTR.COARSE.ROPE_DWPROJ = False
93
+ _CN.LOFTR.COARSE.ABSPE = False
94
+
95
+
96
+ # 3. Coarse-Matching config
97
+ _CN.LOFTR.MATCH_COARSE = CN()
98
+ _CN.LOFTR.MATCH_COARSE.THR = 0.2
99
+ _CN.LOFTR.MATCH_COARSE.BORDER_RM = 2
100
+ _CN.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
101
+ _CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
102
+ _CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3
103
+ _CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
104
+ _CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False
105
+ _CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
106
+ _CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
107
+ _CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
108
+ _CN.LOFTR.MATCH_COARSE.MTD_SPVS = False
109
+ _CN.LOFTR.MATCH_COARSE.FIX_BIAS = False
110
+ _CN.LOFTR.MATCH_COARSE.BINARY = False
111
+ _CN.LOFTR.MATCH_COARSE.BINARY_SPV = 'l2'
112
+ _CN.LOFTR.MATCH_COARSE.NORMFEAT = False
113
+ _CN.LOFTR.MATCH_COARSE.NORMFEATMUL = False
114
+ _CN.LOFTR.MATCH_COARSE.DIFFSIGN2 = False
115
+ _CN.LOFTR.MATCH_COARSE.DIFFSIGN3 = False
116
+ _CN.LOFTR.MATCH_COARSE.CLASSIFY = False
117
+ _CN.LOFTR.MATCH_COARSE.D_CLASSIFY = 256
118
+ _CN.LOFTR.MATCH_COARSE.SKIP_SOFTMAX = False
119
+ _CN.LOFTR.MATCH_COARSE.FORCE_NEAREST = False # in case binary is True, force nearest neighbor, preventing finding a reasonable threshold
120
+ _CN.LOFTR.MATCH_COARSE.FP16MATMUL = False
121
+ _CN.LOFTR.MATCH_COARSE.SEQSOFTMAX = False
122
+ _CN.LOFTR.MATCH_COARSE.SEQSOFTMAX2 = False
123
+ _CN.LOFTR.MATCH_COARSE.RATIO_TEST = False
124
+ _CN.LOFTR.MATCH_COARSE.RATIO_TEST_VAL = -1.0
125
+ _CN.LOFTR.MATCH_COARSE.USE_GT_COARSE = False
126
+ _CN.LOFTR.MATCH_COARSE.CROSS_SOFTMAX = False
127
+ _CN.LOFTR.MATCH_COARSE.PLOT_ORIGIN_SCORES = False
128
+ _CN.LOFTR.MATCH_COARSE.USE_PERCENT_THR = False
129
+ _CN.LOFTR.MATCH_COARSE.PERCENT_THR = 0.1
130
+ _CN.LOFTR.MATCH_COARSE.ADD_SIGMOID = False
131
+ _CN.LOFTR.MATCH_COARSE.SIGMOID_BIAS = 20.0
132
+ _CN.LOFTR.MATCH_COARSE.SIGMOID_SIGMA = 2.5
133
+ _CN.LOFTR.MATCH_COARSE.CAL_PER_OF_GT = False
134
+
135
+ # 4. LoFTR-fine module config
136
+ _CN.LOFTR.FINE = CN()
137
+ _CN.LOFTR.FINE.SKIP = False
138
+ _CN.LOFTR.FINE.D_MODEL = 128
139
+ _CN.LOFTR.FINE.D_FFN = 128
140
+ _CN.LOFTR.FINE.NHEAD = 8
141
+ _CN.LOFTR.FINE.LAYER_NAMES = ['self', 'cross'] * 1
142
+ _CN.LOFTR.FINE.ATTENTION = 'linear'
143
+ _CN.LOFTR.FINE.MTD_SPVS = False
144
+ _CN.LOFTR.FINE.PAN = False
145
+ _CN.LOFTR.FINE.POOl_SIZE = 4
146
+ _CN.LOFTR.FINE.BN = True
147
+ _CN.LOFTR.FINE.XFORMER = False
148
+ _CN.LOFTR.FINE.BIDIRECTION = False
149
+
150
+
151
+ # Fine-Matching config
152
+ _CN.LOFTR.MATCH_FINE = CN()
153
+ _CN.LOFTR.MATCH_FINE.THR = 0
154
+ _CN.LOFTR.MATCH_FINE.TOPK = 3
155
+ _CN.LOFTR.MATCH_FINE.NORMFINEM = False
156
+ _CN.LOFTR.MATCH_FINE.USE_GT_FINE = False
157
+ _CN.LOFTR.MATCH_COARSE.FINE_TOPK = _CN.LOFTR.MATCH_FINE.TOPK
158
+ _CN.LOFTR.MATCH_FINE.FIX_FINE_MATCHING = False
159
+ _CN.LOFTR.MATCH_FINE.SKIP_FINE_SOFTMAX = False
160
+ _CN.LOFTR.MATCH_FINE.USE_SIGMOID = False
161
+ _CN.LOFTR.MATCH_FINE.SIGMOID_BIAS = 0.0
162
+ _CN.LOFTR.MATCH_FINE.NORMFEAT = False
163
+ _CN.LOFTR.MATCH_FINE.SPARSE_SPVS = True
164
+ _CN.LOFTR.MATCH_FINE.FORCE_NEAREST = False
165
+ _CN.LOFTR.MATCH_FINE.LOCAL_REGRESS = False
166
+ _CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_RMBORDER = False
167
+ _CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK = False
168
+ _CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 1.0
169
+ _CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_PADONE = False
170
+ _CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICE = False
171
+ _CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
172
+ _CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_INNER = False
173
+ _CN.LOFTR.MATCH_FINE.MULTI_REGRESS = False
174
+
175
+
176
+
177
+ # 5. LoFTR Losses
178
+ # -- # coarse-level
179
+ _CN.LOFTR.LOSS = CN()
180
+ _CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy']
181
+ _CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0
182
+ _CN.LOFTR.LOSS.COARSE_SIGMOID_WEIGHT = 1.0
183
+ _CN.LOFTR.LOSS.LOCAL_WEIGHT = 0.5
184
+ _CN.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = False
185
+ _CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = False
186
+ _CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2 = False
187
+ # _CN.LOFTR.LOSS.SPARSE_SPVS = False
188
+ # -- - -- # focal loss (coarse)
189
+ _CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25
190
+ _CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0
191
+ _CN.LOFTR.LOSS.POS_WEIGHT = 1.0
192
+ _CN.LOFTR.LOSS.NEG_WEIGHT = 1.0
193
+ _CN.LOFTR.LOSS.CORRECT_NEG_WEIGHT = False
194
+ # _CN.LOFTR.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not.
195
+ # use `_CN.LOFTR.MATCH_COARSE.MATCH_TYPE`
196
+
197
+ # -- # fine-level
198
+ _CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2']
199
+ _CN.LOFTR.LOSS.FINE_WEIGHT = 1.0
200
+ _CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
201
+
202
+ # -- # ROMA:
203
+ _CN.LOFTR.ROMA_LOSS = CN()
204
+ _CN.LOFTR.ROMA_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = False # ['l2_with_std', 'l2']
205
+
206
+ # -- # DKM:
207
+ _CN.LOFTR.DKM_LOSS = CN()
208
+ _CN.LOFTR.DKM_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = False # ['l2_with_std', 'l2']
209
+
210
+ ############## Dataset ##############
211
+ _CN.DATASET = CN()
212
+ # 1. data config
213
+ # training and validating
214
+ _CN.DATASET.TB_LOG_DIR= "logs/tb_logs" # options: ['ScanNet', 'MegaDepth']
215
+ _CN.DATASET.TRAIN_DATA_SAMPLE_RATIO = [1.0] # options: ['ScanNet', 'MegaDepth']
216
+ _CN.DATASET.TRAIN_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
217
+ _CN.DATASET.TRAIN_DATA_ROOT = None
218
+ _CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
219
+ _CN.DATASET.TRAIN_NPZ_ROOT = None
220
+ _CN.DATASET.TRAIN_LIST_PATH = None
221
+ _CN.DATASET.TRAIN_INTRINSIC_PATH = None
222
+ _CN.DATASET.VAL_DATA_ROOT = None
223
+ _CN.DATASET.VAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
224
+ _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
225
+ _CN.DATASET.VAL_NPZ_ROOT = None
226
+ _CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
227
+ _CN.DATASET.VAL_INTRINSIC_PATH = None
228
+ _CN.DATASET.FP16 = False
229
+ _CN.DATASET.TRAIN_GT_MATCHES_PADDING_N = 8000
230
+ # testing
231
+ _CN.DATASET.TEST_DATA_SOURCE = None
232
+ _CN.DATASET.TEST_DATA_ROOT = None
233
+ _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
234
+ _CN.DATASET.TEST_NPZ_ROOT = None
235
+ _CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
236
+ _CN.DATASET.TEST_INTRINSIC_PATH = None
237
+
238
+ # 2. dataset config
239
+ # general options
240
+ _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
241
+ _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
242
+ _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
243
+
244
+ # debug options
245
+ _CN.DATASET.TEST_N_PAIRS = None # Debug first N pairs
246
+ # DEBUG
247
+ _CN.LOFTR.FP16LOG = False
248
+ _CN.LOFTR.MATCH_COARSE.FP16LOG = False
249
+
250
+ # scanNet options
251
+ _CN.DATASET.SCAN_IMG_RESIZEX = 640 # resize the longer side, zero-pad bottom-right to square.
252
+ _CN.DATASET.SCAN_IMG_RESIZEY = 480 # resize the shorter side, zero-pad bottom-right to square.
253
+
254
+ # MegaDepth options
255
+ _CN.DATASET.MGDPT_IMG_RESIZE = (640, 640) # resize the longer side, zero-pad bottom-right to square.
256
+ _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
257
+ _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
258
+ _CN.DATASET.MGDPT_DF = 8
259
+ _CN.DATASET.LOAD_ORIGIN_RGB = False # Only open in test mode, useful for RGB required baselines such as DKM, ROMA.
260
+ _CN.DATASET.READ_GRAY = True
261
+ _CN.DATASET.RESIZE_BY_STRETCH = False
262
+ _CN.DATASET.NORMALIZE_IMG = False # For backbone using pretrained DINO feats, use True may be better.
263
+ _CN.DATASET.HOMO_WARP_USE_MASK = False
264
+
265
+ _CN.DATASET.NPE_NAME = "megadepth"
266
+
267
+ ############## Trainer ##############
268
+ _CN.TRAINER = CN()
269
+ _CN.TRAINER.WORLD_SIZE = 1
270
+ _CN.TRAINER.CANONICAL_BS = 64
271
+ _CN.TRAINER.CANONICAL_LR = 6e-3
272
+ _CN.TRAINER.SCALING = None # this will be calculated automatically
273
+ _CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
274
+
275
+ # optimizer
276
+ _CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
277
+ _CN.TRAINER.OPTIMIZER_EPS = 1e-8 # Default for optimizers, but set smaller, e.g., 1e-7 for fp16 mix training
278
+ _CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
279
+ _CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
280
+ _CN.TRAINER.ADAMW_DECAY = 0.1
281
+
282
+ # step-based warm-up
283
+ _CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
284
+ _CN.TRAINER.WARMUP_RATIO = 0.
285
+ _CN.TRAINER.WARMUP_STEP = 4800
286
+
287
+ # learning rate scheduler
288
+ _CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
289
+ _CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
290
+ _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
291
+ _CN.TRAINER.MSLR_GAMMA = 0.5
292
+ _CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
293
+ _CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
294
+
295
+ # plotting related
296
+ _CN.TRAINER.ENABLE_PLOTTING = True
297
+ _CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 8 # number of val/test paris for plotting
298
+ _CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
299
+ _CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
300
+
301
+ # geometric metrics and pose solver
302
+ _CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
303
+ _CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
304
+ _CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
305
+ _CN.TRAINER.WARP_ESTIMATOR_MODEL = 'affine' # [RANSAC, DEGENSAC, MAGSAC]
306
+ _CN.TRAINER.RANSAC_PIXEL_THR = 0.5
307
+ _CN.TRAINER.RANSAC_CONF = 0.99999
308
+ _CN.TRAINER.RANSAC_MAX_ITERS = 10000
309
+ _CN.TRAINER.USE_MAGSACPP = False
310
+ _CN.TRAINER.THRESHOLDS = [5, 10, 20]
311
+
312
+ # data sampler for train_dataloader
313
+ _CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
314
+ # 'scene_balance' config
315
+ _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
316
+ _CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
317
+ _CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
318
+ _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
319
+ _CN.TRAINER.AUC_METHOD = 'exact_auc'
320
+ # 'random' config
321
+ _CN.TRAINER.RDM_REPLACEMENT = True
322
+ _CN.TRAINER.RDM_NUM_SAMPLES = None
323
+
324
+ # gradient clipping
325
+ _CN.TRAINER.GRADIENT_CLIPPING = 0.5
326
+
327
+ # Finetune Mode:
328
+ _CN.FINETUNE = CN()
329
+ _CN.FINETUNE.ENABLE = False
330
+ _CN.FINETUNE.METHOD = "lora" #['lora', 'whole_network']
331
+
332
+ _CN.FINETUNE.LORA = CN()
333
+ _CN.FINETUNE.LORA.RANK = 2
334
+ _CN.FINETUNE.LORA.MODE = "linear&conv" # ["linear&conv", "linear_only"]
335
+ _CN.FINETUNE.LORA.SCALE = 1.0
336
+
337
+ _CN.TRAINER.SEED = 66
338
+
339
+
340
+ def get_cfg_defaults():
341
+ """Get a yacs CfgNode object with default values for my_project."""
342
+ # Return a clone so that the defaults will not be altered
343
+ # This is for the "local variable" use pattern
344
+ return _CN.clone()
imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from collections import defaultdict
3
+ import pprint
4
+ from loguru import logger
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import numpy as np
9
+ import pytorch_lightning as pl
10
+ from matplotlib import pyplot as plt
11
+
12
+ from src.loftr import LoFTR
13
+ from src.loftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine, compute_roma_supervision
14
+ from src.optimizers import build_optimizer, build_scheduler
15
+ from src.utils.metrics import (
16
+ compute_symmetrical_epipolar_errors,
17
+ compute_pose_errors,
18
+ compute_homo_corner_warp_errors,
19
+ compute_homo_match_warp_errors,
20
+ compute_warp_control_pts_errors,
21
+ aggregate_metrics
22
+ )
23
+ from src.utils.plotting import make_matching_figures, make_scores_figures
24
+ from src.utils.comm import gather, all_gather
25
+ from src.utils.misc import lower_config, flattenList
26
+ from src.utils.profiler import PassThroughProfiler
27
+ from third_party.ROMA.roma.matchanything_roma_model import MatchAnything_Model
28
+
29
+ import pynvml
30
+
31
+ def reparameter(matcher):
32
+ module = matcher.backbone.layer0
33
+ if hasattr(module, 'switch_to_deploy'):
34
+ module.switch_to_deploy()
35
+ print('m0 switch to deploy ok')
36
+ for modules in [matcher.backbone.layer1, matcher.backbone.layer2, matcher.backbone.layer3]:
37
+ for module in modules:
38
+ if hasattr(module, 'switch_to_deploy'):
39
+ module.switch_to_deploy()
40
+ print('backbone switch to deploy ok')
41
+ for modules in [matcher.fine_preprocess.layer2_outconv2, matcher.fine_preprocess.layer1_outconv2]:
42
+ for module in modules:
43
+ if hasattr(module, 'switch_to_deploy'):
44
+ module.switch_to_deploy()
45
+ print('fpn switch to deploy ok')
46
+ return matcher
47
+
48
+ class PL_LoFTR(pl.LightningModule):
49
+ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None, test_mode=False, baseline_config=None):
50
+ """
51
+ TODO:
52
+ - use the new version of PL logging API.
53
+ """
54
+ super().__init__()
55
+ # Misc
56
+ self.config = config # full config
57
+ _config = lower_config(self.config)
58
+ self.profiler = profiler or PassThroughProfiler()
59
+ self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
60
+
61
+ if config.METHOD == "matchanything_eloftr":
62
+ self.matcher = LoFTR(config=_config['loftr'], profiler=self.profiler)
63
+ elif config.METHOD == "matchanything_roma":
64
+ self.matcher = MatchAnything_Model(config=_config['roma'], test_mode=test_mode)
65
+ else:
66
+ raise NotImplementedError
67
+
68
+ if config.FINETUNE.ENABLE and test_mode:
69
+ # Inference time change model architecture before load pretrained model:
70
+ raise NotImplementedError
71
+
72
+ # Pretrained weights
73
+ if pretrained_ckpt:
74
+ if config.METHOD in ["matchanything_eloftr", "matchanything_roma"]:
75
+ state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
76
+ logger.info(f"Load model from:{self.matcher.load_state_dict(state_dict, strict=False)}")
77
+ else:
78
+ raise NotImplementedError
79
+
80
+ if self.config.LOFTR.BACKBONE_TYPE == 'RepVGG' and test_mode and (config.METHOD == 'loftr'):
81
+ module = self.matcher.backbone.layer0
82
+ if hasattr(module, 'switch_to_deploy'):
83
+ module.switch_to_deploy()
84
+ print('m0 switch to deploy ok')
85
+ for modules in [self.matcher.backbone.layer1, self.matcher.backbone.layer2, self.matcher.backbone.layer3]:
86
+ for module in modules:
87
+ if hasattr(module, 'switch_to_deploy'):
88
+ module.switch_to_deploy()
89
+ print('m switch to deploy ok')
90
+
91
+ # Testing
92
+ self.dump_dir = dump_dir
93
+ self.max_gpu_memory = 0
94
+ self.GPUID = 0
95
+ self.warmup = False
96
+
97
+ def gpumem(self, des, gpuid=None):
98
+ NUM_EXPAND = 1024 * 1024 * 1024
99
+ gpu_id= self.GPUID if self.GPUID is not None else gpuid
100
+ handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
101
+ info = pynvml.nvmlDeviceGetMemoryInfo(handle)
102
+ gpu_Used = info.used
103
+ logger.info(f"GPU {gpu_id} memory used: {gpu_Used / NUM_EXPAND} GB while {des}")
104
+ # print(des, gpu_Used / NUM_EXPAND)
105
+ if gpu_Used / NUM_EXPAND > self.max_gpu_memory:
106
+ self.max_gpu_memory = gpu_Used / NUM_EXPAND
107
+ logger.info(f"[MAX]GPU {gpu_id} memory used: {gpu_Used / NUM_EXPAND} GB while {des}")
108
+ print('max_gpu_memory', self.max_gpu_memory)
109
+
110
+ def configure_optimizers(self):
111
+ optimizer = build_optimizer(self, self.config)
112
+ scheduler = build_scheduler(self.config, optimizer)
113
+ return [optimizer], [scheduler]
114
+
115
+ def optimizer_step(
116
+ self, epoch, batch_idx, optimizer, optimizer_idx,
117
+ optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
118
+ # learning rate warm up
119
+ warmup_step = self.config.TRAINER.WARMUP_STEP
120
+ if self.trainer.global_step < warmup_step:
121
+ if self.config.TRAINER.WARMUP_TYPE == 'linear':
122
+ base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
123
+ lr = base_lr + \
124
+ (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
125
+ abs(self.config.TRAINER.TRUE_LR - base_lr)
126
+ for pg in optimizer.param_groups:
127
+ pg['lr'] = lr
128
+ elif self.config.TRAINER.WARMUP_TYPE == 'constant':
129
+ pass
130
+ else:
131
+ raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
132
+
133
+ # update params
134
+ if self.config.LOFTR.FP16:
135
+ optimizer.step(closure=optimizer_closure)
136
+ else:
137
+ optimizer.step(closure=optimizer_closure)
138
+ optimizer.zero_grad()
139
+
140
+ def _trainval_inference(self, batch):
141
+ with self.profiler.profile("Compute coarse supervision"):
142
+
143
+ with torch.autocast(enabled=False, device_type='cuda'):
144
+ if ("roma" in self.config.METHOD) or ('dkm' in self.config.METHOD):
145
+ pass
146
+ else:
147
+ compute_supervision_coarse(batch, self.config)
148
+
149
+ with self.profiler.profile("LoFTR"):
150
+ with torch.autocast(enabled=self.config.LOFTR.FP16, device_type='cuda'):
151
+ self.matcher(batch)
152
+
153
+ with self.profiler.profile("Compute fine supervision"):
154
+ with torch.autocast(enabled=False, device_type='cuda'):
155
+ if ("roma" in self.config.METHOD) or ('dkm' in self.config.METHOD):
156
+ compute_roma_supervision(batch, self.config)
157
+ else:
158
+ compute_supervision_fine(batch, self.config, self.logger)
159
+
160
+ with self.profiler.profile("Compute losses"):
161
+ pass
162
+
163
+ def _compute_metrics(self, batch):
164
+ if 'gt_2D_matches' in batch:
165
+ compute_warp_control_pts_errors(batch, self.config)
166
+ elif batch['homography'].sum() != 0 and batch['T_0to1'].sum() == 0:
167
+ compute_homo_match_warp_errors(batch, self.config) # compute warp_errors for each match
168
+ compute_homo_corner_warp_errors(batch, self.config) # compute mean corner warp error each pair
169
+ else:
170
+ compute_symmetrical_epipolar_errors(batch, self.config) # compute epi_errs for each match
171
+ compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
172
+
173
+ rel_pair_names = list(zip(*batch['pair_names']))
174
+ bs = batch['image0'].size(0)
175
+ if self.config.LOFTR.FINE.MTD_SPVS:
176
+ topk = self.config.LOFTR.MATCH_FINE.TOPK
177
+ metrics = {
178
+ # to filter duplicate pairs caused by DistributedSampler
179
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
180
+ 'epi_errs': [(batch['epi_errs'].reshape(-1,topk))[batch['m_bids'] == b].reshape(-1).cpu().numpy() for b in range(bs)],
181
+ 'R_errs': batch['R_errs'],
182
+ 't_errs': batch['t_errs'],
183
+ 'inliers': batch['inliers'],
184
+ 'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only
185
+ 'percent_inliers': [ batch['inliers'][0].shape[0] / batch['mconf'].shape[0] if batch['mconf'].shape[0]!=0 else 1], # batch size = 1 only
186
+ }
187
+ else:
188
+ metrics = {
189
+ # to filter duplicate pairs caused by DistributedSampler
190
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
191
+ 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
192
+ 'R_errs': batch['R_errs'],
193
+ 't_errs': batch['t_errs'],
194
+ 'inliers': batch['inliers'],
195
+ 'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only
196
+ 'percent_inliers': [ batch['inliers'][0].shape[0] / batch['mconf'].shape[0] if batch['mconf'].shape[0]!=0 else 1], # batch size = 1 only
197
+ }
198
+ ret_dict = {'metrics': metrics}
199
+ return ret_dict, rel_pair_names
200
+
201
+ def training_step(self, batch, batch_idx):
202
+ self._trainval_inference(batch)
203
+
204
+ # logging
205
+ if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
206
+ # scalars
207
+ for k, v in batch['loss_scalars'].items():
208
+ self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
209
+
210
+ # net-params
211
+ method = 'LOFTR'
212
+ if self.config[method]['MATCH_COARSE']['MATCH_TYPE'] == 'sinkhorn':
213
+ self.logger.experiment.add_scalar(
214
+ f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step)
215
+
216
+ figures = {}
217
+ if self.config.TRAINER.ENABLE_PLOTTING:
218
+ compute_symmetrical_epipolar_errors(batch, self.config) # compute epi_errs for each match
219
+ figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
220
+ for k, v in figures.items():
221
+ self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
222
+
223
+ return {'loss': batch['loss']}
224
+
225
+ def training_epoch_end(self, outputs):
226
+ avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
227
+ if self.trainer.global_rank == 0:
228
+ self.logger.experiment.add_scalar(
229
+ 'train/avg_loss_on_epoch', avg_loss,
230
+ global_step=self.current_epoch)
231
+
232
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
233
+ self._trainval_inference(batch)
234
+
235
+ ret_dict, _ = self._compute_metrics(batch)
236
+
237
+ val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
238
+ figures = {self.config.TRAINER.PLOT_MODE: []}
239
+ if batch_idx % val_plot_interval == 0:
240
+ figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
241
+ if self.config.LOFTR.PLOT_SCORES:
242
+ figs = make_scores_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
243
+ figures[self.config.TRAINER.PLOT_MODE] += figs[self.config.TRAINER.PLOT_MODE]
244
+ del figs
245
+
246
+ return {
247
+ **ret_dict,
248
+ 'loss_scalars': batch['loss_scalars'],
249
+ 'figures': figures,
250
+ }
251
+
252
+ def validation_epoch_end(self, outputs):
253
+ # handle multiple validation sets
254
+ multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
255
+ multi_val_metrics = defaultdict(list)
256
+
257
+ for valset_idx, outputs in enumerate(multi_outputs):
258
+ # since pl performs sanity_check at the very begining of the training
259
+ cur_epoch = self.trainer.current_epoch
260
+ if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
261
+ cur_epoch = -1
262
+
263
+ # 1. loss_scalars: dict of list, on cpu
264
+ _loss_scalars = [o['loss_scalars'] for o in outputs]
265
+ loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
266
+
267
+ # 2. val metrics: dict of list, numpy
268
+ _metrics = [o['metrics'] for o in outputs]
269
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
270
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
271
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, self.config.LOFTR.EVAL_TIMES)
272
+ for thr in [5, 10, 20]:
273
+ multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
274
+
275
+ # 3. figures
276
+ _figures = [o['figures'] for o in outputs]
277
+ figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
278
+
279
+ # tensorboard records only on rank 0
280
+ if self.trainer.global_rank == 0:
281
+ for k, v in loss_scalars.items():
282
+ mean_v = torch.stack(v).mean()
283
+ self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
284
+
285
+ for k, v in val_metrics_4tb.items():
286
+ self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
287
+
288
+ for k, v in figures.items():
289
+ if self.trainer.global_rank == 0:
290
+ for plot_idx, fig in enumerate(v):
291
+ self.logger.experiment.add_figure(
292
+ f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
293
+ plt.close('all')
294
+
295
+ for thr in [5, 10, 20]:
296
+ self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
297
+
298
+ def test_step(self, batch, batch_idx):
299
+ if self.warmup:
300
+ for i in range(50):
301
+ self.matcher(batch)
302
+ self.warmup = False
303
+
304
+ with torch.autocast(enabled=self.config.LOFTR.FP16, device_type='cuda'):
305
+ with self.profiler.profile("LoFTR"):
306
+ self.matcher(batch)
307
+
308
+ ret_dict, rel_pair_names = self._compute_metrics(batch)
309
+ print(ret_dict['metrics']['num_matches'])
310
+ self.dump_dir = None
311
+
312
+ return ret_dict
313
+
314
+ def test_epoch_end(self, outputs):
315
+ print(self.config)
316
+ print('max GPU memory: ', self.max_gpu_memory)
317
+ print(self.profiler.summary())
318
+ # metrics: dict of list, numpy
319
+ _metrics = [o['metrics'] for o in outputs]
320
+ metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
321
+
322
+ # [{key: [{...}, *#bs]}, *#batch]
323
+ if self.dump_dir is not None:
324
+ Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
325
+ _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
326
+ dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
327
+ logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
328
+
329
+ if self.trainer.global_rank == 0:
330
+ NUM_EXPAND = 1024 * 1024 * 1024
331
+ gpu_id=self.GPUID
332
+ handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
333
+ info = pynvml.nvmlDeviceGetMemoryInfo(handle)
334
+ gpu_Used = info.used
335
+ print('pynvml', gpu_Used / NUM_EXPAND)
336
+ if gpu_Used / NUM_EXPAND > self.max_gpu_memory:
337
+ self.max_gpu_memory = gpu_Used / NUM_EXPAND
338
+
339
+ print(self.profiler.summary())
340
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, self.config.LOFTR.EVAL_TIMES, self.config.TRAINER.THRESHOLDS, method=self.config.TRAINER.AUC_METHOD)
341
+ logger.info('\n' + pprint.pformat(val_metrics_4tb))
342
+ if self.dump_dir is not None:
343
+ np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps)
imcui/third_party/MatchAnything/src/loftr/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .loftr import LoFTR
imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4, ResNetFPN_8_1, ResNetFPN_8_2_align, ResNetFPN_8_1_align, ResNetFPN_8_2_fix, ResNet_8_1_align, VGG_8_1_align, RepVGG_8_1_align, \
2
+ RepVGGnfpn_8_1_align, RepVGG_8_2_fix, s2dnet_8_1_align
3
+
4
+ def build_backbone(config):
5
+ if config['backbone_type'] == 'ResNetFPN':
6
+ if config['align_corner'] is None or config['align_corner'] is True:
7
+ if config['resolution'] == (8, 2):
8
+ return ResNetFPN_8_2(config['resnetfpn'])
9
+ elif config['resolution'] == (16, 4):
10
+ return ResNetFPN_16_4(config['resnetfpn'])
11
+ elif config['resolution'] == (8, 1):
12
+ return ResNetFPN_8_1(config['resnetfpn'])
13
+ elif config['align_corner'] is False:
14
+ if config['resolution'] == (8, 2):
15
+ return ResNetFPN_8_2_align(config['resnetfpn'])
16
+ elif config['resolution'] == (16, 4):
17
+ return ResNetFPN_16_4(config['resnetfpn'])
18
+ elif config['resolution'] == (8, 1):
19
+ return ResNetFPN_8_1_align(config['resnetfpn'])
20
+ elif config['backbone_type'] == 'ResNetFPNFIX':
21
+ if config['align_corner'] is None or config['align_corner'] is True:
22
+ if config['resolution'] == (8, 2):
23
+ return ResNetFPN_8_2_fix(config['resnetfpn'])
24
+ elif config['backbone_type'] == 'ResNet':
25
+ if config['align_corner'] is None or config['align_corner'] is True:
26
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
27
+ elif config['align_corner'] is False:
28
+ if config['resolution'] == (8, 1):
29
+ return ResNet_8_1_align(config['resnetfpn'])
30
+ elif config['backbone_type'] == 'VGG':
31
+ if config['align_corner'] is None or config['align_corner'] is True:
32
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
33
+ elif config['align_corner'] is False:
34
+ if config['resolution'] == (8, 1):
35
+ return VGG_8_1_align(config['resnetfpn'])
36
+ elif config['backbone_type'] == 'RepVGG':
37
+ if config['align_corner'] is None or config['align_corner'] is True:
38
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
39
+ elif config['align_corner'] is False:
40
+ if config['resolution'] == (8, 1):
41
+ return RepVGG_8_1_align(config['resnetfpn'])
42
+ elif config['backbone_type'] == 'RepVGGNFPN':
43
+ if config['align_corner'] is None or config['align_corner'] is True:
44
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
45
+ elif config['align_corner'] is False:
46
+ if config['resolution'] == (8, 1):
47
+ return RepVGGnfpn_8_1_align(config['resnetfpn'])
48
+ elif config['backbone_type'] == 'RepVGGFPNFIX':
49
+ if config['align_corner'] is None or config['align_corner'] is True:
50
+ if config['resolution'] == (8, 2):
51
+ return RepVGG_8_2_fix(config['resnetfpn'])
52
+ elif config['align_corner'] is False:
53
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
54
+ elif config['backbone_type'] == 's2dnet':
55
+ if config['align_corner'] is None or config['align_corner'] is True:
56
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
57
+ elif config['align_corner'] is False:
58
+ if config['resolution'] == (8, 1):
59
+ return s2dnet_8_1_align(config['resnetfpn'])
60
+ else:
61
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ import torch
9
+ import copy
10
+ # from se_block import SEBlock
11
+ import torch.utils.checkpoint as checkpoint
12
+ from loguru import logger
13
+
14
+ def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
15
+ result = nn.Sequential()
16
+ result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
17
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
18
+ result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
19
+ return result
20
+
21
+ class RepVGGBlock(nn.Module):
22
+
23
+ def __init__(self, in_channels, out_channels, kernel_size,
24
+ stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, leaky=-1.0):
25
+ super(RepVGGBlock, self).__init__()
26
+ self.deploy = deploy
27
+ self.groups = groups
28
+ self.in_channels = in_channels
29
+
30
+ assert kernel_size == 3
31
+ assert padding == 1
32
+
33
+ padding_11 = padding - kernel_size // 2
34
+
35
+ if leaky == -2:
36
+ self.nonlinearity = nn.Identity()
37
+ logger.info(f"Using Identity nonlinearity in repvgg_block")
38
+ elif leaky < 0:
39
+ self.nonlinearity = nn.ReLU()
40
+ else:
41
+ self.nonlinearity = nn.LeakyReLU(leaky)
42
+
43
+ if use_se:
44
+ # Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity.
45
+ # self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)
46
+ raise ValueError(f"SEBlock not supported")
47
+ else:
48
+ self.se = nn.Identity()
49
+
50
+ if deploy:
51
+ self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
52
+ padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
53
+
54
+ else:
55
+ self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
56
+ self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
57
+ self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
58
+ print('RepVGG Block, identity = ', self.rbr_identity)
59
+
60
+
61
+ def forward(self, inputs):
62
+ if hasattr(self, 'rbr_reparam'):
63
+ return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
64
+
65
+ if self.rbr_identity is None:
66
+ id_out = 0
67
+ else:
68
+ id_out = self.rbr_identity(inputs)
69
+
70
+ return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
71
+
72
+
73
+ # Optional. This may improve the accuracy and facilitates quantization in some cases.
74
+ # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
75
+ # 2. Use like this.
76
+ # loss = criterion(....)
77
+ # for every RepVGGBlock blk:
78
+ # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
79
+ # optimizer.zero_grad()
80
+ # loss.backward()
81
+ def get_custom_L2(self):
82
+ K3 = self.rbr_dense.conv.weight
83
+ K1 = self.rbr_1x1.conv.weight
84
+ t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
85
+ t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
86
+
87
+ l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
88
+ eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
89
+ l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
90
+ return l2_loss_eq_kernel + l2_loss_circle
91
+
92
+
93
+
94
+ # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
95
+ # You can get the equivalent kernel and bias at any time and do whatever you want,
96
+ # for example, apply some penalties or constraints during training, just like you do to the other models.
97
+ # May be useful for quantization or pruning.
98
+ def get_equivalent_kernel_bias(self):
99
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
100
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
101
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
102
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
103
+
104
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
105
+ if kernel1x1 is None:
106
+ return 0
107
+ else:
108
+ return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
109
+
110
+ def _fuse_bn_tensor(self, branch):
111
+ if branch is None:
112
+ return 0, 0
113
+ if isinstance(branch, nn.Sequential):
114
+ kernel = branch.conv.weight
115
+ running_mean = branch.bn.running_mean
116
+ running_var = branch.bn.running_var
117
+ gamma = branch.bn.weight
118
+ beta = branch.bn.bias
119
+ eps = branch.bn.eps
120
+ else:
121
+ assert isinstance(branch, nn.BatchNorm2d)
122
+ if not hasattr(self, 'id_tensor'):
123
+ input_dim = self.in_channels // self.groups
124
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
125
+ for i in range(self.in_channels):
126
+ kernel_value[i, i % input_dim, 1, 1] = 1
127
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
128
+ kernel = self.id_tensor
129
+ running_mean = branch.running_mean
130
+ running_var = branch.running_var
131
+ gamma = branch.weight
132
+ beta = branch.bias
133
+ eps = branch.eps
134
+ std = (running_var + eps).sqrt()
135
+ t = (gamma / std).reshape(-1, 1, 1, 1)
136
+ return kernel * t, beta - running_mean * gamma / std
137
+
138
+ def switch_to_deploy(self):
139
+ if hasattr(self, 'rbr_reparam'):
140
+ return
141
+ kernel, bias = self.get_equivalent_kernel_bias()
142
+ self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
143
+ kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
144
+ padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
145
+ self.rbr_reparam.weight.data = kernel
146
+ self.rbr_reparam.bias.data = bias
147
+ self.__delattr__('rbr_dense')
148
+ self.__delattr__('rbr_1x1')
149
+ if hasattr(self, 'rbr_identity'):
150
+ self.__delattr__('rbr_identity')
151
+ if hasattr(self, 'id_tensor'):
152
+ self.__delattr__('id_tensor')
153
+ self.deploy = True
154
+
155
+
156
+
157
+ class RepVGG(nn.Module):
158
+
159
+ def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False, leaky=-1.0):
160
+ super(RepVGG, self).__init__()
161
+ assert len(width_multiplier) == 4
162
+ self.deploy = deploy
163
+ self.override_groups_map = override_groups_map or dict()
164
+ assert 0 not in self.override_groups_map
165
+ self.use_se = use_se
166
+ self.use_checkpoint = use_checkpoint
167
+
168
+ self.in_planes = min(64, int(64 * width_multiplier[0]))
169
+ self.stage0 = RepVGGBlock(in_channels=1, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se, leaky=leaky)
170
+ self.cur_layer_idx = 1
171
+ self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=1, leaky=leaky)
172
+ self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2, leaky=leaky)
173
+ self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2, leaky=leaky)
174
+ # self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=1)
175
+ # self.gap = nn.AdaptiveAvgPool2d(output_size=1)
176
+ # self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
177
+
178
+ def _make_stage(self, planes, num_blocks, stride, leaky=-1.0):
179
+ strides = [stride] + [1]*(num_blocks-1)
180
+ blocks = []
181
+ for stride in strides:
182
+ cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
183
+ blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
184
+ stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se, leaky=leaky))
185
+ self.in_planes = planes
186
+ self.cur_layer_idx += 1
187
+ return nn.ModuleList(blocks)
188
+
189
+ def forward(self, x):
190
+ out = self.stage0(x)
191
+ for stage in (self.stage1, self.stage2, self.stage3): # , self.stage4):
192
+ for block in stage:
193
+ if self.use_checkpoint:
194
+ out = checkpoint.checkpoint(block, out)
195
+ else:
196
+ out = block(out)
197
+ out = self.gap(out)
198
+ out = out.view(out.size(0), -1)
199
+ out = self.linear(out)
200
+ return out
201
+
202
+
203
+ optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
204
+ g2_map = {l: 2 for l in optional_groupwise_layers}
205
+ g4_map = {l: 4 for l in optional_groupwise_layers}
206
+
207
+ def create_RepVGG_A0(deploy=False, use_checkpoint=False):
208
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
209
+ width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
210
+
211
+ def create_RepVGG_A1(deploy=False, use_checkpoint=False):
212
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
213
+ width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
214
+ def create_RepVGG_A15(deploy=False, use_checkpoint=False):
215
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
216
+ width_multiplier=[1.25, 1.25, 1.25, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
217
+ def create_RepVGG_A1_leaky(deploy=False, use_checkpoint=False):
218
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
219
+ width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint, leaky=0.01)
220
+
221
+ def create_RepVGG_A2(deploy=False, use_checkpoint=False):
222
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
223
+ width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
224
+
225
+ def create_RepVGG_B0(deploy=False, use_checkpoint=False):
226
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
227
+ width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
228
+
229
+ def create_RepVGG_B1(deploy=False, use_checkpoint=False):
230
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
231
+ width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
232
+
233
+ def create_RepVGG_B1g2(deploy=False, use_checkpoint=False):
234
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
235
+ width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
236
+
237
+ def create_RepVGG_B1g4(deploy=False, use_checkpoint=False):
238
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
239
+ width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
240
+
241
+
242
+ def create_RepVGG_B2(deploy=False, use_checkpoint=False):
243
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
244
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
245
+
246
+ def create_RepVGG_B2g2(deploy=False, use_checkpoint=False):
247
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
248
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
249
+
250
+ def create_RepVGG_B2g4(deploy=False, use_checkpoint=False):
251
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
252
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
253
+
254
+
255
+ def create_RepVGG_B3(deploy=False, use_checkpoint=False):
256
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
257
+ width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
258
+
259
+ def create_RepVGG_B3g2(deploy=False, use_checkpoint=False):
260
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
261
+ width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
262
+
263
+ def create_RepVGG_B3g4(deploy=False, use_checkpoint=False):
264
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
265
+ width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
266
+
267
+ def create_RepVGG_D2se(deploy=False, use_checkpoint=False):
268
+ return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000,
269
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True, use_checkpoint=use_checkpoint)
270
+
271
+
272
+ func_dict = {
273
+ 'RepVGG-A0': create_RepVGG_A0,
274
+ 'RepVGG-A1': create_RepVGG_A1,
275
+ 'RepVGG-A15': create_RepVGG_A15,
276
+ 'RepVGG-A1_leaky': create_RepVGG_A1_leaky,
277
+ 'RepVGG-A2': create_RepVGG_A2,
278
+ 'RepVGG-B0': create_RepVGG_B0,
279
+ 'RepVGG-B1': create_RepVGG_B1,
280
+ 'RepVGG-B1g2': create_RepVGG_B1g2,
281
+ 'RepVGG-B1g4': create_RepVGG_B1g4,
282
+ 'RepVGG-B2': create_RepVGG_B2,
283
+ 'RepVGG-B2g2': create_RepVGG_B2g2,
284
+ 'RepVGG-B2g4': create_RepVGG_B2g4,
285
+ 'RepVGG-B3': create_RepVGG_B3,
286
+ 'RepVGG-B3g2': create_RepVGG_B3g2,
287
+ 'RepVGG-B3g4': create_RepVGG_B3g4,
288
+ 'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper.
289
+ }
290
+ def get_RepVGG_func_by_name(name):
291
+ return func_dict[name]
292
+
293
+
294
+
295
+ # Use this for converting a RepVGG model or a bigger model with RepVGG as its component
296
+ # Use like this
297
+ # model = create_RepVGG_A0(deploy=False)
298
+ # train model or load weights
299
+ # repvgg_model_convert(model, save_path='repvgg_deploy.pth')
300
+ # If you want to preserve the original model, call with do_copy=True
301
+
302
+ # ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
303
+ # train_backbone = create_RepVGG_B2(deploy=False)
304
+ # train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
305
+ # train_pspnet = build_pspnet(backbone=train_backbone)
306
+ # segmentation_train(train_pspnet)
307
+ # deploy_pspnet = repvgg_model_convert(train_pspnet)
308
+ # segmentation_test(deploy_pspnet)
309
+ # ===================== example_pspnet.py shows an example
310
+
311
+ def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):
312
+ if do_copy:
313
+ model = copy.deepcopy(model)
314
+ for module in model.modules():
315
+ if hasattr(module, 'switch_to_deploy'):
316
+ module.switch_to_deploy()
317
+ if save_path is not None:
318
+ torch.save(model.state_dict(), save_path)
319
+ return model
imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py ADDED
@@ -0,0 +1,1094 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from .repvgg import get_RepVGG_func_by_name
4
+ from .s2dnet import S2DNet
5
+
6
+
7
+ def conv1x1(in_planes, out_planes, stride=1):
8
+ """1x1 convolution without padding"""
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
10
+
11
+
12
+ def conv3x3(in_planes, out_planes, stride=1):
13
+ """3x3 convolution with padding"""
14
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
15
+
16
+
17
+ class BasicBlock(nn.Module):
18
+ def __init__(self, in_planes, planes, stride=1):
19
+ super().__init__()
20
+ self.conv1 = conv3x3(in_planes, planes, stride)
21
+ self.conv2 = conv3x3(planes, planes)
22
+ self.bn1 = nn.BatchNorm2d(planes)
23
+ self.bn2 = nn.BatchNorm2d(planes)
24
+ self.relu = nn.ReLU(inplace=True)
25
+
26
+ if stride == 1:
27
+ self.downsample = None
28
+ else:
29
+ self.downsample = nn.Sequential(
30
+ conv1x1(in_planes, planes, stride=stride),
31
+ nn.BatchNorm2d(planes)
32
+ )
33
+
34
+ def forward(self, x):
35
+ y = x
36
+ y = self.relu(self.bn1(self.conv1(y)))
37
+ y = self.bn2(self.conv2(y))
38
+
39
+ if self.downsample is not None:
40
+ x = self.downsample(x)
41
+
42
+ return self.relu(x+y)
43
+
44
+
45
+ class ResNetFPN_8_2(nn.Module):
46
+ """
47
+ ResNet+FPN, output resolution are 1/8 and 1/2.
48
+ Each block has 2 layers.
49
+ """
50
+
51
+ def __init__(self, config):
52
+ super().__init__()
53
+ # Config
54
+ block = BasicBlock
55
+ initial_dim = config['initial_dim']
56
+ block_dims = config['block_dims']
57
+
58
+ # Class Variable
59
+ self.in_planes = initial_dim
60
+
61
+ # Networks
62
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
63
+ self.bn1 = nn.BatchNorm2d(initial_dim)
64
+ self.relu = nn.ReLU(inplace=True)
65
+
66
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
67
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
68
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
69
+
70
+ # 3. FPN upsample
71
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
72
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
73
+ self.layer2_outconv2 = nn.Sequential(
74
+ conv3x3(block_dims[2], block_dims[2]),
75
+ nn.BatchNorm2d(block_dims[2]),
76
+ nn.LeakyReLU(),
77
+ conv3x3(block_dims[2], block_dims[1]),
78
+ )
79
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
80
+ self.layer1_outconv2 = nn.Sequential(
81
+ conv3x3(block_dims[1], block_dims[1]),
82
+ nn.BatchNorm2d(block_dims[1]),
83
+ nn.LeakyReLU(),
84
+ conv3x3(block_dims[1], block_dims[0]),
85
+ )
86
+
87
+ for m in self.modules():
88
+ if isinstance(m, nn.Conv2d):
89
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
90
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
91
+ nn.init.constant_(m.weight, 1)
92
+ nn.init.constant_(m.bias, 0)
93
+
94
+ def _make_layer(self, block, dim, stride=1):
95
+ layer1 = block(self.in_planes, dim, stride=stride)
96
+ layer2 = block(dim, dim, stride=1)
97
+ layers = (layer1, layer2)
98
+
99
+ self.in_planes = dim
100
+ return nn.Sequential(*layers)
101
+
102
+ def forward(self, x):
103
+ # ResNet Backbone
104
+ x0 = self.relu(self.bn1(self.conv1(x)))
105
+ x1 = self.layer1(x0) # 1/2
106
+ x2 = self.layer2(x1) # 1/4
107
+ x3 = self.layer3(x2) # 1/8
108
+
109
+ # FPN
110
+ x3_out = self.layer3_outconv(x3)
111
+
112
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
113
+ x2_out = self.layer2_outconv(x2)
114
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
115
+
116
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
117
+ x1_out = self.layer1_outconv(x1)
118
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
119
+
120
+ return {'feats_c': x3_out, 'feats_f': x1_out}
121
+
122
+ def pro(self, x, profiler):
123
+ with profiler.profile('ResNet Backbone'):
124
+ # ResNet Backbone
125
+ x0 = self.relu(self.bn1(self.conv1(x)))
126
+ x1 = self.layer1(x0) # 1/2
127
+ x2 = self.layer2(x1) # 1/4
128
+ x3 = self.layer3(x2) # 1/8
129
+
130
+ with profiler.profile('ResNet FPN'):
131
+ # FPN
132
+ x3_out = self.layer3_outconv(x3)
133
+
134
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
135
+ x2_out = self.layer2_outconv(x2)
136
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
137
+
138
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
139
+ x1_out = self.layer1_outconv(x1)
140
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
141
+
142
+ return {'feats_c': x3_out, 'feats_f': x1_out}
143
+
144
+ class ResNetFPN_8_2_fix(nn.Module):
145
+ """
146
+ ResNet+FPN, output resolution are 1/8 and 1/2.
147
+ Each block has 2 layers.
148
+ """
149
+
150
+ def __init__(self, config):
151
+ super().__init__()
152
+ # Config
153
+ block = BasicBlock
154
+ initial_dim = config['initial_dim']
155
+ block_dims = config['block_dims']
156
+
157
+ # Class Variable
158
+ self.in_planes = initial_dim
159
+ self.skip_fine_feature = config['coarse_feat_only']
160
+ self.inter_feat = config['inter_feat']
161
+
162
+ # Networks
163
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
164
+ self.bn1 = nn.BatchNorm2d(initial_dim)
165
+ self.relu = nn.ReLU(inplace=True)
166
+
167
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
168
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
169
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
170
+
171
+ # 3. FPN upsample
172
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
173
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
174
+ self.layer2_outconv2 = nn.Sequential(
175
+ conv3x3(block_dims[2], block_dims[2]),
176
+ nn.BatchNorm2d(block_dims[2]),
177
+ nn.LeakyReLU(),
178
+ conv3x3(block_dims[2], block_dims[1]),
179
+ )
180
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
181
+ self.layer1_outconv2 = nn.Sequential(
182
+ conv3x3(block_dims[1], block_dims[1]),
183
+ nn.BatchNorm2d(block_dims[1]),
184
+ nn.LeakyReLU(),
185
+ conv3x3(block_dims[1], block_dims[0]),
186
+ )
187
+
188
+ for m in self.modules():
189
+ if isinstance(m, nn.Conv2d):
190
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
191
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
192
+ nn.init.constant_(m.weight, 1)
193
+ nn.init.constant_(m.bias, 0)
194
+
195
+ def _make_layer(self, block, dim, stride=1):
196
+ layer1 = block(self.in_planes, dim, stride=stride)
197
+ layer2 = block(dim, dim, stride=1)
198
+ layers = (layer1, layer2)
199
+
200
+ self.in_planes = dim
201
+ return nn.Sequential(*layers)
202
+
203
+ def forward(self, x):
204
+ # ResNet Backbone
205
+ x0 = self.relu(self.bn1(self.conv1(x)))
206
+ x1 = self.layer1(x0) # 1/2
207
+ x2 = self.layer2(x1) # 1/4
208
+ x3 = self.layer3(x2) # 1/8
209
+
210
+ # FPN
211
+ if self.skip_fine_feature:
212
+ if self.inter_feat:
213
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
214
+ else:
215
+ return {'feats_c': x3, 'feats_f': None}
216
+
217
+
218
+ x3_out = self.layer3_outconv(x3) # n+1
219
+
220
+ x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) # 2n+1
221
+ x2_out = self.layer2_outconv(x2)
222
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
223
+
224
+ x2_out_2x = F.interpolate(x2_out, size=((x2_out.size(-2)-1)*2+1, (x2_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) # 4n+1
225
+ x1_out = self.layer1_outconv(x1)
226
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
227
+
228
+ return {'feats_c': x3_out, 'feats_f': x1_out}
229
+
230
+
231
+ class ResNetFPN_16_4(nn.Module):
232
+ """
233
+ ResNet+FPN, output resolution are 1/16 and 1/4.
234
+ Each block has 2 layers.
235
+ """
236
+
237
+ def __init__(self, config):
238
+ super().__init__()
239
+ # Config
240
+ block = BasicBlock
241
+ initial_dim = config['initial_dim']
242
+ block_dims = config['block_dims']
243
+
244
+ # Class Variable
245
+ self.in_planes = initial_dim
246
+
247
+ # Networks
248
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
249
+ self.bn1 = nn.BatchNorm2d(initial_dim)
250
+ self.relu = nn.ReLU(inplace=True)
251
+
252
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
253
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
254
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
255
+ self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
256
+
257
+ # 3. FPN upsample
258
+ self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
259
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
260
+ self.layer3_outconv2 = nn.Sequential(
261
+ conv3x3(block_dims[3], block_dims[3]),
262
+ nn.BatchNorm2d(block_dims[3]),
263
+ nn.LeakyReLU(),
264
+ conv3x3(block_dims[3], block_dims[2]),
265
+ )
266
+
267
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
268
+ self.layer2_outconv2 = nn.Sequential(
269
+ conv3x3(block_dims[2], block_dims[2]),
270
+ nn.BatchNorm2d(block_dims[2]),
271
+ nn.LeakyReLU(),
272
+ conv3x3(block_dims[2], block_dims[1]),
273
+ )
274
+
275
+ for m in self.modules():
276
+ if isinstance(m, nn.Conv2d):
277
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
278
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
279
+ nn.init.constant_(m.weight, 1)
280
+ nn.init.constant_(m.bias, 0)
281
+
282
+ def _make_layer(self, block, dim, stride=1):
283
+ layer1 = block(self.in_planes, dim, stride=stride)
284
+ layer2 = block(dim, dim, stride=1)
285
+ layers = (layer1, layer2)
286
+
287
+ self.in_planes = dim
288
+ return nn.Sequential(*layers)
289
+
290
+ def forward(self, x):
291
+ # ResNet Backbone
292
+ x0 = self.relu(self.bn1(self.conv1(x)))
293
+ x1 = self.layer1(x0) # 1/2
294
+ x2 = self.layer2(x1) # 1/4
295
+ x3 = self.layer3(x2) # 1/8
296
+ x4 = self.layer4(x3) # 1/16
297
+
298
+ # FPN
299
+ x4_out = self.layer4_outconv(x4)
300
+
301
+ x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
302
+ x3_out = self.layer3_outconv(x3)
303
+ x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
304
+
305
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
306
+ x2_out = self.layer2_outconv(x2)
307
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
308
+
309
+ return {'feats_c': x4_out, 'feats_f': x2_out}
310
+
311
+
312
+ class ResNetFPN_8_1(nn.Module):
313
+ """
314
+ ResNet+FPN, output resolution are 1/8 and 1.
315
+ Each block has 2 layers.
316
+ """
317
+
318
+ def __init__(self, config):
319
+ super().__init__()
320
+ # Config
321
+ block = BasicBlock
322
+ initial_dim = config['initial_dim']
323
+ block_dims = config['block_dims']
324
+
325
+ # Class Variable
326
+ self.in_planes = initial_dim
327
+ self.skip_fine_feature = config['coarse_feat_only']
328
+ self.inter_feat = config['inter_feat']
329
+
330
+ # Networks
331
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
332
+ self.bn1 = nn.BatchNorm2d(initial_dim)
333
+ self.relu = nn.ReLU(inplace=True)
334
+
335
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
336
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
337
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
338
+
339
+ # 3. FPN upsample
340
+ if not self.skip_fine_feature:
341
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
342
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
343
+ self.layer2_outconv2 = nn.Sequential(
344
+ conv3x3(block_dims[2], block_dims[2]),
345
+ nn.BatchNorm2d(block_dims[2]),
346
+ nn.LeakyReLU(),
347
+ conv3x3(block_dims[2], block_dims[1]),
348
+ )
349
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
350
+ self.layer1_outconv2 = nn.Sequential(
351
+ conv3x3(block_dims[1], block_dims[1]),
352
+ nn.BatchNorm2d(block_dims[1]),
353
+ nn.LeakyReLU(),
354
+ conv3x3(block_dims[1], block_dims[0]),
355
+ )
356
+
357
+ for m in self.modules():
358
+ if isinstance(m, nn.Conv2d):
359
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
360
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
361
+ nn.init.constant_(m.weight, 1)
362
+ nn.init.constant_(m.bias, 0)
363
+
364
+ def _make_layer(self, block, dim, stride=1):
365
+ layer1 = block(self.in_planes, dim, stride=stride)
366
+ layer2 = block(dim, dim, stride=1)
367
+ layers = (layer1, layer2)
368
+
369
+ self.in_planes = dim
370
+ return nn.Sequential(*layers)
371
+
372
+ def forward(self, x):
373
+ # ResNet Backbone
374
+ x0 = self.relu(self.bn1(self.conv1(x)))
375
+ x1 = self.layer1(x0) # 1/2
376
+ x2 = self.layer2(x1) # 1/4
377
+ x3 = self.layer3(x2) # 1/8
378
+
379
+ # FPN
380
+ if self.skip_fine_feature:
381
+ if self.inter_feat:
382
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
383
+ else:
384
+ return {'feats_c': x3, 'feats_f': None}
385
+
386
+ x3_out = self.layer3_outconv(x3)
387
+
388
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
389
+ x2_out = self.layer2_outconv(x2)
390
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
391
+
392
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
393
+ x1_out = self.layer1_outconv(x1)
394
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
395
+
396
+ x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
397
+
398
+ if not self.inter_feat:
399
+ return {'feats_c': x3, 'feats_f': x0_out}
400
+ else:
401
+ return {'feats_c': x3, 'feats_f': x0_out, 'feats_x2': x2, 'feats_x1': x1}
402
+
403
+
404
+ class ResNetFPN_8_1_align(nn.Module):
405
+ """
406
+ ResNet+FPN, output resolution are 1/8 and 1.
407
+ Each block has 2 layers.
408
+ """
409
+
410
+ def __init__(self, config):
411
+ super().__init__()
412
+ # Config
413
+ block = BasicBlock
414
+ initial_dim = config['initial_dim']
415
+ block_dims = config['block_dims']
416
+
417
+ # Class Variable
418
+ self.in_planes = initial_dim
419
+ self.skip_fine_feature = config['coarse_feat_only']
420
+ self.inter_feat = config['inter_feat']
421
+ # Networks
422
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
423
+ self.bn1 = nn.BatchNorm2d(initial_dim)
424
+ self.relu = nn.ReLU(inplace=True)
425
+
426
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
427
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
428
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
429
+
430
+ # 3. FPN upsample
431
+ if not self.skip_fine_feature:
432
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
433
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
434
+ self.layer2_outconv2 = nn.Sequential(
435
+ conv3x3(block_dims[2], block_dims[2]),
436
+ nn.BatchNorm2d(block_dims[2]),
437
+ nn.LeakyReLU(),
438
+ conv3x3(block_dims[2], block_dims[1]),
439
+ )
440
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
441
+ self.layer1_outconv2 = nn.Sequential(
442
+ conv3x3(block_dims[1], block_dims[1]),
443
+ nn.BatchNorm2d(block_dims[1]),
444
+ nn.LeakyReLU(),
445
+ conv3x3(block_dims[1], block_dims[0]),
446
+ )
447
+
448
+ for m in self.modules():
449
+ if isinstance(m, nn.Conv2d):
450
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
451
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
452
+ nn.init.constant_(m.weight, 1)
453
+ nn.init.constant_(m.bias, 0)
454
+
455
+ def _make_layer(self, block, dim, stride=1):
456
+ layer1 = block(self.in_planes, dim, stride=stride)
457
+ layer2 = block(dim, dim, stride=1)
458
+ layers = (layer1, layer2)
459
+
460
+ self.in_planes = dim
461
+ return nn.Sequential(*layers)
462
+
463
+ def forward(self, x):
464
+ # ResNet Backbone
465
+ x0 = self.relu(self.bn1(self.conv1(x)))
466
+ x1 = self.layer1(x0) # 1/2
467
+ x2 = self.layer2(x1) # 1/4
468
+ x3 = self.layer3(x2) # 1/8
469
+
470
+ # FPN
471
+
472
+ if self.skip_fine_feature:
473
+ if self.inter_feat:
474
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
475
+ else:
476
+ return {'feats_c': x3, 'feats_f': None}
477
+
478
+ x3_out = self.layer3_outconv(x3)
479
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
480
+ x2_out = self.layer2_outconv(x2)
481
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
482
+
483
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
484
+ x1_out = self.layer1_outconv(x1)
485
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
486
+
487
+ x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
488
+
489
+ if not self.inter_feat:
490
+ return {'feats_c': x3, 'feats_f': x0_out}
491
+ else:
492
+ return {'feats_c': x3, 'feats_f': x0_out, 'feats_x2': x2, 'feats_x1': x1}
493
+
494
+ def pro(self, x, profiler):
495
+ with profiler.profile('ResNet Backbone'):
496
+ # ResNet Backbone
497
+ x0 = self.relu(self.bn1(self.conv1(x)))
498
+ x1 = self.layer1(x0) # 1/2
499
+ x2 = self.layer2(x1) # 1/4
500
+ x3 = self.layer3(x2) # 1/8
501
+
502
+ with profiler.profile('FPN'):
503
+ # FPN
504
+ x3_out = self.layer3_outconv(x3)
505
+
506
+ if self.skip_fine:
507
+ return [x3_out, None]
508
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
509
+ x2_out = self.layer2_outconv(x2)
510
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
511
+
512
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
513
+ x1_out = self.layer1_outconv(x1)
514
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
515
+
516
+ with profiler.profile('upsample*1'):
517
+ x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
518
+
519
+ return {'feats_c': x3_out, 'feats_f': x0_out}
520
+
521
+
522
+ class ResNetFPN_8_2_align(nn.Module):
523
+ """
524
+ ResNet+FPN, output resolution are 1/8 and 1/2.
525
+ Each block has 2 layers.
526
+ """
527
+
528
+ def __init__(self, config):
529
+ super().__init__()
530
+ # Config
531
+ block = BasicBlock
532
+ initial_dim = config['initial_dim']
533
+ block_dims = config['block_dims']
534
+
535
+ # Class Variable
536
+ self.in_planes = initial_dim
537
+ self.skip_fine_feature = config['coarse_feat_only']
538
+ self.inter_feat = config['inter_feat']
539
+ # Networks
540
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
541
+ self.bn1 = nn.BatchNorm2d(initial_dim)
542
+ self.relu = nn.ReLU(inplace=True)
543
+
544
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
545
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
546
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
547
+
548
+ # 3. FPN upsample
549
+ if not self.skip_fine_feature:
550
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
551
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
552
+ self.layer2_outconv2 = nn.Sequential(
553
+ conv3x3(block_dims[2], block_dims[2]),
554
+ nn.BatchNorm2d(block_dims[2]),
555
+ nn.LeakyReLU(),
556
+ conv3x3(block_dims[2], block_dims[1]),
557
+ )
558
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
559
+ self.layer1_outconv2 = nn.Sequential(
560
+ conv3x3(block_dims[1], block_dims[1]),
561
+ nn.BatchNorm2d(block_dims[1]),
562
+ nn.LeakyReLU(),
563
+ conv3x3(block_dims[1], block_dims[0]),
564
+ )
565
+
566
+ for m in self.modules():
567
+ if isinstance(m, nn.Conv2d):
568
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
569
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
570
+ nn.init.constant_(m.weight, 1)
571
+ nn.init.constant_(m.bias, 0)
572
+
573
+ def _make_layer(self, block, dim, stride=1):
574
+ layer1 = block(self.in_planes, dim, stride=stride)
575
+ layer2 = block(dim, dim, stride=1)
576
+ layers = (layer1, layer2)
577
+
578
+ self.in_planes = dim
579
+ return nn.Sequential(*layers)
580
+
581
+ def forward(self, x):
582
+ # ResNet Backbone
583
+ x0 = self.relu(self.bn1(self.conv1(x)))
584
+ x1 = self.layer1(x0) # 1/2
585
+ x2 = self.layer2(x1) # 1/4
586
+ x3 = self.layer3(x2) # 1/8
587
+
588
+ if self.skip_fine_feature:
589
+ if self.inter_feat:
590
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
591
+ else:
592
+ return {'feats_c': x3, 'feats_f': None}
593
+
594
+ # FPN
595
+ x3_out = self.layer3_outconv(x3)
596
+
597
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
598
+ x2_out = self.layer2_outconv(x2)
599
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
600
+
601
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
602
+ x1_out = self.layer1_outconv(x1)
603
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
604
+
605
+ if not self.inter_feat:
606
+ return {'feats_c': x3, 'feats_f': x1_out}
607
+ else:
608
+ return {'feats_c': x3, 'feats_f': x1_out, 'feats_x2': x2, 'feats_x1': x1}
609
+
610
+
611
+ class ResNet_8_1_align(nn.Module):
612
+ """
613
+ ResNet, output resolution are 1/8 and 1.
614
+ Each block has 2 layers.
615
+ """
616
+
617
+ def __init__(self, config):
618
+ super().__init__()
619
+ # Config
620
+ block = BasicBlock
621
+ initial_dim = config['initial_dim']
622
+ block_dims = config['block_dims']
623
+
624
+ # Class Variable
625
+ self.in_planes = initial_dim
626
+
627
+ # Networks
628
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
629
+ self.bn1 = nn.BatchNorm2d(initial_dim)
630
+ self.relu = nn.ReLU(inplace=True)
631
+
632
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
633
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
634
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
635
+
636
+ # 3. FPN upsample
637
+ # self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
638
+ # self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
639
+ # self.layer2_outconv2 = nn.Sequential(
640
+ # conv3x3(block_dims[2], block_dims[2]),
641
+ # nn.BatchNorm2d(block_dims[2]),
642
+ # nn.LeakyReLU(),
643
+ # conv3x3(block_dims[2], block_dims[1]),
644
+ # )
645
+ # self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
646
+ # self.layer1_outconv2 = nn.Sequential(
647
+ # conv3x3(block_dims[1], block_dims[1]),
648
+ # nn.BatchNorm2d(block_dims[1]),
649
+ # nn.LeakyReLU(),
650
+ # conv3x3(block_dims[1], block_dims[0]),
651
+ # )
652
+ self.layer0_outconv = conv1x1(block_dims[2], block_dims[0])
653
+ for m in self.modules():
654
+ if isinstance(m, nn.Conv2d):
655
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
656
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
657
+ nn.init.constant_(m.weight, 1)
658
+ nn.init.constant_(m.bias, 0)
659
+
660
+ def _make_layer(self, block, dim, stride=1):
661
+ layer1 = block(self.in_planes, dim, stride=stride)
662
+ layer2 = block(dim, dim, stride=1)
663
+ layers = (layer1, layer2)
664
+
665
+ self.in_planes = dim
666
+ return nn.Sequential(*layers)
667
+
668
+ def forward(self, x):
669
+ # ResNet Backbone
670
+ x0 = self.relu(self.bn1(self.conv1(x)))
671
+ x1 = self.layer1(x0) # 1/2
672
+ x2 = self.layer2(x1) # 1/4
673
+ x3 = self.layer3(x2) # 1/8
674
+
675
+ # FPN
676
+ # x3_out = self.layer3_outconv(x3)
677
+
678
+ # x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
679
+ # x2_out = self.layer2_outconv(x2)
680
+ # x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
681
+
682
+ # x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
683
+ # x1_out = self.layer1_outconv(x1)
684
+ # x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
685
+
686
+ x0_out = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False)
687
+ x0_out = self.layer0_outconv(x0_out)
688
+
689
+ return {'feats_c': x3, 'feats_f': x0_out}
690
+
691
+ class VGG_8_1_align(nn.Module):
692
+ """
693
+ VGG-like backbone, output resolution are 1/8 and 1.
694
+ Each block has 2 layers.
695
+ """
696
+
697
+ def __init__(self, config):
698
+ super().__init__()
699
+ # Config
700
+ block = BasicBlock
701
+ initial_dim = config['initial_dim']
702
+ block_dims = config['block_dims']
703
+
704
+ self.relu = nn.ReLU(inplace=True)
705
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
706
+ c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
707
+
708
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
709
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
710
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
711
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
712
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
713
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
714
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
715
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
716
+
717
+ # self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
718
+ # self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
719
+
720
+ self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
721
+ self.convDb = nn.Conv2d(
722
+ c5, 256,
723
+ kernel_size=1, stride=1, padding=0)
724
+ self.layer0_outconv = conv1x1(block_dims[2], block_dims[0])
725
+ for m in self.modules():
726
+ if isinstance(m, nn.Conv2d):
727
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
728
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
729
+ nn.init.constant_(m.weight, 1)
730
+ nn.init.constant_(m.bias, 0)
731
+
732
+ def _make_layer(self, block, dim, stride=1):
733
+ layer1 = block(self.in_planes, dim, stride=stride)
734
+ layer2 = block(dim, dim, stride=1)
735
+ layers = (layer1, layer2)
736
+
737
+ self.in_planes = dim
738
+ return nn.Sequential(*layers)
739
+
740
+ def forward(self, x):
741
+ # Shared Encoder
742
+ x = self.relu(self.conv1a(x))
743
+ x = self.relu(self.conv1b(x))
744
+ x = self.pool(x)
745
+ x = self.relu(self.conv2a(x))
746
+ x = self.relu(self.conv2b(x))
747
+ x = self.pool(x)
748
+ x = self.relu(self.conv3a(x))
749
+ x = self.relu(self.conv3b(x))
750
+ x = self.pool(x)
751
+ x = self.relu(self.conv4a(x))
752
+ x = self.relu(self.conv4b(x))
753
+
754
+ cDa = self.relu(self.convDa(x))
755
+ descriptors = self.convDb(cDa)
756
+ x3_out = nn.functional.normalize(descriptors, p=2, dim=1)
757
+
758
+ x0_out = F.interpolate(x3_out, scale_factor=8., mode='bilinear', align_corners=False)
759
+ x0_out = self.layer0_outconv(x0_out)
760
+ # ResNet Backbone
761
+ # x0 = self.relu(self.bn1(self.conv1(x)))
762
+ # x1 = self.layer1(x0) # 1/2
763
+ # x2 = self.layer2(x1) # 1/4
764
+ # x3 = self.layer3(x2) # 1/8
765
+
766
+ # # FPN
767
+ # x3_out = self.layer3_outconv(x3)
768
+
769
+ # x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
770
+ # x2_out = self.layer2_outconv(x2)
771
+ # x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
772
+
773
+ # x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
774
+ # x1_out = self.layer1_outconv(x1)
775
+ # x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
776
+
777
+ # x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
778
+
779
+ return {'feats_c': x3_out, 'feats_f': x0_out}
780
+
781
+ class RepVGG_8_1_align(nn.Module):
782
+ """
783
+ RepVGG backbone, output resolution are 1/8 and 1.
784
+ Each block has 2 layers.
785
+ """
786
+
787
+ def __init__(self, config):
788
+ super().__init__()
789
+ # Config
790
+ # block = BasicBlock
791
+ # initial_dim = config['initial_dim']
792
+ block_dims = config['block_dims']
793
+ self.skip_fine_feature = config['coarse_feat_only']
794
+ self.inter_feat = config['inter_feat']
795
+ self.leaky = config['leaky']
796
+
797
+ # backbone_name='RepVGG-B0'
798
+ if config.get('repvggmodel') is not None:
799
+ backbone_name=config['repvggmodel']
800
+ elif self.leaky:
801
+ backbone_name='RepVGG-A1_leaky'
802
+ else:
803
+ backbone_name='RepVGG-A1'
804
+ repvgg_fn = get_RepVGG_func_by_name(backbone_name)
805
+ backbone = repvgg_fn(False)
806
+ self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4
807
+ # self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3, backbone.stage4
808
+
809
+ # 3. FPN upsample
810
+ if not self.skip_fine_feature:
811
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
812
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
813
+ self.layer2_outconv2 = nn.Sequential(
814
+ conv3x3(block_dims[2], block_dims[2]),
815
+ nn.BatchNorm2d(block_dims[2]),
816
+ nn.LeakyReLU(),
817
+ conv3x3(block_dims[2], block_dims[1]),
818
+ )
819
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
820
+ self.layer1_outconv2 = nn.Sequential(
821
+ conv3x3(block_dims[1], block_dims[1]),
822
+ nn.BatchNorm2d(block_dims[1]),
823
+ nn.LeakyReLU(),
824
+ conv3x3(block_dims[1], block_dims[0]),
825
+ )
826
+
827
+ # self.layer0_outconv = conv1x1(192, 48)
828
+
829
+ for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
830
+ for m in layer.modules():
831
+ if isinstance(m, nn.Conv2d):
832
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
833
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
834
+ nn.init.constant_(m.weight, 1)
835
+ nn.init.constant_(m.bias, 0)
836
+ # for layer in [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4]:
837
+ # for m in layer.modules():
838
+ # if isinstance(m, nn.Conv2d):
839
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
840
+ # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
841
+ # nn.init.constant_(m.weight, 1)
842
+ # nn.init.constant_(m.bias, 0)
843
+
844
+ def forward(self, x):
845
+
846
+ out = self.layer0(x) # 1/2
847
+ for module in self.layer1:
848
+ out = module(out) # 1/2
849
+ x1 = out
850
+ for module in self.layer2:
851
+ out = module(out) # 1/4
852
+ x2 = out
853
+ for module in self.layer3:
854
+ out = module(out) # 1/8
855
+ x3 = out
856
+ # for module in self.layer4:
857
+ # out = module(out)
858
+ # x3 = out
859
+
860
+ if self.skip_fine_feature:
861
+ if self.inter_feat:
862
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
863
+ else:
864
+ return {'feats_c': x3, 'feats_f': None}
865
+ x3_out = self.layer3_outconv(x3)
866
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
867
+ x2_out = self.layer2_outconv(x2)
868
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
869
+
870
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
871
+ x1_out = self.layer1_outconv(x1)
872
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
873
+
874
+ x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
875
+
876
+ # x_f = F.interpolate(x_c, scale_factor=8., mode='bilinear', align_corners=False)
877
+ # x_f = self.layer0_outconv(x_f)
878
+ return {'feats_c': x3_out, 'feats_f': x0_out}
879
+
880
+
881
+ class RepVGG_8_2_fix(nn.Module):
882
+ """
883
+ RepVGG backbone, output resolution are 1/8 and 1.
884
+ Each block has 2 layers.
885
+ """
886
+
887
+ def __init__(self, config):
888
+ super().__init__()
889
+ # Config
890
+ # block = BasicBlock
891
+ # initial_dim = config['initial_dim']
892
+ block_dims = config['block_dims']
893
+ self.skip_fine_feature = config['coarse_feat_only']
894
+ self.inter_feat = config['inter_feat']
895
+
896
+ # backbone_name='RepVGG-B0'
897
+ backbone_name='RepVGG-A1'
898
+ repvgg_fn = get_RepVGG_func_by_name(backbone_name)
899
+ backbone = repvgg_fn(False)
900
+ self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4
901
+
902
+ # 3. FPN upsample
903
+ if not self.skip_fine_feature:
904
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
905
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
906
+ self.layer2_outconv2 = nn.Sequential(
907
+ conv3x3(block_dims[2], block_dims[2]),
908
+ nn.BatchNorm2d(block_dims[2]),
909
+ nn.LeakyReLU(),
910
+ conv3x3(block_dims[2], block_dims[1]),
911
+ )
912
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
913
+ self.layer1_outconv2 = nn.Sequential(
914
+ conv3x3(block_dims[1], block_dims[1]),
915
+ nn.BatchNorm2d(block_dims[1]),
916
+ nn.LeakyReLU(),
917
+ conv3x3(block_dims[1], block_dims[0]),
918
+ )
919
+
920
+ # self.layer0_outconv = conv1x1(192, 48)
921
+
922
+ for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
923
+ for m in layer.modules():
924
+ if isinstance(m, nn.Conv2d):
925
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
926
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
927
+ nn.init.constant_(m.weight, 1)
928
+ nn.init.constant_(m.bias, 0)
929
+
930
+ def forward(self, x):
931
+
932
+ x0 = self.layer0(x) # 1/2
933
+ out = x0
934
+ for module in self.layer1:
935
+ out = module(out) # 1/2
936
+ x1 = out
937
+ for module in self.layer2:
938
+ out = module(out) # 1/4
939
+ x2 = out
940
+ for module in self.layer3:
941
+ out = module(out) # 1/8
942
+ x3 = out
943
+ # for module in self.layer4:
944
+ # out = module(out)
945
+
946
+ if self.skip_fine_feature:
947
+ if self.inter_feat:
948
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
949
+ else:
950
+ return {'feats_c': x3, 'feats_f': None}
951
+ x3_out = self.layer3_outconv(x3)
952
+ x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True)
953
+ x2_out = self.layer2_outconv(x2)
954
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
955
+
956
+ x2_out_2x = F.interpolate(x2_out, size=((x2_out.size(-2)-1)*2+1, (x2_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True)
957
+ x1_out = self.layer1_outconv(x1)
958
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
959
+
960
+ # x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
961
+
962
+ # x_f = F.interpolate(x_c, scale_factor=8., mode='bilinear', align_corners=False)
963
+ # x_f = self.layer0_outconv(x_f)
964
+ return {'feats_c': x3_out, 'feats_f': x1_out}
965
+
966
+
967
+ class RepVGGnfpn_8_1_align(nn.Module):
968
+ """
969
+ RepVGG backbone, output resolution are 1/8 and 1.
970
+ Each block has 2 layers.
971
+ """
972
+
973
+ def __init__(self, config):
974
+ super().__init__()
975
+ # Config
976
+ # block = BasicBlock
977
+ # initial_dim = config['initial_dim']
978
+ block_dims = config['block_dims']
979
+ self.skip_fine_feature = config['coarse_feat_only']
980
+ self.inter_feat = config['inter_feat']
981
+
982
+ # backbone_name='RepVGG-B0'
983
+ backbone_name='RepVGG-A1'
984
+ repvgg_fn = get_RepVGG_func_by_name(backbone_name)
985
+ backbone = repvgg_fn(False)
986
+ self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4
987
+
988
+ # 3. FPN upsample
989
+ if not self.skip_fine_feature:
990
+ self.layer0_outconv = conv1x1(block_dims[2], block_dims[0])
991
+ # self.layer0_outconv = conv1x1(192, 48)
992
+
993
+ for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
994
+ for m in layer.modules():
995
+ if isinstance(m, nn.Conv2d):
996
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
997
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
998
+ nn.init.constant_(m.weight, 1)
999
+ nn.init.constant_(m.bias, 0)
1000
+
1001
+ def forward(self, x):
1002
+
1003
+ x0 = self.layer0(x) # 1/2
1004
+ out = x0
1005
+ for module in self.layer1:
1006
+ out = module(out) # 1/2
1007
+ x1 = out
1008
+ for module in self.layer2:
1009
+ out = module(out) # 1/4
1010
+ x2 = out
1011
+ for module in self.layer3:
1012
+ out = module(out) # 1/8
1013
+ x3 = out
1014
+ # for module in self.layer4:
1015
+ # out = module(out)
1016
+
1017
+ if self.skip_fine_feature:
1018
+ if self.inter_feat:
1019
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
1020
+ else:
1021
+ return {'feats_c': x3, 'feats_f': None}
1022
+ # x3_out = self.layer3_outconv(x3)
1023
+ # x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
1024
+ # x2_out = self.layer2_outconv(x2)
1025
+ # x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
1026
+
1027
+ # x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
1028
+ # x1_out = self.layer1_outconv(x1)
1029
+ # x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
1030
+
1031
+ # x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
1032
+
1033
+ x_f = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False)
1034
+ x_f = self.layer0_outconv(x_f)
1035
+ # x_f2 = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False)
1036
+ # x_f2 = self.layer0_outconv(x_f2)
1037
+ return {'feats_c': x3, 'feats_f': x_f}
1038
+
1039
+
1040
+ class s2dnet_8_1_align(nn.Module):
1041
+ """
1042
+ ResNet+FPN, output resolution are 1/8 and 1.
1043
+ Each block has 2 layers.
1044
+ """
1045
+
1046
+ def __init__(self, config):
1047
+ super().__init__()
1048
+ # Config
1049
+ block = BasicBlock
1050
+ initial_dim = config['initial_dim']
1051
+ block_dims = config['block_dims']
1052
+
1053
+ # Class Variable
1054
+ self.in_planes = initial_dim
1055
+ self.skip_fine_feature = config['coarse_feat_only']
1056
+ self.inter_feat = config['inter_feat']
1057
+ # Networks
1058
+ self.backbone = S2DNet(checkpoint_path = '/cephfs-mvs/3dv-research/hexingyi/code_yf/loftrdev/weights/s2dnet/s2dnet_weights.pth')
1059
+ # 3. FPN upsample
1060
+ # self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
1061
+ # if not self.skip_fine_feature:
1062
+ # self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
1063
+ # self.layer2_outconv2 = nn.Sequential(
1064
+ # conv3x3(block_dims[2], block_dims[2]),
1065
+ # nn.BatchNorm2d(block_dims[2]),
1066
+ # nn.LeakyReLU(),
1067
+ # conv3x3(block_dims[2], block_dims[1]),
1068
+ # )
1069
+ # self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
1070
+ # self.layer1_outconv2 = nn.Sequential(
1071
+ # conv3x3(block_dims[1], block_dims[1]),
1072
+ # nn.BatchNorm2d(block_dims[1]),
1073
+ # nn.LeakyReLU(),
1074
+ # conv3x3(block_dims[1], block_dims[0]),
1075
+ # )
1076
+
1077
+ # for m in self.modules():
1078
+ # if isinstance(m, nn.Conv2d):
1079
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
1080
+ # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
1081
+ # nn.init.constant_(m.weight, 1)
1082
+ # nn.init.constant_(m.bias, 0)
1083
+
1084
+ def forward(self, x):
1085
+ ret = self.backbone(x)
1086
+ ret[2] = F.interpolate(ret[2], scale_factor=2., mode='bilinear', align_corners=False)
1087
+ if self.skip_fine_feature:
1088
+ if self.inter_feat:
1089
+ return {'feats_c': ret[2], 'feats_f': None, 'feats_x2': ret[1], 'feats_x1': ret[0]}
1090
+ else:
1091
+ return {'feats_c': ret[2], 'feats_f': None,}
1092
+
1093
+ def pro(self, x, profiler):
1094
+ pass
imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # from torchvision import models
6
+ from typing import List, Dict
7
+
8
+ # VGG-16 Layer Names and Channels
9
+ vgg16_layers = {
10
+ "conv1_1": 64,
11
+ "relu1_1": 64,
12
+ "conv1_2": 64,
13
+ "relu1_2": 64,
14
+ "pool1": 64,
15
+ "conv2_1": 128,
16
+ "relu2_1": 128,
17
+ "conv2_2": 128,
18
+ "relu2_2": 128,
19
+ "pool2": 128,
20
+ "conv3_1": 256,
21
+ "relu3_1": 256,
22
+ "conv3_2": 256,
23
+ "relu3_2": 256,
24
+ "conv3_3": 256,
25
+ "relu3_3": 256,
26
+ "pool3": 256,
27
+ "conv4_1": 512,
28
+ "relu4_1": 512,
29
+ "conv4_2": 512,
30
+ "relu4_2": 512,
31
+ "conv4_3": 512,
32
+ "relu4_3": 512,
33
+ "pool4": 512,
34
+ "conv5_1": 512,
35
+ "relu5_1": 512,
36
+ "conv5_2": 512,
37
+ "relu5_2": 512,
38
+ "conv5_3": 512,
39
+ "relu5_3": 512,
40
+ "pool5": 512,
41
+ }
42
+
43
+ class AdapLayers(nn.Module):
44
+ """Small adaptation layers.
45
+ """
46
+
47
+ def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128):
48
+ """Initialize one adaptation layer for every extraction point.
49
+
50
+ Args:
51
+ hypercolumn_layers: The list of the hypercolumn layer names.
52
+ output_dim: The output channel dimension.
53
+ """
54
+ super(AdapLayers, self).__init__()
55
+ self.layers = []
56
+ channel_sizes = [vgg16_layers[name] for name in hypercolumn_layers]
57
+ for i, l in enumerate(channel_sizes):
58
+ layer = nn.Sequential(
59
+ nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0),
60
+ nn.ReLU(),
61
+ nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2),
62
+ nn.BatchNorm2d(output_dim),
63
+ )
64
+ self.layers.append(layer)
65
+ self.add_module("adap_layer_{}".format(i), layer)
66
+
67
+ def forward(self, features: List[torch.tensor]):
68
+ """Apply adaptation layers.
69
+ """
70
+ for i, _ in enumerate(features):
71
+ features[i] = getattr(self, "adap_layer_{}".format(i))(features[i])
72
+ return features
73
+
74
+ class S2DNet(nn.Module):
75
+ """The S2DNet model
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ # hypercolumn_layers: List[str] = ["conv2_2", "conv3_3", "relu4_3"],
81
+ hypercolumn_layers: List[str] = ["conv1_2", "conv3_3", "conv5_3"],
82
+ checkpoint_path: str = None,
83
+ ):
84
+ """Initialize S2DNet.
85
+
86
+ Args:
87
+ device: The torch device to put the model on
88
+ hypercolumn_layers: Names of the layers to extract features from
89
+ checkpoint_path: Path to the pre-trained model.
90
+ """
91
+ super(S2DNet, self).__init__()
92
+ self._checkpoint_path = checkpoint_path
93
+ self.layer_to_index = dict((k, v) for v, k in enumerate(vgg16_layers.keys()))
94
+ self._hypercolumn_layers = hypercolumn_layers
95
+
96
+ # Initialize architecture
97
+ vgg16 = models.vgg16(pretrained=False)
98
+ # layers = list(vgg16.features.children())[:-2]
99
+ layers = list(vgg16.features.children())[:-1]
100
+ # layers = list(vgg16.features.children())[:23] # relu4_3
101
+ self.encoder = nn.Sequential(*layers)
102
+ self.adaptation_layers = AdapLayers(self._hypercolumn_layers) # .to(self._device)
103
+ self.eval()
104
+
105
+ # Restore params from checkpoint
106
+ if checkpoint_path:
107
+ print(">> Loading weights from {}".format(checkpoint_path))
108
+ self._checkpoint = torch.load(checkpoint_path)
109
+ self._hypercolumn_layers = self._checkpoint["hypercolumn_layers"]
110
+ self.load_state_dict(self._checkpoint["state_dict"])
111
+
112
+ def forward(self, image_tensor: torch.FloatTensor):
113
+ """Compute intermediate feature maps at the provided extraction levels.
114
+
115
+ Args:
116
+ image_tensor: The [N x 3 x H x Ws] input image tensor.
117
+ Returns:
118
+ feature_maps: The list of output feature maps.
119
+ """
120
+ feature_maps, j = [], 0
121
+ feature_map = image_tensor.repeat(1,3,1,1)
122
+ layer_list = list(self.encoder.modules())[0]
123
+ for i, layer in enumerate(layer_list):
124
+ feature_map = layer(feature_map)
125
+ if j < len(self._hypercolumn_layers):
126
+ next_extraction_index = self.layer_to_index[self._hypercolumn_layers[j]]
127
+ if i == next_extraction_index:
128
+ feature_maps.append(feature_map)
129
+ j += 1
130
+ feature_maps = self.adaptation_layers(feature_maps)
131
+ return feature_maps
imcui/third_party/MatchAnything/src/loftr/loftr.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops.einops import rearrange
4
+
5
+ from .backbone import build_backbone
6
+ # from third_party.matchformer.model.backbone import build_backbone as build_backbone_matchformer
7
+ from .utils.position_encoding import PositionEncodingSine
8
+ from .loftr_module import LocalFeatureTransformer, FinePreprocess
9
+ from .utils.coarse_matching import CoarseMatching
10
+ from .utils.fine_matching import FineMatching
11
+
12
+ from loguru import logger
13
+
14
+ class LoFTR(nn.Module):
15
+ def __init__(self, config, profiler=None):
16
+ super().__init__()
17
+ # Misc
18
+ self.config = config
19
+ self.profiler = profiler
20
+
21
+ # Modules
22
+ self.backbone = build_backbone(config)
23
+ if not (self.config['coarse']['skip'] or self.config['coarse']['rope'] or self.config['coarse']['pan'] or self.config['coarse']['token_mixer'] is not None):
24
+ self.pos_encoding = PositionEncodingSine(
25
+ config['coarse']['d_model'],
26
+ temp_bug_fix=config['coarse']['temp_bug_fix'],
27
+ npe=config['coarse']['npe'],
28
+ )
29
+ if self.config['coarse']['abspe']:
30
+ self.pos_encoding = PositionEncodingSine(
31
+ config['coarse']['d_model'],
32
+ temp_bug_fix=config['coarse']['temp_bug_fix'],
33
+ npe=config['coarse']['npe'],
34
+ )
35
+
36
+ if self.config['coarse']['skip'] is False:
37
+ self.loftr_coarse = LocalFeatureTransformer(config)
38
+ self.coarse_matching = CoarseMatching(config['match_coarse'])
39
+ # self.fine_preprocess = FinePreprocess(config).float()
40
+ self.fine_preprocess = FinePreprocess(config)
41
+ if self.config['fine']['skip'] is False:
42
+ self.loftr_fine = LocalFeatureTransformer(config["fine"])
43
+ self.fine_matching = FineMatching(config)
44
+
45
+ def forward(self, data):
46
+ """
47
+ Update:
48
+ data (dict): {
49
+ 'image0': (torch.Tensor): (N, 1, H, W)
50
+ 'image1': (torch.Tensor): (N, 1, H, W)
51
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
52
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
53
+ }
54
+ """
55
+ # 1. Local Feature CNN
56
+ data.update({
57
+ 'bs': data['image0'].size(0),
58
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
59
+ })
60
+
61
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
62
+ # feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
63
+ ret_dict = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
64
+ feats_c, feats_f = ret_dict['feats_c'], ret_dict['feats_f']
65
+ if self.config['inter_feat']:
66
+ data.update({
67
+ 'feats_x2': ret_dict['feats_x2'],
68
+ 'feats_x1': ret_dict['feats_x1'],
69
+ })
70
+ if self.config['coarse_feat_only']:
71
+ (feat_c0, feat_c1) = feats_c.split(data['bs'])
72
+ feat_f0, feat_f1 = None, None
73
+ else:
74
+ (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
75
+ else: # handle different input shapes
76
+ # (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
77
+ ret_dict0, ret_dict1 = self.backbone(data['image0']), self.backbone(data['image1'])
78
+ feat_c0, feat_f0 = ret_dict0['feats_c'], ret_dict0['feats_f']
79
+ feat_c1, feat_f1 = ret_dict1['feats_c'], ret_dict1['feats_f']
80
+ if self.config['inter_feat']:
81
+ data.update({
82
+ 'feats_x2_0': ret_dict0['feats_x2'],
83
+ 'feats_x1_0': ret_dict0['feats_x1'],
84
+ 'feats_x2_1': ret_dict1['feats_x2'],
85
+ 'feats_x1_1': ret_dict1['feats_x1'],
86
+ })
87
+ if self.config['coarse_feat_only']:
88
+ feat_f0, feat_f1 = None, None
89
+
90
+
91
+ mul = self.config['resolution'][0] // self.config['resolution'][1]
92
+ # mul = 4
93
+ if self.config['fix_bias']:
94
+ data.update({
95
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
96
+ 'hw0_f': feat_f0.shape[2:] if feat_f0 is not None else [(feat_c0.shape[2]-1) * mul+1, (feat_c0.shape[3]-1) * mul+1] ,
97
+ 'hw1_f': feat_f1.shape[2:] if feat_f1 is not None else [(feat_c1.shape[2]-1) * mul+1, (feat_c1.shape[3]-1) * mul+1]
98
+ })
99
+ else:
100
+ data.update({
101
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
102
+ 'hw0_f': feat_f0.shape[2:] if feat_f0 is not None else [feat_c0.shape[2] * mul, feat_c0.shape[3] * mul] ,
103
+ 'hw1_f': feat_f1.shape[2:] if feat_f1 is not None else [feat_c1.shape[2] * mul, feat_c1.shape[3] * mul]
104
+ })
105
+
106
+ # 2. coarse-level loftr module
107
+ # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
108
+ if self.config['coarse']['skip']:
109
+ mask_c0 = mask_c1 = None # mask is useful in training
110
+ if 'mask0' in data:
111
+ mask_c0, mask_c1 = data['mask0'], data['mask1']
112
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
113
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
114
+
115
+ elif self.config['coarse']['pan']:
116
+ # assert feat_c0.shape[0] == 1, 'batch size must be 1 when using mask Xformer now'
117
+ if self.config['coarse']['abspe']:
118
+ feat_c0 = self.pos_encoding(feat_c0)
119
+ feat_c1 = self.pos_encoding(feat_c1)
120
+
121
+ mask_c0 = mask_c1 = None # mask is useful in training
122
+ if 'mask0' in data:
123
+ mask_c0, mask_c1 = data['mask0'], data['mask1']
124
+ if self.config['matchability']: # else match in loftr_coarse
125
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1, data=data)
126
+ else:
127
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
128
+
129
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
130
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
131
+ else:
132
+ if not (self.config['coarse']['rope'] or self.config['coarse']['token_mixer'] is not None):
133
+ feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
134
+ feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
135
+
136
+ mask_c0 = mask_c1 = None # mask is useful in training
137
+ if self.config['coarse']['rope']:
138
+ if 'mask0' in data:
139
+ mask_c0, mask_c1 = data['mask0'], data['mask1']
140
+ else:
141
+ if 'mask0' in data:
142
+ mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
143
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
144
+ if self.config['coarse']['rope']:
145
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
146
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
147
+
148
+ # detect nan
149
+ if self.config['replace_nan'] and (torch.any(torch.isnan(feat_c0)) or torch.any(torch.isnan(feat_c1))):
150
+ logger.info(f'replace nan in coarse attention')
151
+ logger.info(f"feat_c0_nan_num: {torch.isnan(feat_c0).int().sum()}, feat_c1_nan_num: {torch.isnan(feat_c1).int().sum()}")
152
+ logger.info(f"feat_c0: {feat_c0}, feat_c1: {feat_c1}")
153
+ logger.info(f"feat_c0_max: {feat_c0.abs().max()}, feat_c1_max: {feat_c1.abs().max()}")
154
+ feat_c0[torch.isnan(feat_c0)] = 0
155
+ feat_c1[torch.isnan(feat_c1)] = 0
156
+ logger.info(f"feat_c0_nanmax: {feat_c0.abs().max()}, feat_c1_nanmax: {feat_c1.abs().max()}")
157
+
158
+ # 3. match coarse-level
159
+ if not self.config['matchability']: # else match in loftr_coarse
160
+ self.coarse_matching(feat_c0, feat_c1, data,
161
+ mask_c0=mask_c0.view(mask_c0.size(0), -1) if mask_c0 is not None else mask_c0,
162
+ mask_c1=mask_c1.view(mask_c1.size(0), -1) if mask_c1 is not None else mask_c1
163
+ )
164
+
165
+ #return data['conf_matrix'],feat_c0,feat_c1,data['feats_x2'],data['feats_x1']
166
+
167
+ # norm FPNfeat
168
+ if self.config['norm_fpnfeat']:
169
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
170
+ [feat_c0, feat_c1])
171
+ if self.config['norm_fpnfeat2']:
172
+ assert self.config['inter_feat']
173
+ logger.info(f'before norm_fpnfeat2 max of feat_c0, feat_c1:{feat_c0.abs().max()}, {feat_c1.abs().max()}')
174
+ if data['hw0_i'] == data['hw1_i']:
175
+ logger.info(f'before norm_fpnfeat2 max of data[feats_x2], data[feats_x1]:{data["feats_x2"].abs().max()}, {data["feats_x1"].abs().max()}')
176
+ feat_c0, feat_c1, data['feats_x2'], data['feats_x1'] = map(lambda feat: feat / feat.shape[-1]**.5,
177
+ [feat_c0, feat_c1, data['feats_x2'], data['feats_x1']])
178
+ else:
179
+ feat_c0, feat_c1, data['feats_x2_0'], data['feats_x2_1'], data['feats_x1_0'], data['feats_x1_1'] = map(lambda feat: feat / feat.shape[-1]**.5,
180
+ [feat_c0, feat_c1, data['feats_x2_0'], data['feats_x2_1'], data['feats_x1_0'], data['feats_x1_1']])
181
+
182
+
183
+ # 4. fine-level refinement
184
+ with torch.autocast(enabled=False, device_type="cuda"):
185
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
186
+
187
+ # detect nan
188
+ if self.config['replace_nan'] and (torch.any(torch.isnan(feat_f0_unfold)) or torch.any(torch.isnan(feat_f1_unfold))):
189
+ logger.info(f'replace nan in fine_preprocess')
190
+ logger.info(f"feat_f0_unfold_nan_num: {torch.isnan(feat_f0_unfold).int().sum()}, feat_f1_unfold_nan_num: {torch.isnan(feat_f1_unfold).int().sum()}")
191
+ logger.info(f"feat_f0_unfold: {feat_f0_unfold}, feat_f1_unfold: {feat_f1_unfold}")
192
+ logger.info(f"feat_f0_unfold_max: {feat_f0_unfold}, feat_f1_unfold_max: {feat_f1_unfold}")
193
+ feat_f0_unfold[torch.isnan(feat_f0_unfold)] = 0
194
+ feat_f1_unfold[torch.isnan(feat_f1_unfold)] = 0
195
+ logger.info(f"feat_f0_unfold_nanmax: {feat_f0_unfold}, feat_f1_unfold_nanmax: {feat_f1_unfold}")
196
+
197
+ if self.config['fp16log'] and feat_c0 is not None:
198
+ logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}")
199
+ del feat_c0, feat_c1, mask_c0, mask_c1
200
+ if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
201
+ if self.config['fine']['pan']:
202
+ m, ww, c = feat_f0_unfold.size() # [m, ww, c]
203
+ w = self.config['fine_window_size']
204
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold.reshape(m, c, w, w), feat_f1_unfold.reshape(m, c, w, w))
205
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'm c w h -> m (w h) c')
206
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'm c w h -> m (w h) c')
207
+ elif self.config['fine']['skip']:
208
+ pass
209
+ else:
210
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
211
+ # 5. match fine-level
212
+ # log forward nan
213
+ if self.config['fp16log']:
214
+ if feat_f0_unfold.size(0) != 0 and feat_f0 is not None:
215
+ logger.info(f"f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}, uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}")
216
+ elif feat_f0_unfold.size(0) != 0:
217
+ logger.info(f"uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}")
218
+ # elif feat_c0 is not None:
219
+ # logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}")
220
+
221
+ with torch.autocast(enabled=False, device_type="cuda"):
222
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
223
+
224
+ return data
225
+
226
+ def load_state_dict(self, state_dict, *args, **kwargs):
227
+ for k in list(state_dict.keys()):
228
+ if k.startswith('matcher.'):
229
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
230
+ return super().load_state_dict(state_dict, *args, **kwargs)
231
+
232
+ def refine(self, data):
233
+ """
234
+ Update:
235
+ data (dict): {
236
+ 'image0': (torch.Tensor): (N, 1, H, W)
237
+ 'image1': (torch.Tensor): (N, 1, H, W)
238
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
239
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
240
+ }
241
+ """
242
+ # 1. Local Feature CNN
243
+ data.update({
244
+ 'bs': data['image0'].size(0),
245
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
246
+ })
247
+ feat_f0, feat_f1 = None, None
248
+ feat_c0, feat_c1 = data['feat_c0'], data['feat_c1']
249
+ # 4. fine-level refinement
250
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
251
+ if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
252
+ if self.config['fine']['pan']:
253
+ m, ww, c = feat_f0_unfold.size() # [m, ww, c]
254
+ w = self.config['fine_window_size']
255
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold.reshape(m, c, w, w), feat_f1_unfold.reshape(m, c, w, w))
256
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'm c w h -> m (w h) c')
257
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'm c w h -> m (w h) c')
258
+ elif self.config['fine']['skip']:
259
+ pass
260
+ else:
261
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
262
+ # 5. match fine-level
263
+ # log forward nan
264
+ if self.config['fp16log']:
265
+ if feat_f0_unfold.size(0) != 0 and feat_f0 is not None and feat_c0 is not None:
266
+ logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}, f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}, uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}")
267
+ elif feat_f0 is not None and feat_c0 is not None:
268
+ logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}, f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}")
269
+ elif feat_c0 is not None:
270
+ logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}")
271
+
272
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
273
+ return data
imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transformer import LocalFeatureTransformer
2
+ from .fine_preprocess import FinePreprocess
imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops.einops import rearrange, repeat
5
+ from ..backbone.repvgg import RepVGGBlock
6
+
7
+ from loguru import logger
8
+
9
+ def conv1x1(in_planes, out_planes, stride=1):
10
+ """1x1 convolution without padding"""
11
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
17
+
18
+ class FinePreprocess(nn.Module):
19
+ def __init__(self, config):
20
+ super().__init__()
21
+
22
+ self.config = config
23
+ self.cat_c_feat = config['fine_concat_coarse_feat']
24
+ self.sample_c_feat = config['fine_sample_coarse_feat']
25
+ self.fpn_inter_feat = config['inter_feat']
26
+ self.rep_fpn = config['rep_fpn']
27
+ self.deploy = config['rep_deploy']
28
+ self.multi_regress = config['match_fine']['multi_regress']
29
+ self.local_regress = config['match_fine']['local_regress']
30
+ self.local_regress_inner = config['match_fine']['local_regress_inner']
31
+ block_dims = config['resnetfpn']['block_dims']
32
+
33
+ self.mtd_spvs = self.config['fine']['mtd_spvs']
34
+ self.align_corner = self.config['align_corner']
35
+ self.fix_bias = self.config['fix_bias']
36
+
37
+ if self.mtd_spvs:
38
+ self.W = self.config['fine_window_size']
39
+ else:
40
+ # assert False, 'fine_window_matching_size to be revised' # good notification!
41
+ # self.W = self.config['fine_window_matching_size']
42
+ self.W = self.config['fine_window_size']
43
+
44
+ self.backbone_type = self.config['backbone_type']
45
+
46
+ d_model_c = self.config['coarse']['d_model']
47
+ d_model_f = self.config['fine']['d_model']
48
+ self.d_model_f = d_model_f
49
+ if self.fpn_inter_feat:
50
+ if self.rep_fpn:
51
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
52
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
53
+ self.layer2_outconv2 = []
54
+ self.layer2_outconv2.append(RepVGGBlock(in_channels=block_dims[2], out_channels=block_dims[2], kernel_size=3,
55
+ stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=0.01))
56
+ self.layer2_outconv2.append(RepVGGBlock(in_channels=block_dims[2], out_channels=block_dims[1], kernel_size=3,
57
+ stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=-2))
58
+ self.layer2_outconv2 = nn.ModuleList(self.layer2_outconv2)
59
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
60
+ self.layer1_outconv2 = []
61
+ self.layer1_outconv2.append(RepVGGBlock(in_channels=block_dims[1], out_channels=block_dims[1], kernel_size=3,
62
+ stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=0.01))
63
+ self.layer1_outconv2.append(RepVGGBlock(in_channels=block_dims[1], out_channels=block_dims[0], kernel_size=3,
64
+ stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=-2))
65
+ self.layer1_outconv2 = nn.ModuleList(self.layer1_outconv2)
66
+
67
+ else:
68
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
69
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
70
+ self.layer2_outconv2 = nn.Sequential(
71
+ conv3x3(block_dims[2], block_dims[2]),
72
+ nn.BatchNorm2d(block_dims[2]),
73
+ nn.LeakyReLU(),
74
+ conv3x3(block_dims[2], block_dims[1]),
75
+ )
76
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
77
+ self.layer1_outconv2 = nn.Sequential(
78
+ conv3x3(block_dims[1], block_dims[1]),
79
+ nn.BatchNorm2d(block_dims[1]),
80
+ nn.LeakyReLU(),
81
+ conv3x3(block_dims[1], block_dims[0]),
82
+ )
83
+ elif self.cat_c_feat:
84
+ self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
85
+ self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
86
+ if self.sample_c_feat:
87
+ self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
88
+
89
+
90
+ self._reset_parameters()
91
+
92
+ def _reset_parameters(self):
93
+ for p in self.parameters():
94
+ if p.dim() > 1:
95
+ nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
96
+
97
+ def inter_fpn(self, feat_c, x2, x1, stride):
98
+ feat_c = self.layer3_outconv(feat_c)
99
+ feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False)
100
+ x2 = self.layer2_outconv(x2)
101
+ if self.rep_fpn:
102
+ x2 = x2 + feat_c
103
+ for layer in self.layer2_outconv2:
104
+ x2 = layer(x2)
105
+ else:
106
+ x2 = self.layer2_outconv2(x2+feat_c)
107
+
108
+ x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False)
109
+ x1 = self.layer1_outconv(x1)
110
+ if self.rep_fpn:
111
+ x1 = x1 + x2
112
+ for layer in self.layer1_outconv2:
113
+ x1 = layer(x1)
114
+ else:
115
+ x1 = self.layer1_outconv2(x1+x2)
116
+
117
+ if stride == 4:
118
+ logger.info('stride == 4')
119
+
120
+ elif stride == 8:
121
+ logger.info('stride == 8')
122
+ x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False)
123
+ else:
124
+ logger.info('stride not in {4,8}')
125
+ assert False
126
+ return x1
127
+
128
+ def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
129
+ W = self.W
130
+ if self.fix_bias:
131
+ stride = 4
132
+ else:
133
+ stride = data['hw0_f'][0] // data['hw0_c'][0]
134
+
135
+ data.update({'W': W})
136
+ if data['b_ids'].shape[0] == 0:
137
+ feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_c0.device)
138
+ feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_c0.device)
139
+ # return feat0, feat1
140
+ return feat0.float(), feat1.float()
141
+
142
+ if self.fpn_inter_feat:
143
+ if data['hw0_i'] != data['hw1_i']:
144
+ if self.align_corner is False:
145
+ assert self.backbone_type != 's2dnet'
146
+
147
+ feat_c0 = rearrange(feat_c0, 'b (h w) c -> b c h w', h=data['hw0_c'][0])
148
+ feat_c1 = rearrange(feat_c1, 'b (h w) c -> b c h w', h=data['hw1_c'][0])
149
+ x2_0, x1_0 = data['feats_x2_0'], data['feats_x1_0']
150
+ x2_1, x1_1 = data['feats_x2_1'], data['feats_x1_1']
151
+ del data['feats_x2_0'], data['feats_x1_0'], data['feats_x2_1'], data['feats_x1_1']
152
+ feat_f0, feat_f1 = self.inter_fpn(feat_c0, x2_0, x1_0, stride), self.inter_fpn(feat_c1, x2_1, x1_1, stride)
153
+
154
+ if self.local_regress_inner:
155
+ assert W == 8
156
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
157
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
158
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1)
159
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2)
160
+ elif W == 10 and self.multi_regress:
161
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
162
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
163
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
164
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
165
+ elif W == 10:
166
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
167
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
168
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
169
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
170
+ else:
171
+ assert not self.multi_regress
172
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
173
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
174
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0)
175
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
176
+
177
+ # 2. select only the predicted matches
178
+ feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf]
179
+ feat_f1 = feat_f1[data['b_ids'], data['j_ids']]
180
+
181
+ return feat_f0, feat_f1
182
+
183
+ else:
184
+ if self.align_corner is False:
185
+ feat_c = torch.cat([feat_c0, feat_c1], 0)
186
+ feat_c = rearrange(feat_c, 'b (h w) c -> b c h w', h=data['hw0_c'][0]) # 1/8 256
187
+ x2 = data['feats_x2'].float() # 1/4 128
188
+ x1 = data['feats_x1'].float() # 1/2 64
189
+ del data['feats_x2'], data['feats_x1']
190
+ assert self.backbone_type != 's2dnet'
191
+ feat_c = self.layer3_outconv(feat_c)
192
+ feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False)
193
+ x2 = self.layer2_outconv(x2)
194
+ if self.rep_fpn:
195
+ x2 = x2 + feat_c
196
+ for layer in self.layer2_outconv2:
197
+ x2 = layer(x2)
198
+ else:
199
+ x2 = self.layer2_outconv2(x2+feat_c)
200
+
201
+ x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False)
202
+ x1 = self.layer1_outconv(x1)
203
+ if self.rep_fpn:
204
+ x1 = x1 + x2
205
+ for layer in self.layer1_outconv2:
206
+ x1 = layer(x1)
207
+ else:
208
+ x1 = self.layer1_outconv2(x1+x2)
209
+
210
+ if stride == 4:
211
+ # logger.info('stride == 4')
212
+ pass
213
+ elif stride == 8:
214
+ # logger.info('stride == 8')
215
+ x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False)
216
+ else:
217
+ # logger.info('stride not in {4,8}')
218
+ assert False
219
+
220
+ feat_f0, feat_f1 = torch.chunk(x1, 2, dim=0)
221
+
222
+ # 1. unfold(crop) all local windows
223
+ if self.local_regress_inner:
224
+ assert W == 8
225
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
226
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
227
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1)
228
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2)
229
+ elif self.multi_regress or (self.local_regress and W == 10):
230
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
231
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
232
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
233
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
234
+ elif W == 10:
235
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
236
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
237
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
238
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
239
+
240
+ else:
241
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
242
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
243
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0)
244
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
245
+
246
+ # 2. select only the predicted matches
247
+ feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf]
248
+ feat_f1 = feat_f1[data['b_ids'], data['j_ids']]
249
+
250
+ return feat_f0, feat_f1
251
+ elif self.fix_bias:
252
+ feat_c = torch.cat([feat_c0, feat_c1], 0)
253
+ feat_c = rearrange(feat_c, 'b (h w) c -> b c h w', h=data['hw0_c'][0])
254
+ x2 = data['feats_x2'].float()
255
+ x1 = data['feats_x1'].float()
256
+ assert self.backbone_type != 's2dnet'
257
+ x3_out = self.layer3_outconv(feat_c)
258
+ x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=False)
259
+ x2 = self.layer2_outconv(x2)
260
+ x2 = self.layer2_outconv2(x2+x3_out_2x)
261
+
262
+ x2 = F.interpolate(x2, size=((x2.size(-2)-1)*2+1, (x2.size(-1)-1)*2+1), mode='bilinear', align_corners=False)
263
+ x1_out = self.layer1_outconv(x1)
264
+ x1_out = self.layer1_outconv2(x1_out+x2)
265
+ x0_out = x1_out
266
+
267
+ feat_f0, feat_f1 = torch.chunk(x0_out, 2, dim=0)
268
+
269
+ # 1. unfold(crop) all local windows
270
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
271
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
272
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
273
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
274
+
275
+ # 2. select only the predicted matches
276
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
277
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
278
+
279
+ return feat_f0_unfold, feat_f1_unfold
280
+
281
+
282
+
283
+ elif self.sample_c_feat:
284
+ if self.align_corner is False:
285
+ # easy implemented but memory consuming
286
+ feat_c = self.down_proj(torch.cat([feat_c0,
287
+ feat_c1], 0)) # [n, (h w), c] -> [2n, (h w), cf]
288
+ feat_c = rearrange(feat_c, 'n (h w) c -> n c h w', h=data['hw0_c'][0], w=data['hw0_c'][1])
289
+ feat_f = F.interpolate(feat_c, scale_factor=8., mode='bilinear', align_corners=False) # [2n, cf, hf, wf]
290
+ feat_f_unfold = F.unfold(feat_f, kernel_size=(W, W), stride=stride, padding=0)
291
+ feat_f_unfold = rearrange(feat_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
292
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_f_unfold, 2, dim=0)
293
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [m, ww, cf]
294
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] # [m, ww, cf]
295
+ # return feat_f0_unfold, feat_f1_unfold
296
+ return feat_f0_unfold.float(), feat_f1_unfold.float()
297
+ else:
298
+ if self.align_corner is False:
299
+ # 1. unfold(crop) all local windows
300
+ assert False, 'maybe exist bugs'
301
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
302
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
303
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0)
304
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
305
+
306
+ # 2. select only the predicted matches
307
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
308
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
309
+
310
+ # option: use coarse-level loftr feature as context: concat and linear
311
+ if self.cat_c_feat:
312
+ feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
313
+ feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
314
+ feat_cf_win = self.merge_feat(torch.cat([
315
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
316
+ repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
317
+ ], -1))
318
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
319
+
320
+ return feat_f0_unfold, feat_f1_unfold
321
+
322
+ else:
323
+ # 1. unfold(crop) all local windows
324
+ if self.fix_bias:
325
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
326
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
327
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
328
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
329
+ else:
330
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
331
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
332
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
333
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
334
+
335
+ # 2. select only the predicted matches
336
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
337
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
338
+
339
+ # option: use coarse-level loftr feature as context: concat and linear
340
+ if self.cat_c_feat:
341
+ feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
342
+ feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
343
+ feat_cf_win = self.merge_feat(torch.cat([
344
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
345
+ repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
346
+ ], -1))
347
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
348
+
349
+ # return feat_f0_unfold, feat_f1_unfold
350
+ return feat_f0_unfold.float(), feat_f1_unfold.float()
imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
3
+ Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
4
+ """
5
+
6
+ import torch
7
+ from torch.nn import Module, Dropout
8
+ import torch.nn.functional as F
9
+
10
+ # if hasattr(F, 'scaled_dot_product_attention'):
11
+ # FLASH_AVAILABLE = True
12
+ # else: # v100
13
+ FLASH_AVAILABLE = False
14
+ # import xformers.ops
15
+ from ..utils.position_encoding import PositionEncodingSine, RoPEPositionEncodingSine
16
+ from einops.einops import rearrange
17
+ from loguru import logger
18
+
19
+
20
+ # flash_attn_func_ok = True
21
+ # try:
22
+ # from flash_attn import flash_attn_func
23
+ # except ModuleNotFoundError:
24
+ # flash_attn_func_ok = False
25
+
26
+ def elu_feature_map(x):
27
+ return torch.nn.functional.elu(x) + 1
28
+
29
+
30
+ class LinearAttention(Module):
31
+ def __init__(self, eps=1e-6):
32
+ super().__init__()
33
+ self.feature_map = elu_feature_map
34
+ self.eps = eps
35
+
36
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
37
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
38
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
39
+ Args:
40
+ queries: [N, L, H, D]
41
+ keys: [N, S, H, D]
42
+ values: [N, S, H, D]
43
+ q_mask: [N, L]
44
+ kv_mask: [N, S]
45
+ Returns:
46
+ queried_values: (N, L, H, D)
47
+ """
48
+ Q = self.feature_map(queries)
49
+ K = self.feature_map(keys)
50
+
51
+ # set padded position to zero
52
+ if q_mask is not None:
53
+ Q = Q * q_mask[:, :, None, None]
54
+ if kv_mask is not None:
55
+ K = K * kv_mask[:, :, None, None]
56
+ values = values * kv_mask[:, :, None, None]
57
+
58
+ v_length = values.size(1)
59
+ values = values / v_length # prevent fp16 overflow
60
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
61
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
62
+ # queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
63
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
64
+
65
+ return queried_values.contiguous()
66
+
67
+ class RoPELinearAttention(Module):
68
+ def __init__(self, eps=1e-6):
69
+ super().__init__()
70
+ self.feature_map = elu_feature_map
71
+ self.eps = eps
72
+ self.RoPE = RoPEPositionEncodingSine(256, max_shape=(256, 256))
73
+
74
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
75
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None, H=None, W=None):
76
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
77
+ Args:
78
+ queries: [N, L, H, D]
79
+ keys: [N, S, H, D]
80
+ values: [N, S, H, D]
81
+ q_mask: [N, L]
82
+ kv_mask: [N, S]
83
+ Returns:
84
+ queried_values: (N, L, H, D)
85
+ """
86
+ Q = self.feature_map(queries)
87
+ K = self.feature_map(keys)
88
+ nhead, d = Q.size(2), Q.size(3)
89
+ # set padded position to zero
90
+ if q_mask is not None:
91
+ Q = Q * q_mask[:, :, None, None]
92
+ if kv_mask is not None:
93
+ K = K * kv_mask[:, :, None, None]
94
+ values = values * kv_mask[:, :, None, None]
95
+
96
+ v_length = values.size(1)
97
+ values = values / v_length # prevent fp16 overflow
98
+ # Q = Q / Q.size(1)
99
+ # logger.info(f"Q: {Q.dtype}, K: {K.dtype}, values: {values.dtype}")
100
+
101
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
102
+ # logger.info(f"Z_max: {Z.abs().max()}")
103
+ Q = rearrange(Q, 'n (h w) nhead d -> n h w (nhead d)', h=H, w=W)
104
+ K = rearrange(K, 'n (h w) nhead d -> n h w (nhead d)', h=H, w=W)
105
+ Q, K = self.RoPE(Q), self.RoPE(K)
106
+ # logger.info(f"Q_rope: {Q.abs().max()}, K_rope: {K.abs().max()}")
107
+ Q = rearrange(Q, 'n h w (nhead d) -> n (h w) nhead d', nhead=nhead, d=d)
108
+ K = rearrange(K, 'n h w (nhead d) -> n (h w) nhead d', nhead=nhead, d=d)
109
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
110
+ del K, values
111
+ # logger.info(f"KV_max: {KV.abs().max()}")
112
+ # queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
113
+ # Q = torch.einsum("nlhd,nlh->nlhd", Q, Z)
114
+ # logger.info(f"QZ_max: {Q.abs().max()}")
115
+ # queried_values = torch.einsum("nlhd,nhdv->nlhv", Q, KV) * v_length
116
+ # logger.info(f"message_max: {queried_values.abs().max()}")
117
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
118
+
119
+ return queried_values.contiguous()
120
+
121
+
122
+ class FullAttention(Module):
123
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
124
+ super().__init__()
125
+ self.use_dropout = use_dropout
126
+ self.dropout = Dropout(attention_dropout)
127
+
128
+ # @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
129
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
130
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
131
+ Args:
132
+ queries: [N, L, H, D]
133
+ keys: [N, S, H, D]
134
+ values: [N, S, H, D]
135
+ q_mask: [N, L]
136
+ kv_mask: [N, S]
137
+ Returns:
138
+ queried_values: (N, L, H, D)
139
+ """
140
+ # assert kv_mask is None
141
+ # mask = torch.zeros(queries.size(0)*queries.size(2), queries.size(1), keys.size(1), device=queries.device)
142
+ # mask.masked_fill(~(q_mask[:, :, None] * kv_mask[:, None, :]), float('-inf'))
143
+ # if keys.size(1) % 8 != 0:
144
+ # mask = torch.cat([mask, torch.zeros(queries.size(0)*queries.size(2), queries.size(1), 8-keys.size(1)%8, device=queries.device)], dim=-1)
145
+ # out = xformers.ops.memory_efficient_attention(queries, keys, values, attn_bias=mask[...,:keys.size(1)])
146
+ # return out
147
+
148
+ # N = queries.size(0)
149
+ # list_q = [queries[i, :q_mask[i].sum, ...] for i in N]
150
+ # list_k = [keys[i, :kv_mask[i].sum, ...] for i in N]
151
+ # list_v = [values[i, :kv_mask[i].sum, ...] for i in N]
152
+ # assert N == 1
153
+ # out = xformers.ops.memory_efficient_attention(queries[:,:q_mask.sum(),...], keys[:,:kv_mask.sum(),...], values[:,:kv_mask.sum(),...])
154
+ # out = torch.cat([out, torch.zeros(out.size(0), queries.size(1)-q_mask.sum(), queries.size(2), queries.size(3), device=queries.device)], dim=1)
155
+ # return out
156
+ # Compute the unnormalized attention and apply the masks
157
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
158
+ if kv_mask is not None:
159
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), -1e5) # float('-inf')
160
+
161
+ # Compute the attention and the weighted average
162
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
163
+ A = torch.softmax(softmax_temp * QK, dim=2)
164
+ if self.use_dropout:
165
+ A = self.dropout(A)
166
+
167
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
168
+
169
+ return queried_values.contiguous()
170
+
171
+
172
+ class XAttention(Module):
173
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
174
+ super().__init__()
175
+ self.use_dropout = use_dropout
176
+ if use_dropout:
177
+ self.dropout = Dropout(attention_dropout)
178
+
179
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
180
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
181
+ Args:
182
+ if FLASH_AVAILABLE: # pytorch scaled_dot_product_attention
183
+ queries: [N, H, L, D]
184
+ keys: [N, H, S, D]
185
+ values: [N, H, S, D]
186
+ else:
187
+ queries: [N, L, H, D]
188
+ keys: [N, S, H, D]
189
+ values: [N, S, H, D]
190
+ q_mask: [N, L]
191
+ kv_mask: [N, S]
192
+ Returns:
193
+ queried_values: (N, L, H, D)
194
+ """
195
+
196
+ assert q_mask is None and kv_mask is None, "already been sliced"
197
+ if FLASH_AVAILABLE:
198
+ # args = [x.half().contiguous() for x in [queries, keys, values]]
199
+ # out = F.scaled_dot_product_attention(*args, attn_mask=mask).to(queries.dtype)
200
+ args = [x.contiguous() for x in [queries, keys, values]]
201
+ out = F.scaled_dot_product_attention(*args)
202
+ else:
203
+ # if flash_attn_func_ok:
204
+ # out = flash_attn_func(queries, keys, values)
205
+ # else:
206
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
207
+
208
+ # Compute the attention and the weighted average
209
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
210
+ A = torch.softmax(softmax_temp * QK, dim=2)
211
+
212
+ out = torch.einsum("nlsh,nshd->nlhd", A, values)
213
+
214
+ # out = xformers.ops.memory_efficient_attention(queries, keys, values)
215
+ # out = xformers.ops.memory_efficient_attention(queries[:,:q_mask.sum(),...], keys[:,:kv_mask.sum(),...], values[:,:kv_mask.sum(),...])
216
+ # out = torch.cat([out, torch.zeros(out.size(0), queries.size(1)-q_mask.sum(), queries.size(2), queries.size(3), device=queries.device)], dim=1)
217
+ return out
imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py ADDED
@@ -0,0 +1,1768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .linear_attention import LinearAttention, RoPELinearAttention, FullAttention, XAttention
6
+ from einops.einops import rearrange
7
+ from collections import OrderedDict
8
+ from .transformer_utils import TokenConfidence, MatchAssignment, filter_matches
9
+ from ..utils.coarse_matching import CoarseMatching
10
+ from ..utils.position_encoding import RoPEPositionEncodingSine
11
+ import numpy as np
12
+ from loguru import logger
13
+
14
+ PFLASH_AVAILABLE = False
15
+
16
+ class PANEncoderLayer(nn.Module):
17
+ def __init__(self,
18
+ d_model,
19
+ nhead,
20
+ attention='linear',
21
+ pool_size=4,
22
+ bn=True,
23
+ xformer=False,
24
+ leaky=-1.0,
25
+ dw_conv=False,
26
+ scatter=False,
27
+ ):
28
+ super(PANEncoderLayer, self).__init__()
29
+
30
+ self.pool_size = pool_size
31
+ self.dw_conv = dw_conv
32
+ self.scatter = scatter
33
+ if self.dw_conv:
34
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
35
+
36
+ assert not self.scatter, 'buggy implemented here'
37
+ self.dim = d_model // nhead
38
+ self.nhead = nhead
39
+
40
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
41
+ # multi-head attention
42
+ if bn:
43
+ method = 'dw_bn'
44
+ else:
45
+ method = 'dw'
46
+ self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
47
+ self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
48
+ self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
49
+
50
+ # self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
51
+ # self.k_proj = nn.Linear(d_model, d_model, bias=False)
52
+ # self.v_proj = nn.Linear(d_model, d_model, bias=False)
53
+ if xformer:
54
+ self.attention = XAttention()
55
+ else:
56
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
57
+ self.merge = nn.Linear(d_model, d_model, bias=False)
58
+
59
+ # feed-forward network
60
+ if leaky > 0:
61
+ self.mlp = nn.Sequential(
62
+ nn.Linear(d_model*2, d_model*2, bias=False),
63
+ nn.LeakyReLU(leaky, True),
64
+ nn.Linear(d_model*2, d_model, bias=False),
65
+ )
66
+
67
+ else:
68
+ self.mlp = nn.Sequential(
69
+ nn.Linear(d_model*2, d_model*2, bias=False),
70
+ nn.ReLU(True),
71
+ nn.Linear(d_model*2, d_model, bias=False),
72
+ )
73
+
74
+ # norm and dropout
75
+ self.norm1 = nn.LayerNorm(d_model)
76
+ self.norm2 = nn.LayerNorm(d_model)
77
+
78
+ # self.norm1 = nn.BatchNorm2d(d_model)
79
+
80
+ def forward(self, x, source, x_mask=None, source_mask=None):
81
+ """
82
+ Args:
83
+ x (torch.Tensor): [N, C, H1, W1]
84
+ source (torch.Tensor): [N, C, H2, W2]
85
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
86
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
87
+ """
88
+ bs = x.size(0)
89
+ H1, W1 = x.size(-2), x.size(-1)
90
+ H2, W2 = source.size(-2), source.size(-1)
91
+
92
+ query, key, value = x, source, source
93
+
94
+ if self.dw_conv:
95
+ query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
96
+ else:
97
+ query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
98
+ # only need to cal key or value...
99
+ key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
100
+ value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
101
+
102
+ # After 0617 bnorm to prevent permute*6
103
+ # query = self.norm1(self.max_pool(query))
104
+ # key = self.norm1(self.max_pool(key))
105
+ # value = self.norm1(self.max_pool(value))
106
+ # multi-head attention
107
+ query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
108
+ key = self.k_proj_conv(key)
109
+ value = self.v_proj_conv(value)
110
+
111
+ C = query.shape[-3]
112
+
113
+ ismask = x_mask is not None and source_mask is not None
114
+ if bs == 1 or not ismask:
115
+ if ismask:
116
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
117
+ source_mask = self.max_pool(source_mask.float()).bool()
118
+
119
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
120
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
121
+
122
+ query = query[:, :, :mask_h0, :mask_w0]
123
+ key = key[:, :, :mask_h1, :mask_w1]
124
+ value = value[:, :, :mask_h1, :mask_w1]
125
+
126
+ else:
127
+ assert x_mask is None and source_mask is None
128
+
129
+ # query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
130
+ # key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
131
+ # value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
132
+ if PFLASH_AVAILABLE: # N H L D
133
+ query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
134
+ key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
135
+ value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
136
+
137
+ else: # N L H D
138
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
139
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
140
+ value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
141
+
142
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
143
+
144
+ if PFLASH_AVAILABLE: # N H L D
145
+ message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
146
+
147
+ if ismask:
148
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
149
+ if mask_h0 != x_mask.size(-2):
150
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
151
+ elif mask_w0 != x_mask.size(-1):
152
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
153
+ # message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
154
+
155
+ else:
156
+ assert x_mask is None and source_mask is None
157
+
158
+
159
+ message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
160
+ # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
161
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
162
+
163
+ if self.scatter:
164
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
165
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
166
+ # message = self.aggregate(message)
167
+ message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
168
+ else:
169
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
170
+
171
+ # message = self.norm1(message)
172
+
173
+ # feed-forward network
174
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
175
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
176
+
177
+ return x + message
178
+ else:
179
+ x_mask = self.max_pool(x_mask.float()).bool()
180
+ source_mask = self.max_pool(source_mask.float()).bool()
181
+ m_list = []
182
+ for i in range(bs):
183
+ mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
184
+ mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
185
+
186
+ q = query[i:i+1, :, :mask_h0, :mask_w0]
187
+ k = key[i:i+1, :, :mask_h1, :mask_w1]
188
+ v = value[i:i+1, :, :mask_h1, :mask_w1]
189
+
190
+ if PFLASH_AVAILABLE: # N H L D
191
+ q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
192
+ k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
193
+ v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
194
+
195
+ else: # N L H D
196
+
197
+ q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
198
+ k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
199
+ v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
200
+
201
+ m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
202
+
203
+ if PFLASH_AVAILABLE: # N H L D
204
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
205
+
206
+ m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
207
+ if mask_h0 != x_mask.size(-2):
208
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
209
+ elif mask_w0 != x_mask.size(-1):
210
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
211
+ m_list.append(m)
212
+ message = torch.cat(m_list, dim=0)
213
+
214
+
215
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
216
+ # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
217
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
218
+
219
+ if self.scatter:
220
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
221
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
222
+ # message = self.aggregate(message)
223
+ # assert False
224
+ else:
225
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
226
+
227
+ # message = self.norm1(message)
228
+
229
+ # feed-forward network
230
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
231
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
232
+
233
+ return x + message
234
+
235
+
236
+ def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
237
+ """
238
+ Args:
239
+ x (torch.Tensor): [N, C, H1, W1]
240
+ source (torch.Tensor): [N, C, H2, W2]
241
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
242
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
243
+ """
244
+ bs = x.size(0)
245
+ H1, W1 = x.size(-2), x.size(-1)
246
+ H2, W2 = source.size(-2), source.size(-1)
247
+
248
+ query, key, value = x, source, source
249
+
250
+ with profiler.profile("permute*6+norm1*3+max_pool*3"):
251
+ if self.dw_conv:
252
+ query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
253
+ else:
254
+ query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
255
+ # only need to cal key or value...
256
+ key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
257
+ value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
258
+
259
+ with profiler.profile("permute*6"):
260
+ query = query.permute(0, 2, 3, 1)
261
+ key = key.permute(0, 2, 3, 1)
262
+ value = value.permute(0, 2, 3, 1)
263
+
264
+ query = query.permute(0,3,1,2)
265
+ key = key.permute(0,3,1,2)
266
+ value = value.permute(0,3,1,2)
267
+
268
+ # query = self.bnorm1(self.max_pool(query))
269
+ # key = self.bnorm1(self.max_pool(key))
270
+ # value = self.bnorm1(self.max_pool(value))
271
+ # multi-head attention
272
+
273
+ with profiler.profile("q_conv+k_conv+v_conv"):
274
+ query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
275
+ key = self.k_proj_conv(key)
276
+ value = self.v_proj_conv(value)
277
+
278
+ C = query.shape[-3]
279
+ # TODO: Need to be consistent with bs=1 (where mask region do not in attention at all)
280
+ if x_mask is not None and source_mask is not None:
281
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
282
+ source_mask = self.max_pool(source_mask.float()).bool()
283
+
284
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
285
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
286
+
287
+ query = query[:, :, :mask_h0, :mask_w0]
288
+ key = key[:, :, :mask_h1, :mask_w1]
289
+ value = value[:, :, :mask_h1, :mask_w1]
290
+
291
+ # mask_h0, mask_w0 = data['mask0'][0].sum(-2)[0], data['mask0'][0].sum(-1)[0]
292
+ # mask_h1, mask_w1 = data['mask1'][0].sum(-2)[0], data['mask1'][0].sum(-1)[0]
293
+ # C = feat_c0.shape[-3]
294
+ # feat_c0 = feat_c0[:, :, :mask_h0, :mask_w0]
295
+ # feat_c1 = feat_c1[:, :, :mask_h1, :mask_w1]
296
+
297
+
298
+ # feat_c0 = feat_c0.reshape(-1, mask_h0, mask_w0, C)
299
+ # feat_c1 = feat_c1.reshape(-1, mask_h1, mask_w1, C)
300
+ # if mask_h0 != data['mask0'].size(-2):
301
+ # feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0]-mask_h0, data['hw0_c'][1], C, device=feat_c0.device)], dim=1)
302
+ # elif mask_w0 != data['mask0'].size(-1):
303
+ # feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0], data['hw0_c'][1]-mask_w0, C, device=feat_c0.device)], dim=2)
304
+
305
+ # if mask_h1 != data['mask1'].size(-2):
306
+ # feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0]-mask_h1, data['hw1_c'][1], C, device=feat_c1.device)], dim=1)
307
+ # elif mask_w1 != data['mask1'].size(-1):
308
+ # feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0], data['hw1_c'][1]-mask_w1, C, device=feat_c1.device)], dim=2)
309
+
310
+
311
+ else:
312
+ assert x_mask is None and source_mask is None
313
+
314
+
315
+
316
+ # query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
317
+ # key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
318
+ # value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
319
+
320
+ with profiler.profile("rearrange*3"):
321
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
322
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
323
+ value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
324
+
325
+ with profiler.profile("attention"):
326
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D]
327
+
328
+ if x_mask is not None and source_mask is not None:
329
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
330
+ if mask_h0 != x_mask.size(-2):
331
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
332
+ elif mask_w0 != x_mask.size(-1):
333
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
334
+ # message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
335
+
336
+ else:
337
+ assert x_mask is None and source_mask is None
338
+
339
+ with profiler.profile("merge"):
340
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
341
+ # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
342
+
343
+ with profiler.profile("rearrange*1"):
344
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
345
+
346
+ with profiler.profile("upsample"):
347
+ if self.scatter:
348
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
349
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
350
+ # message = self.aggregate(message)
351
+ # assert False
352
+ else:
353
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
354
+
355
+ # message = self.norm1(message)
356
+
357
+ # feed-forward network
358
+ with profiler.profile("feed-forward_mlp+permute*2+norm2"):
359
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
360
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
361
+
362
+ return x + message
363
+
364
+
365
+ def _build_projection(self,
366
+ dim_in,
367
+ dim_out,
368
+ kernel_size=3,
369
+ padding=1,
370
+ stride=1,
371
+ method='dw_bn',
372
+ ):
373
+ if method == 'dw_bn':
374
+ proj = nn.Sequential(OrderedDict([
375
+ ('conv', nn.Conv2d(
376
+ dim_in,
377
+ dim_in,
378
+ kernel_size=kernel_size,
379
+ padding=padding,
380
+ stride=stride,
381
+ bias=False,
382
+ groups=dim_in
383
+ )),
384
+ ('bn', nn.BatchNorm2d(dim_in)),
385
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
386
+ ]))
387
+ elif method == 'avg':
388
+ proj = nn.Sequential(OrderedDict([
389
+ ('avg', nn.AvgPool2d(
390
+ kernel_size=kernel_size,
391
+ padding=padding,
392
+ stride=stride,
393
+ ceil_mode=True
394
+ )),
395
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
396
+ ]))
397
+ elif method == 'linear':
398
+ proj = None
399
+ elif method == 'dw':
400
+ proj = nn.Sequential(OrderedDict([
401
+ ('conv', nn.Conv2d(
402
+ dim_in,
403
+ dim_in,
404
+ kernel_size=kernel_size,
405
+ padding=padding,
406
+ stride=stride,
407
+ bias=False,
408
+ groups=dim_in
409
+ )),
410
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
411
+ ]))
412
+ else:
413
+ raise ValueError('Unknown method ({})'.format(method))
414
+
415
+ return proj
416
+
417
+ class AG_RoPE_EncoderLayer(nn.Module):
418
+ def __init__(self,
419
+ d_model,
420
+ nhead,
421
+ attention='linear',
422
+ pool_size=4,
423
+ pool_size2=4,
424
+ xformer=False,
425
+ leaky=-1.0,
426
+ dw_conv=False,
427
+ dw_conv2=False,
428
+ scatter=False,
429
+ norm_before=True,
430
+ rope=False,
431
+ npe=None,
432
+ vit_norm=False,
433
+ dw_proj=False,
434
+ ):
435
+ super(AG_RoPE_EncoderLayer, self).__init__()
436
+
437
+ self.pool_size = pool_size
438
+ self.pool_size2 = pool_size2
439
+ self.dw_conv = dw_conv
440
+ self.dw_conv2 = dw_conv2
441
+ self.scatter = scatter
442
+ self.norm_before = norm_before
443
+ self.vit_norm = vit_norm
444
+ self.dw_proj = dw_proj
445
+ self.rope = rope
446
+ if self.dw_conv and self.pool_size != 1:
447
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
448
+ if self.dw_conv2 and self.pool_size2 != 1:
449
+ self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size2, padding=0, stride=pool_size2, bias=False, groups=d_model)
450
+
451
+ self.dim = d_model // nhead
452
+ self.nhead = nhead
453
+
454
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size2, stride=self.pool_size2)
455
+
456
+ # multi-head attention
457
+ if self.dw_proj:
458
+ self.q_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
459
+ self.k_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
460
+ self.v_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
461
+ else:
462
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
463
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
464
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
465
+
466
+ if self.rope:
467
+ self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True)
468
+
469
+ if xformer:
470
+ self.attention = XAttention()
471
+ else:
472
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
473
+ self.merge = nn.Linear(d_model, d_model, bias=False)
474
+
475
+ # feed-forward network
476
+ if leaky > 0:
477
+ if self.vit_norm:
478
+ self.mlp = nn.Sequential(
479
+ nn.Linear(d_model, d_model*2, bias=False),
480
+ nn.LeakyReLU(leaky, True),
481
+ nn.Linear(d_model*2, d_model, bias=False),
482
+ )
483
+ else:
484
+ self.mlp = nn.Sequential(
485
+ nn.Linear(d_model*2, d_model*2, bias=False),
486
+ nn.LeakyReLU(leaky, True),
487
+ nn.Linear(d_model*2, d_model, bias=False),
488
+ )
489
+
490
+ else:
491
+ if self.vit_norm:
492
+ self.mlp = nn.Sequential(
493
+ nn.Linear(d_model, d_model*2, bias=False),
494
+ nn.ReLU(True),
495
+ nn.Linear(d_model*2, d_model, bias=False),
496
+ )
497
+ else:
498
+ self.mlp = nn.Sequential(
499
+ nn.Linear(d_model*2, d_model*2, bias=False),
500
+ nn.ReLU(True),
501
+ nn.Linear(d_model*2, d_model, bias=False),
502
+ )
503
+
504
+ # norm and dropout
505
+ self.norm1 = nn.LayerNorm(d_model)
506
+ self.norm2 = nn.LayerNorm(d_model)
507
+
508
+ # self.norm1 = nn.BatchNorm2d(d_model)
509
+
510
+ def forward(self, x, source, x_mask=None, source_mask=None):
511
+ """
512
+ Args:
513
+ x (torch.Tensor): [N, C, H1, W1]
514
+ source (torch.Tensor): [N, C, H2, W2]
515
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
516
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
517
+ """
518
+ bs, C, H1, W1 = x.size()
519
+ H2, W2 = source.size(-2), source.size(-1)
520
+
521
+
522
+ if self.norm_before and not self.vit_norm:
523
+ if self.pool_size == 1:
524
+ query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
525
+ elif self.dw_conv:
526
+ query = self.norm1(self.aggregate(x).permute(0,2,3,1)) # [N, H, W, C]
527
+ else:
528
+ query = self.norm1(self.max_pool(x).permute(0,2,3,1)) # [N, H, W, C]
529
+ if self.pool_size2 == 1:
530
+ source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
531
+ elif self.dw_conv2:
532
+ source = self.norm1(self.aggregate2(source).permute(0,2,3,1)) # [N, H, W, C]
533
+ else:
534
+ source = self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C]
535
+ elif self.vit_norm:
536
+ if self.pool_size == 1:
537
+ query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
538
+ elif self.dw_conv:
539
+ query = self.aggregate(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
540
+ else:
541
+ query = self.max_pool(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
542
+ if self.pool_size2 == 1:
543
+ source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
544
+ elif self.dw_conv2:
545
+ source = self.aggregate2(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
546
+ else:
547
+ source = self.max_pool(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
548
+ else:
549
+ if self.pool_size == 1:
550
+ query = x.permute(0,2,3,1) # [N, H, W, C]
551
+ elif self.dw_conv:
552
+ query = self.aggregate(x).permute(0,2,3,1) # [N, H, W, C]
553
+ else:
554
+ query = self.max_pool(x).permute(0,2,3,1) # [N, H, W, C]
555
+ if self.pool_size2 == 1:
556
+ source = source.permute(0,2,3,1) # [N, H, W, C]
557
+ elif self.dw_conv2:
558
+ source = self.aggregate2(source).permute(0,2,3,1) # [N, H, W, C]
559
+ else:
560
+ source = self.max_pool(source).permute(0,2,3,1) # [N, H, W, C]
561
+
562
+ # projection
563
+ if self.dw_proj:
564
+ query = self.q_proj(query.permute(0,3,1,2)).permute(0,2,3,1)
565
+ key = self.k_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
566
+ value = self.v_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
567
+ else:
568
+ query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source)
569
+
570
+ # RoPE
571
+ if self.rope:
572
+ query = self.rope_pos_enc(query)
573
+ if self.pool_size == 1 and self.pool_size2 == 4:
574
+ key = self.rope_pos_enc(key, 4)
575
+ else:
576
+ key = self.rope_pos_enc(key)
577
+
578
+ use_mask = x_mask is not None and source_mask is not None
579
+ if bs == 1 or not use_mask:
580
+ if use_mask:
581
+ # downsample mask
582
+ if self.pool_size ==1:
583
+ pass
584
+ else:
585
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
586
+
587
+ if self.pool_size2 ==1:
588
+ pass
589
+ else:
590
+ source_mask = self.max_pool(source_mask.float()).bool()
591
+
592
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
593
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
594
+
595
+ query = query[:, :mask_h0, :mask_w0, :]
596
+ key = key[:, :mask_h1, :mask_w1, :]
597
+ value = value[:, :mask_h1, :mask_w1, :]
598
+ else:
599
+ assert x_mask is None and source_mask is None
600
+
601
+ if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
602
+ query = rearrange(query, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
603
+ key = rearrange(key, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
604
+ value = rearrange(value, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
605
+ else: # N L H D
606
+ query = rearrange(query, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
607
+ key = rearrange(key, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
608
+ value = rearrange(value, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
609
+
610
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
611
+
612
+ if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
613
+ message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
614
+
615
+ if use_mask: # padding zero
616
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L h D]
617
+ if mask_h0 != x_mask.size(-2):
618
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
619
+ elif mask_w0 != x_mask.size(-1):
620
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
621
+ else:
622
+ assert x_mask is None and source_mask is None
623
+
624
+ message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
625
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
626
+
627
+ if self.pool_size == 1:
628
+ pass
629
+ else:
630
+ if self.scatter:
631
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
632
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
633
+ message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
634
+ else:
635
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
636
+
637
+ if not self.norm_before and not self.vit_norm:
638
+ message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
639
+
640
+ # feed-forward network
641
+ if self.vit_norm:
642
+ message_inter = (x + message)
643
+ del x
644
+ message = self.norm2(message_inter.permute(0, 2, 3, 1))
645
+ message = self.mlp(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
646
+ return message_inter + message
647
+ else:
648
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
649
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
650
+
651
+ return x + message
652
+ else: # mask with bs > 1
653
+ if self.pool_size ==1:
654
+ pass
655
+ else:
656
+ x_mask = self.max_pool(x_mask.float()).bool()
657
+
658
+ if self.pool_size2 ==1:
659
+ pass
660
+ else:
661
+ source_mask = self.max_pool(source_mask.float()).bool()
662
+ m_list = []
663
+ for i in range(bs):
664
+ mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
665
+ mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
666
+
667
+ q = query[i:i+1, :mask_h0, :mask_w0, :]
668
+ k = key[i:i+1, :mask_h1, :mask_w1, :]
669
+ v = value[i:i+1, :mask_h1, :mask_w1, :]
670
+
671
+ if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
672
+ q = rearrange(q, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
673
+ k = rearrange(k, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
674
+ v = rearrange(v, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
675
+ else: # N L H D
676
+ q = rearrange(q, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
677
+ k = rearrange(k, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
678
+ v = rearrange(v, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
679
+
680
+ m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
681
+
682
+ if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
683
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
684
+
685
+ m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
686
+ if mask_h0 != x_mask.size(-2):
687
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
688
+ elif mask_w0 != x_mask.size(-1):
689
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
690
+ m_list.append(m)
691
+ m = torch.cat(m_list, dim=0)
692
+
693
+ m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
694
+ # m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
695
+ m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
696
+
697
+ if self.pool_size == 1:
698
+ pass
699
+ else:
700
+ if self.scatter:
701
+ m = torch.repeat_interleave(m, self.pool_size, dim=-2)
702
+ m = torch.repeat_interleave(m, self.pool_size, dim=-1)
703
+ m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
704
+ else:
705
+ m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
706
+
707
+
708
+ if not self.norm_before and not self.vit_norm:
709
+ m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
710
+
711
+ # feed-forward network
712
+ if self.vit_norm:
713
+ m_inter = (x + m)
714
+ del x
715
+ m = self.norm2(m_inter.permute(0, 2, 3, 1))
716
+ m = self.mlp(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
717
+ return m_inter + m
718
+ else:
719
+ m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
720
+ m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
721
+
722
+ return x + m
723
+
724
+ return x + m
725
+
726
+ class AG_Conv_EncoderLayer(nn.Module):
727
+ def __init__(self,
728
+ d_model,
729
+ nhead,
730
+ attention='linear',
731
+ pool_size=4,
732
+ bn=True,
733
+ xformer=False,
734
+ leaky=-1.0,
735
+ dw_conv=False,
736
+ dw_conv2=False,
737
+ scatter=False,
738
+ norm_before=True,
739
+ ):
740
+ super(AG_Conv_EncoderLayer, self).__init__()
741
+
742
+ self.pool_size = pool_size
743
+ self.dw_conv = dw_conv
744
+ self.dw_conv2 = dw_conv2
745
+ self.scatter = scatter
746
+ self.norm_before = norm_before
747
+ if self.dw_conv:
748
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
749
+ if self.dw_conv2:
750
+ self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
751
+ self.dim = d_model // nhead
752
+ self.nhead = nhead
753
+
754
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
755
+
756
+ # multi-head attention
757
+ if bn:
758
+ method = 'dw_bn'
759
+ else:
760
+ method = 'dw'
761
+ self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
762
+ self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
763
+ self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
764
+
765
+ if xformer:
766
+ self.attention = XAttention()
767
+ else:
768
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
769
+ self.merge = nn.Linear(d_model, d_model, bias=False)
770
+
771
+ # feed-forward network
772
+ if leaky > 0:
773
+ self.mlp = nn.Sequential(
774
+ nn.Linear(d_model*2, d_model*2, bias=False),
775
+ nn.LeakyReLU(leaky, True),
776
+ nn.Linear(d_model*2, d_model, bias=False),
777
+ )
778
+
779
+ else:
780
+ self.mlp = nn.Sequential(
781
+ nn.Linear(d_model*2, d_model*2, bias=False),
782
+ nn.ReLU(True),
783
+ nn.Linear(d_model*2, d_model, bias=False),
784
+ )
785
+
786
+ # norm and dropout
787
+ self.norm1 = nn.LayerNorm(d_model)
788
+ self.norm2 = nn.LayerNorm(d_model)
789
+
790
+ def forward(self, x, source, x_mask=None, source_mask=None):
791
+ """
792
+ Args:
793
+ x (torch.Tensor): [N, C, H1, W1]
794
+ source (torch.Tensor): [N, C, H2, W2]
795
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
796
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
797
+ """
798
+ bs = x.size(0)
799
+ H1, W1 = x.size(-2), x.size(-1)
800
+ H2, W2 = source.size(-2), source.size(-1)
801
+ C = x.shape[-3]
802
+
803
+ if self.norm_before:
804
+ if self.dw_conv:
805
+ query = self.norm1(self.aggregate(x).permute(0,2,3,1)).permute(0,3,1,2)
806
+ else:
807
+ query = self.norm1(self.max_pool(x).permute(0,2,3,1)).permute(0,3,1,2)
808
+ if self.dw_conv2:
809
+ source = self.norm1(self.aggregate2(source).permute(0,2,3,1)).permute(0,3,1,2)
810
+ else:
811
+ source = self.norm1(self.max_pool(source).permute(0,2,3,1)).permute(0,3,1,2)
812
+ else:
813
+ if self.dw_conv:
814
+ query = self.aggregate(x)
815
+ else:
816
+ query = self.max_pool(x)
817
+ if self.dw_conv2:
818
+ source = self.aggregate2(source)
819
+ else:
820
+ source = self.max_pool(source)
821
+
822
+ key, value = source, source
823
+
824
+ query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
825
+ key = self.k_proj_conv(key)
826
+ value = self.v_proj_conv(value)
827
+
828
+ use_mask = x_mask is not None and source_mask is not None
829
+ if bs == 1 or not use_mask:
830
+ if use_mask:
831
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
832
+ source_mask = self.max_pool(source_mask.float()).bool()
833
+
834
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
835
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
836
+
837
+ query = query[:, :, :mask_h0, :mask_w0]
838
+ key = key[:, :, :mask_h1, :mask_w1]
839
+ value = value[:, :, :mask_h1, :mask_w1]
840
+
841
+ else:
842
+ assert x_mask is None and source_mask is None
843
+
844
+ if PFLASH_AVAILABLE: # N H L D
845
+ query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
846
+ key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
847
+ value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
848
+
849
+ else: # N L H D
850
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
851
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
852
+ value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
853
+
854
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
855
+
856
+ if PFLASH_AVAILABLE: # N H L D
857
+ message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
858
+
859
+ if use_mask: # padding zero
860
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L H D]
861
+ if mask_h0 != x_mask.size(-2):
862
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
863
+ elif mask_w0 != x_mask.size(-1):
864
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
865
+ else:
866
+ assert x_mask is None and source_mask is None
867
+
868
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
869
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
870
+
871
+ if self.scatter:
872
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
873
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
874
+ message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
875
+ else:
876
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
877
+
878
+ if not self.norm_before:
879
+ message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
880
+
881
+ # feed-forward network
882
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
883
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
884
+
885
+ return x + message
886
+ else: # mask with bs > 1
887
+ x_mask = self.max_pool(x_mask.float()).bool()
888
+ source_mask = self.max_pool(source_mask.float()).bool()
889
+ m_list = []
890
+ for i in range(bs):
891
+ mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
892
+ mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
893
+
894
+ q = query[i:i+1, :, :mask_h0, :mask_w0]
895
+ k = key[i:i+1, :, :mask_h1, :mask_w1]
896
+ v = value[i:i+1, :, :mask_h1, :mask_w1]
897
+
898
+ if PFLASH_AVAILABLE: # N H L D
899
+ q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
900
+ k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
901
+ v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
902
+
903
+ else: # N L H D
904
+ q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
905
+ k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
906
+ v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
907
+
908
+ m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
909
+
910
+ if PFLASH_AVAILABLE: # N H L D
911
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
912
+
913
+ m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
914
+ if mask_h0 != x_mask.size(-2):
915
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
916
+ elif mask_w0 != x_mask.size(-1):
917
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
918
+ m_list.append(m)
919
+ m = torch.cat(m_list, dim=0)
920
+
921
+ m = self.merge(m.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
922
+
923
+ # m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
924
+ m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
925
+
926
+ if self.scatter:
927
+ m = torch.repeat_interleave(m, self.pool_size, dim=-2)
928
+ m = torch.repeat_interleave(m, self.pool_size, dim=-1)
929
+ m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
930
+ else:
931
+ m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
932
+
933
+ if not self.norm_before:
934
+ m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
935
+
936
+ # feed-forward network
937
+ m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
938
+ m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
939
+
940
+ return x + m
941
+
942
+ def _build_projection(self,
943
+ dim_in,
944
+ dim_out,
945
+ kernel_size=3,
946
+ padding=1,
947
+ stride=1,
948
+ method='dw_bn',
949
+ ):
950
+ if method == 'dw_bn':
951
+ proj = nn.Sequential(OrderedDict([
952
+ ('conv', nn.Conv2d(
953
+ dim_in,
954
+ dim_in,
955
+ kernel_size=kernel_size,
956
+ padding=padding,
957
+ stride=stride,
958
+ bias=False,
959
+ groups=dim_in
960
+ )),
961
+ ('bn', nn.BatchNorm2d(dim_in)),
962
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
963
+ ]))
964
+ elif method == 'avg':
965
+ proj = nn.Sequential(OrderedDict([
966
+ ('avg', nn.AvgPool2d(
967
+ kernel_size=kernel_size,
968
+ padding=padding,
969
+ stride=stride,
970
+ ceil_mode=True
971
+ )),
972
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
973
+ ]))
974
+ elif method == 'linear':
975
+ proj = None
976
+ elif method == 'dw':
977
+ proj = nn.Sequential(OrderedDict([
978
+ ('conv', nn.Conv2d(
979
+ dim_in,
980
+ dim_in,
981
+ kernel_size=kernel_size,
982
+ padding=padding,
983
+ stride=stride,
984
+ bias=False,
985
+ groups=dim_in
986
+ )),
987
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
988
+ ]))
989
+ else:
990
+ raise ValueError('Unknown method ({})'.format(method))
991
+
992
+ return proj
993
+
994
+
995
+ class RoPELoFTREncoderLayer(nn.Module):
996
+ def __init__(self,
997
+ d_model,
998
+ nhead,
999
+ attention='linear',
1000
+ rope=False,
1001
+ token_mixer=None,
1002
+ ):
1003
+ super(RoPELoFTREncoderLayer, self).__init__()
1004
+
1005
+ self.dim = d_model // nhead
1006
+ self.nhead = nhead
1007
+
1008
+ # multi-head attention
1009
+ if token_mixer is None:
1010
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
1011
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
1012
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
1013
+
1014
+ self.rope = rope
1015
+ self.token_mixer = None
1016
+ if token_mixer is not None:
1017
+ self.token_mixer = token_mixer
1018
+ if token_mixer == 'dwcn':
1019
+ self.attention = nn.Sequential(OrderedDict([
1020
+ ('conv', nn.Conv2d(
1021
+ d_model,
1022
+ d_model,
1023
+ kernel_size=3,
1024
+ padding=1,
1025
+ stride=1,
1026
+ bias=False,
1027
+ groups=d_model
1028
+ )),
1029
+ ]))
1030
+ elif self.rope:
1031
+ assert attention == 'linear'
1032
+ self.attention = RoPELinearAttention()
1033
+
1034
+ if token_mixer is None:
1035
+ self.merge = nn.Linear(d_model, d_model, bias=False)
1036
+
1037
+ # feed-forward network
1038
+ if token_mixer is None:
1039
+ self.mlp = nn.Sequential(
1040
+ nn.Linear(d_model*2, d_model*2, bias=False),
1041
+ nn.ReLU(True),
1042
+ nn.Linear(d_model*2, d_model, bias=False),
1043
+ )
1044
+ else:
1045
+ self.mlp = nn.Sequential(
1046
+ nn.Linear(d_model, d_model, bias=False),
1047
+ nn.ReLU(True),
1048
+ nn.Linear(d_model, d_model, bias=False),
1049
+ )
1050
+ # norm and dropout
1051
+ self.norm1 = nn.LayerNorm(d_model)
1052
+ self.norm2 = nn.LayerNorm(d_model)
1053
+
1054
+ def forward(self, x, source, x_mask=None, source_mask=None, H=None, W=None):
1055
+ """
1056
+ Args:
1057
+ x (torch.Tensor): [N, L, C]
1058
+ source (torch.Tensor): [N, L, C]
1059
+ x_mask (torch.Tensor): [N, L] (optional)
1060
+ source_mask (torch.Tensor): [N, S] (optional)
1061
+ """
1062
+ bs = x.size(0)
1063
+ assert H*W == x.size(-2)
1064
+
1065
+ # x = rearrange(x, 'n c h w -> n (h w) c')
1066
+ # source = rearrange(source, 'n c h w -> n (h w) c')
1067
+ query, key, value = x, source, source
1068
+
1069
+ if self.token_mixer is not None:
1070
+ # multi-head attention
1071
+ m = self.norm1(x)
1072
+ m = rearrange(m, 'n (h w) c -> n c h w', h=H, w=W)
1073
+ m = self.attention(m)
1074
+ m = rearrange(m, 'n c h w -> n (h w) c')
1075
+
1076
+ x = x + m
1077
+ x = x + self.mlp(self.norm2(x))
1078
+ return x
1079
+ else:
1080
+ # multi-head attention
1081
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
1082
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
1083
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
1084
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask, H=H, W=W) # [N, L, (H, D)]
1085
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
1086
+ message = self.norm1(message)
1087
+
1088
+ # feed-forward network
1089
+ message = self.mlp(torch.cat([x, message], dim=2))
1090
+ message = self.norm2(message)
1091
+
1092
+ return x + message
1093
+
1094
+ class LoFTREncoderLayer(nn.Module):
1095
+ def __init__(self,
1096
+ d_model,
1097
+ nhead,
1098
+ attention='linear',
1099
+ xformer=False,
1100
+ ):
1101
+ super(LoFTREncoderLayer, self).__init__()
1102
+
1103
+ self.dim = d_model // nhead
1104
+ self.nhead = nhead
1105
+
1106
+ # multi-head attention
1107
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
1108
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
1109
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
1110
+
1111
+ if xformer:
1112
+ self.attention = XAttention()
1113
+ else:
1114
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
1115
+ self.merge = nn.Linear(d_model, d_model, bias=False)
1116
+
1117
+ # feed-forward network
1118
+ self.mlp = nn.Sequential(
1119
+ nn.Linear(d_model*2, d_model*2, bias=False),
1120
+ nn.ReLU(True),
1121
+ nn.Linear(d_model*2, d_model, bias=False),
1122
+ )
1123
+
1124
+ # norm and dropout
1125
+ self.norm1 = nn.LayerNorm(d_model)
1126
+ self.norm2 = nn.LayerNorm(d_model)
1127
+
1128
+ def forward(self, x, source, x_mask=None, source_mask=None):
1129
+ """
1130
+ Args:
1131
+ x (torch.Tensor): [N, L, C]
1132
+ source (torch.Tensor): [N, S, C]
1133
+ x_mask (torch.Tensor): [N, L] (optional)
1134
+ source_mask (torch.Tensor): [N, S] (optional)
1135
+ """
1136
+ bs = x.size(0)
1137
+ query, key, value = x, source, source
1138
+
1139
+ # multi-head attention
1140
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
1141
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
1142
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
1143
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
1144
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
1145
+ message = self.norm1(message)
1146
+
1147
+ # feed-forward network
1148
+ message = self.mlp(torch.cat([x, message], dim=2))
1149
+ message = self.norm2(message)
1150
+
1151
+ return x + message
1152
+
1153
+ def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
1154
+ """
1155
+ Args:
1156
+ x (torch.Tensor): [N, L, C]
1157
+ source (torch.Tensor): [N, S, C]
1158
+ x_mask (torch.Tensor): [N, L] (optional)
1159
+ source_mask (torch.Tensor): [N, S] (optional)
1160
+ """
1161
+ bs = x.size(0)
1162
+ query, key, value = x, source, source
1163
+
1164
+ # multi-head attention
1165
+ with profiler.profile("proj*3"):
1166
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
1167
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
1168
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
1169
+ with profiler.profile("attention"):
1170
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
1171
+ with profiler.profile("merge"):
1172
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
1173
+ with profiler.profile("norm1"):
1174
+ message = self.norm1(message)
1175
+
1176
+ # feed-forward network
1177
+ with profiler.profile("mlp"):
1178
+ message = self.mlp(torch.cat([x, message], dim=2))
1179
+ with profiler.profile("norm2"):
1180
+ message = self.norm2(message)
1181
+
1182
+ return x + message
1183
+
1184
+ class PANEncoderLayer_cross(nn.Module):
1185
+ def __init__(self,
1186
+ d_model,
1187
+ nhead,
1188
+ attention='linear',
1189
+ pool_size=4,
1190
+ bn=True,
1191
+ ):
1192
+ super(PANEncoderLayer_cross, self).__init__()
1193
+
1194
+ self.pool_size = pool_size
1195
+
1196
+ self.dim = d_model // nhead
1197
+ self.nhead = nhead
1198
+
1199
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
1200
+ # multi-head attention
1201
+ if bn:
1202
+ method = 'dw_bn'
1203
+ else:
1204
+ method = 'dw'
1205
+ self.qk_proj_conv = self._build_projection(d_model, d_model, method=method)
1206
+ self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
1207
+
1208
+ # self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
1209
+ # self.k_proj = nn.Linear(d_model, d_model, bias=False)
1210
+ # self.v_proj = nn.Linear(d_model, d_model, bias=False)
1211
+ self.attention = FullAttention()
1212
+ self.merge = nn.Linear(d_model, d_model, bias=False)
1213
+
1214
+ # feed-forward network
1215
+ self.mlp = nn.Sequential(
1216
+ nn.Linear(d_model*2, d_model*2, bias=False),
1217
+ nn.ReLU(True),
1218
+ nn.Linear(d_model*2, d_model, bias=False),
1219
+ )
1220
+
1221
+ # norm and dropout
1222
+ self.norm1 = nn.LayerNorm(d_model)
1223
+ self.norm2 = nn.LayerNorm(d_model)
1224
+
1225
+ # self.norm1 = nn.BatchNorm2d(d_model)
1226
+
1227
+ def forward(self, x1, x2, x1_mask=None, x2_mask=None):
1228
+ """
1229
+ Args:
1230
+ x (torch.Tensor): [N, C, H1, W1]
1231
+ source (torch.Tensor): [N, C, H2, W2]
1232
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
1233
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
1234
+ """
1235
+ bs = x1.size(0)
1236
+ H1, W1 = x1.size(-2) // self.pool_size, x1.size(-1) // self.pool_size
1237
+ H2, W2 = x2.size(-2) // self.pool_size, x2.size(-1) // self.pool_size
1238
+
1239
+ query = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
1240
+ key = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
1241
+ v2 = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
1242
+ v1 = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
1243
+
1244
+ # multi-head attention
1245
+ query = self.qk_proj_conv(query) # [N, C, H1//pool, W1//pool]
1246
+ key = self.qk_proj_conv(key)
1247
+ v2 = self.v_proj_conv(v2)
1248
+ v1 = self.v_proj_conv(v1)
1249
+
1250
+ C = query.shape[-3]
1251
+ if x1_mask is not None and x2_mask is not None:
1252
+ x1_mask = self.max_pool(x1_mask.float()).bool() # [N, H1//pool, W1//pool]
1253
+ x2_mask = self.max_pool(x2_mask.float()).bool()
1254
+
1255
+ mask_h1, mask_w1 = x1_mask[0].sum(-2)[0], x1_mask[0].sum(-1)[0]
1256
+ mask_h2, mask_w2 = x2_mask[0].sum(-2)[0], x2_mask[0].sum(-1)[0]
1257
+
1258
+ query = query[:, :, :mask_h1, :mask_w1]
1259
+ key = key[:, :, :mask_h2, :mask_w2]
1260
+ v1 = v1[:, :, :mask_h1, :mask_w1]
1261
+ v2 = v2[:, :, :mask_h2, :mask_w2]
1262
+ x1_mask = x1_mask[:, :mask_h1, :mask_w1]
1263
+ x2_mask = x2_mask[:, :mask_h2, :mask_w2]
1264
+
1265
+ else:
1266
+ assert x1_mask is None and x2_mask is None
1267
+
1268
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
1269
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
1270
+ v2 = rearrange(v2, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
1271
+ v1 = rearrange(v1, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
1272
+ if x2_mask is not None or x1_mask is not None:
1273
+ x1_mask = x1_mask.flatten(-2)
1274
+ x2_mask = x2_mask.flatten(-2)
1275
+
1276
+
1277
+ QK = torch.einsum("nlhd,nshd->nlsh", query, key)
1278
+ with torch.autocast(enabled=False, device_type='cuda'):
1279
+ if x2_mask is not None or x1_mask is not None:
1280
+ # S1 = S2.transpose(-2,-3).masked_fill(~(x_mask[:, None, :, None] * source_mask[:, :, None, None]), -1e9) # float('-inf')
1281
+ QK = QK.float().masked_fill_(~(x1_mask[:, :, None, None] * x2_mask[:, None, :, None]), -1e9) # float('-inf')
1282
+
1283
+
1284
+ # Compute the attention and the weighted average
1285
+ softmax_temp = 1. / query.size(3)**.5 # sqrt(D)
1286
+ S1 = torch.softmax(softmax_temp * QK, dim=2)
1287
+ S2 = torch.softmax(softmax_temp * QK, dim=3)
1288
+
1289
+ m1 = torch.einsum("nlsh,nshd->nlhd", S1, v2)
1290
+ m2 = torch.einsum("nlsh,nlhd->nshd", S2, v1)
1291
+
1292
+ if x1_mask is not None and x2_mask is not None:
1293
+ m1 = m1.view(bs, mask_h1, mask_w1, self.nhead, self.dim)
1294
+ if mask_h1 != H1:
1295
+ m1 = torch.cat([m1, torch.zeros(m1.size(0), H1-mask_h1, W1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=1)
1296
+ elif mask_w1 != W1:
1297
+ m1 = torch.cat([m1, torch.zeros(m1.size(0), H1, W1-mask_w1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=2)
1298
+ else:
1299
+ assert x1_mask is None and x2_mask is None
1300
+
1301
+ m1 = self.merge(m1.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
1302
+ m1 = rearrange(m1, 'b (h w) c -> b c h w', h=H1, w=W1) # [N, C, H, W]
1303
+ m1 = torch.nn.functional.interpolate(m1, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
1304
+ # feed-forward network
1305
+ m1 = self.mlp(torch.cat([x1, m1], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
1306
+ m1 = self.norm2(m1).permute(0, 3, 1, 2) # [N, C, H1, W1]
1307
+
1308
+ if x1_mask is not None and x2_mask is not None:
1309
+ m2 = m2.view(bs, mask_h2, mask_w2, self.nhead, self.dim)
1310
+ if mask_h2 != H2:
1311
+ m2 = torch.cat([m2, torch.zeros(m2.size(0), H2-mask_h2, W2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=1)
1312
+ elif mask_w2 != W2:
1313
+ m2 = torch.cat([m2, torch.zeros(m2.size(0), H2, W2-mask_w2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=2)
1314
+ else:
1315
+ assert x1_mask is None and x2_mask is None
1316
+
1317
+ m2 = self.merge(m2.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
1318
+ m2 = rearrange(m2, 'b (h w) c -> b c h w', h=H2, w=W2) # [N, C, H, W]
1319
+ m2 = torch.nn.functional.interpolate(m2, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
1320
+ # feed-forward network
1321
+ m2 = self.mlp(torch.cat([x2, m2], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
1322
+ m2 = self.norm2(m2).permute(0, 3, 1, 2) # [N, C, H1, W1]
1323
+
1324
+ return x1 + m1, x2 + m2
1325
+
1326
+ def _build_projection(self,
1327
+ dim_in,
1328
+ dim_out,
1329
+ kernel_size=3,
1330
+ padding=1,
1331
+ stride=1,
1332
+ method='dw_bn',
1333
+ ):
1334
+ if method == 'dw_bn':
1335
+ proj = nn.Sequential(OrderedDict([
1336
+ ('conv', nn.Conv2d(
1337
+ dim_in,
1338
+ dim_in,
1339
+ kernel_size=kernel_size,
1340
+ padding=padding,
1341
+ stride=stride,
1342
+ bias=False,
1343
+ groups=dim_in
1344
+ )),
1345
+ ('bn', nn.BatchNorm2d(dim_in)),
1346
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
1347
+ ]))
1348
+ elif method == 'avg':
1349
+ proj = nn.Sequential(OrderedDict([
1350
+ ('avg', nn.AvgPool2d(
1351
+ kernel_size=kernel_size,
1352
+ padding=padding,
1353
+ stride=stride,
1354
+ ceil_mode=True
1355
+ )),
1356
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
1357
+ ]))
1358
+ elif method == 'linear':
1359
+ proj = None
1360
+ elif method == 'dw':
1361
+ proj = nn.Sequential(OrderedDict([
1362
+ ('conv', nn.Conv2d(
1363
+ dim_in,
1364
+ dim_in,
1365
+ kernel_size=kernel_size,
1366
+ padding=padding,
1367
+ stride=stride,
1368
+ bias=False,
1369
+ groups=dim_in
1370
+ )),
1371
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
1372
+ ]))
1373
+ else:
1374
+ raise ValueError('Unknown method ({})'.format(method))
1375
+
1376
+ return proj
1377
+
1378
+ class LocalFeatureTransformer(nn.Module):
1379
+ """A Local Feature Transformer (LoFTR) module."""
1380
+
1381
+ def __init__(self, config):
1382
+ super(LocalFeatureTransformer, self).__init__()
1383
+
1384
+ self.full_config = config
1385
+ self.fine = False
1386
+ if 'coarse' not in config:
1387
+ self.fine = True # fine attention
1388
+ else:
1389
+ config = config['coarse']
1390
+ self.d_model = config['d_model']
1391
+ self.nhead = config['nhead']
1392
+ self.layer_names = config['layer_names']
1393
+ self.pan = config['pan']
1394
+ self.bidirect = config['bidirection']
1395
+ # prune
1396
+ self.pool_size = config['pool_size']
1397
+ self.matchability = False
1398
+ self.depth_confidence = -1.0
1399
+ self.width_confidence = -1.0
1400
+ # self.depth_confidence = config['depth_confidence']
1401
+ # self.width_confidence = config['width_confidence']
1402
+ # self.matchability = self.depth_confidence > 0 or self.width_confidence > 0
1403
+ # self.thr = self.full_config['match_coarse']['thr']
1404
+ if not self.fine:
1405
+ # asy
1406
+ self.asymmetric = config['asymmetric']
1407
+ self.asymmetric_self = config['asymmetric_self']
1408
+ # aggregate
1409
+ self.aggregate = config['dwconv']
1410
+ # RoPE
1411
+ self.rope = config['rope']
1412
+ # absPE
1413
+ self.abspe = config['abspe']
1414
+
1415
+ else:
1416
+ self.rope, self.asymmetric, self.asymmetric_self, self.aggregate = False, False, False, False
1417
+ if self.matchability:
1418
+ self.n_layers = len(self.layer_names) // 2
1419
+ assert self.n_layers == 4
1420
+ self.log_assignment = nn.ModuleList(
1421
+ [MatchAssignment(self.d_model) for _ in range(self.n_layers)])
1422
+ self.token_confidence = nn.ModuleList([
1423
+ TokenConfidence(self.d_model) for _ in range(self.n_layers-1)])
1424
+
1425
+ self.CoarseMatching = CoarseMatching(self.full_config['match_coarse'])
1426
+
1427
+ # self only
1428
+ # if self.rope:
1429
+ # self_layer = RoPELoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'], config['rope'], config['token_mixer'])
1430
+ # self.layers = nn.ModuleList([copy.deepcopy(self_layer) for _ in range(len(self.layer_names))])
1431
+
1432
+ if self.bidirect:
1433
+ assert config['xformer'] is False and config['pan'] is True
1434
+ self_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'], config['xformer'])
1435
+ cross_layer = PANEncoderLayer_cross(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'])
1436
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
1437
+ else:
1438
+ if self.aggregate:
1439
+ if self.rope:
1440
+ # assert config['npe'][0] == 832 and config['npe'][1] == 832 and config['npe'][2] == 832 and config['npe'][3] == 832
1441
+ logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
1442
+ self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
1443
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
1444
+ config['norm_before'], config['rope'], config['npe'], config['vit_norm'], config['rope_dwproj'])
1445
+ cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
1446
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
1447
+ config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
1448
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
1449
+ elif self.abspe:
1450
+ logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
1451
+ self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
1452
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
1453
+ config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
1454
+ cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
1455
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
1456
+ config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
1457
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
1458
+
1459
+ else:
1460
+ encoder_layer = AG_Conv_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'],
1461
+ config['xformer'], config['leaky'], config['dwconv'], config['scatter'],
1462
+ config['norm_before'])
1463
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
1464
+ else:
1465
+ encoder_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'],
1466
+ config['bn'], config['xformer'], config['leaky'], config['dwconv'], config['scatter']) \
1467
+ if config['pan'] else LoFTREncoderLayer(config['d_model'], config['nhead'],
1468
+ config['attention'], config['xformer'])
1469
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
1470
+ self._reset_parameters()
1471
+
1472
+ def _reset_parameters(self):
1473
+ for p in self.parameters():
1474
+ if p.dim() > 1:
1475
+ nn.init.xavier_uniform_(p)
1476
+
1477
+ def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
1478
+ """
1479
+ Args:
1480
+ feat0 (torch.Tensor): [N, C, H, W]
1481
+ feat1 (torch.Tensor): [N, C, H, W]
1482
+ mask0 (torch.Tensor): [N, L] (optional)
1483
+ mask1 (torch.Tensor): [N, S] (optional)
1484
+ """
1485
+ # nchw for pan and n(hw)c for loftr
1486
+ assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
1487
+ H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1)
1488
+ bs = feat0.shape[0]
1489
+ padding = False
1490
+ if bs == 1 and mask0 is not None and mask1 is not None and self.pan: # NCHW for pan
1491
+ mask_H0, mask_W0 = mask0.size(-2), mask0.size(-1)
1492
+ mask_H1, mask_W1 = mask1.size(-2), mask1.size(-1)
1493
+ mask_h0, mask_w0 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0]
1494
+ mask_h1, mask_w1 = mask1[0].sum(-2)[0], mask1[0].sum(-1)[0]
1495
+
1496
+ #round to self.pool_size
1497
+ if self.pan:
1498
+ mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.pool_size*self.pool_size, mask_w0//self.pool_size*self.pool_size, mask_h1//self.pool_size*self.pool_size, mask_w1//self.pool_size*self.pool_size
1499
+
1500
+ feat0 = feat0[:, :, :mask_h0, :mask_w0]
1501
+ feat1 = feat1[:, :, :mask_h1, :mask_w1]
1502
+
1503
+ padding = True
1504
+
1505
+ # rope self only
1506
+ # if self.rope:
1507
+ # feat0, feat1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
1508
+ # prune
1509
+ if padding:
1510
+ l0, l1 = mask_h0 * mask_w0, mask_h1 * mask_w1
1511
+ else:
1512
+ l0, l1 = H0 * W0, H1 * W1
1513
+ do_early_stop = self.depth_confidence > 0
1514
+ do_point_pruning = self.width_confidence > 0
1515
+ if do_point_pruning:
1516
+ ind0 = torch.arange(0, l0, device=feat0.device)[None]
1517
+ ind1 = torch.arange(0, l1, device=feat0.device)[None]
1518
+ # We store the index of the layer at which pruning is detected.
1519
+ prune0 = torch.ones_like(ind0)
1520
+ prune1 = torch.ones_like(ind1)
1521
+ if do_early_stop:
1522
+ token0, token1 = None, None
1523
+
1524
+ for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
1525
+ if padding:
1526
+ mask0, mask1 = None, None
1527
+ if name == 'self':
1528
+ # if self.rope:
1529
+ # feat0 = layer(feat0, feat0, mask0, mask1, H0, W0)
1530
+ # feat1 = layer(feat1, feat1, mask0, mask1, H1, W1)
1531
+ if self.asymmetric:
1532
+ assert False, 'not worked'
1533
+ # feat0 = layer(feat0, feat0, mask0, mask1)
1534
+ feat1 = layer(feat1, feat1, mask1, mask1)
1535
+ else:
1536
+ feat0 = layer(feat0, feat0, mask0, mask0)
1537
+ feat1 = layer(feat1, feat1, mask1, mask1)
1538
+ elif name == 'cross':
1539
+ if self.bidirect:
1540
+ feat0, feat1 = layer(feat0, feat1, mask0, mask1)
1541
+ else:
1542
+ if self.asymmetric or self.asymmetric_self:
1543
+ assert False, 'not worked'
1544
+ feat0 = layer(feat0, feat1, mask0, mask1)
1545
+ else:
1546
+ feat0 = layer(feat0, feat1, mask0, mask1)
1547
+ feat1 = layer(feat1, feat0, mask1, mask0)
1548
+
1549
+ if i == len(self.layer_names) - 1 and not self.training:
1550
+ continue
1551
+ if self.matchability:
1552
+ desc0, desc1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
1553
+ if do_early_stop:
1554
+ token0, token1 = self.token_confidence[i//2](desc0, desc1)
1555
+ if self.check_if_stop(token0, token1, i, l0+l1) and not self.training:
1556
+ break
1557
+ if do_point_pruning:
1558
+ scores0, scores1 = self.log_assignment[i//2].scores(desc0, desc1)
1559
+ mask0 = self.get_pruning_mask(token0, scores0, i)
1560
+ mask1 = self.get_pruning_mask(token1, scores1, i)
1561
+ ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
1562
+ feat0, feat1 = desc0[mask0][None], desc1[mask1][None]
1563
+ if feat0.shape[-2] == 0 or desc1.shape[-2] == 0:
1564
+ break
1565
+ prune0[:, ind0] += 1
1566
+ prune1[:, ind1] += 1
1567
+ if self.training and self.matchability:
1568
+ scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
1569
+ m0_full = torch.zeros((bs, mask_h0 * mask_w0), device=matchability0.device, dtype=matchability0.dtype)
1570
+ m0_full.scatter(1, ind0, matchability0.squeeze(-1))
1571
+ if padding and self.d_model == feat0.size(1):
1572
+ m0_full = m0_full.reshape(bs, mask_h0, mask_w0)
1573
+ bs, c, mask_h0, mask_w0 = feat0.size()
1574
+ if mask_h0 != mask_H0:
1575
+ m0_full = torch.cat([m0_full, torch.zeros(bs, mask_H0-mask_h0, mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=1)
1576
+ elif mask_w0 != mask_W0:
1577
+ m0_full = torch.cat([m0_full, torch.zeros(bs, mask_h0, mask_W0-mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=2)
1578
+ m0_full = m0_full.reshape(bs, mask_H0*mask_W0)
1579
+ m1_full = torch.zeros((bs, mask_h1 * mask_w1), device=matchability0.device, dtype=matchability0.dtype)
1580
+ m1_full.scatter(1, ind1, matchability1.squeeze(-1))
1581
+ if padding and self.d_model == feat1.size(1):
1582
+ m1_full = m1_full.reshape(bs, mask_h1, mask_w1)
1583
+ bs, c, mask_h1, mask_w1 = feat1.size()
1584
+ if mask_h1 != mask_H1:
1585
+ m1_full = torch.cat([m1_full, torch.zeros(bs, mask_H1-mask_h1, mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=1)
1586
+ elif mask_w1 != mask_W1:
1587
+ m1_full = torch.cat([m1_full, torch.zeros(bs, mask_h1, mask_W1-mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=2)
1588
+ m1_full = m1_full.reshape(bs, mask_H1*mask_W1)
1589
+ data.update({'matchability0_'+str(i//2): m0_full, 'matchability1_'+str(i//2): m1_full})
1590
+ m0, m1, mscores0, mscores1 = filter_matches(
1591
+ scores, self.thr)
1592
+ if do_point_pruning:
1593
+ m0_ = torch.full((bs, l0), -1, device=m0.device, dtype=m0.dtype)
1594
+ m1_ = torch.full((bs, l1), -1, device=m1.device, dtype=m1.dtype)
1595
+ m0_[:, ind0] = torch.where(
1596
+ m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
1597
+ m1_[:, ind1] = torch.where(
1598
+ m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
1599
+ mscores0_ = torch.zeros((bs, l0), device=mscores0.device)
1600
+ mscores1_ = torch.zeros((bs, l1), device=mscores1.device)
1601
+ mscores0_[:, ind0] = mscores0
1602
+ mscores1_[:, ind1] = mscores1
1603
+ m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
1604
+ if padding and self.d_model == feat0.size(1):
1605
+ m0 = m0.reshape(bs, mask_h0, mask_w0)
1606
+ bs, c, mask_h0, mask_w0 = feat0.size()
1607
+ if mask_h0 != mask_H0:
1608
+ m0 = torch.cat([m0, -torch.ones(bs, mask_H0-mask_h0, mask_w0, device=m0.device, dtype=m0.dtype)], dim=1)
1609
+ elif mask_w0 != mask_W0:
1610
+ m0 = torch.cat([m0, -torch.ones(bs, mask_h0, mask_W0-mask_w0, device=m0.device, dtype=m0.dtype)], dim=2)
1611
+ m0 = m0.reshape(bs, mask_H0*mask_W0)
1612
+ if padding and self.d_model == feat1.size(1):
1613
+ m1 = m1.reshape(bs, mask_h1, mask_w1)
1614
+ bs, c, mask_h1, mask_w1 = feat1.size()
1615
+ if mask_h1 != mask_H1:
1616
+ m1 = torch.cat([m1, -torch.ones(bs, mask_H1-mask_h1, mask_w1, device=m1.device, dtype=m1.dtype)], dim=1)
1617
+ elif mask_w1 != mask_W1:
1618
+ m1 = torch.cat([m1, -torch.ones(bs, mask_h1, mask_W1-mask_w1, device=m1.device, dtype=m1.dtype)], dim=2)
1619
+ m1 = m1.reshape(bs, mask_H1*mask_W1)
1620
+ data.update({'matches0_'+str(i//2): m0, 'matches1_'+str(i//2): m1})
1621
+ conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
1622
+ ind = ind0[...,None] * l1 + ind1[:,None,:]
1623
+ # conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
1624
+ conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
1625
+ if padding and self.d_model == feat0.size(1):
1626
+ conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
1627
+ bs, c, mask_h0, mask_w0 = feat0.size()
1628
+ if mask_h0 != mask_H0:
1629
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
1630
+ elif mask_w0 != mask_W0:
1631
+ conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
1632
+ bs, c, mask_h1, mask_w1 = feat1.size()
1633
+ if mask_h1 != mask_H1:
1634
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
1635
+ elif mask_w1 != mask_W1:
1636
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
1637
+ conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
1638
+ data.update({'conf_matrix_'+str(i//2): conf})
1639
+
1640
+
1641
+
1642
+ else:
1643
+ raise KeyError
1644
+
1645
+ if self.matchability and not self.training:
1646
+ scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
1647
+ conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
1648
+ ind = ind0[...,None] * l1 + ind1[:,None,:]
1649
+ # conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
1650
+ conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
1651
+ if padding and self.d_model == feat0.size(1):
1652
+ conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
1653
+ bs, c, mask_h0, mask_w0 = feat0.size()
1654
+ if mask_h0 != mask_H0:
1655
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
1656
+ elif mask_w0 != mask_W0:
1657
+ conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
1658
+ bs, c, mask_h1, mask_w1 = feat1.size()
1659
+ if mask_h1 != mask_H1:
1660
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
1661
+ elif mask_w1 != mask_W1:
1662
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
1663
+ conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
1664
+ data.update({'conf_matrix': conf})
1665
+ data.update(**self.CoarseMatching.get_coarse_match(conf, data))
1666
+ # m0, m1, mscores0, mscores1 = filter_matches(
1667
+ # scores, self.conf.filter_threshold)
1668
+
1669
+ # matches, mscores = [], []
1670
+ # for k in range(b):
1671
+ # valid = m0[k] > -1
1672
+ # m_indices_0 = torch.where(valid)[0]
1673
+ # m_indices_1 = m0[k][valid]
1674
+ # if do_point_pruning:
1675
+ # m_indices_0 = ind0[k, m_indices_0]
1676
+ # m_indices_1 = ind1[k, m_indices_1]
1677
+ # matches.append(torch.stack([m_indices_0, m_indices_1], -1))
1678
+ # mscores.append(mscores0[k][valid])
1679
+
1680
+ # # TODO: Remove when hloc switches to the compact format.
1681
+ # if do_point_pruning:
1682
+ # m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
1683
+ # m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
1684
+ # m0_[:, ind0] = torch.where(
1685
+ # m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
1686
+ # m1_[:, ind1] = torch.where(
1687
+ # m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
1688
+ # mscores0_ = torch.zeros((b, m), device=mscores0.device)
1689
+ # mscores1_ = torch.zeros((b, n), device=mscores1.device)
1690
+ # mscores0_[:, ind0] = mscores0
1691
+ # mscores1_[:, ind1] = mscores1
1692
+ # m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
1693
+
1694
+ # pred = {
1695
+ # 'matches0': m0,
1696
+ # 'matches1': m1,
1697
+ # 'matching_scores0': mscores0,
1698
+ # 'matching_scores1': mscores1,
1699
+ # 'stop': i+1,
1700
+ # 'matches': matches,
1701
+ # 'scores': mscores,
1702
+ # }
1703
+
1704
+ # if do_point_pruning:
1705
+ # pred.update(dict(prune0=prune0, prune1=prune1))
1706
+ # return pred
1707
+
1708
+
1709
+ if padding and self.d_model == feat0.size(1):
1710
+ bs, c, mask_h0, mask_w0 = feat0.size()
1711
+ if mask_h0 != mask_H0:
1712
+ feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2)
1713
+ elif mask_w0 != mask_W0:
1714
+ feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1)
1715
+ bs, c, mask_h1, mask_w1 = feat1.size()
1716
+ if mask_h1 != mask_H1:
1717
+ feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2)
1718
+ elif mask_w1 != mask_W1:
1719
+ feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1)
1720
+
1721
+ return feat0, feat1
1722
+
1723
+ def pro(self, feat0, feat1, mask0=None, mask1=None, profiler=None):
1724
+ """
1725
+ Args:
1726
+ feat0 (torch.Tensor): [N, C, H, W]
1727
+ feat1 (torch.Tensor): [N, C, H, W]
1728
+ mask0 (torch.Tensor): [N, L] (optional)
1729
+ mask1 (torch.Tensor): [N, S] (optional)
1730
+ """
1731
+
1732
+ assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
1733
+ with profiler.profile("LoFTR_transformer_attention"):
1734
+ for layer, name in zip(self.layers, self.layer_names):
1735
+ if name == 'self':
1736
+ feat0 = layer.pro(feat0, feat0, mask0, mask0, profiler=profiler)
1737
+ feat1 = layer.pro(feat1, feat1, mask1, mask1, profiler=profiler)
1738
+ elif name == 'cross':
1739
+ feat0 = layer.pro(feat0, feat1, mask0, mask1, profiler=profiler)
1740
+ feat1 = layer.pro(feat1, feat0, mask1, mask0, profiler=profiler)
1741
+ else:
1742
+ raise KeyError
1743
+
1744
+ return feat0, feat1
1745
+
1746
+ def confidence_threshold(self, layer_index: int) -> float:
1747
+ """ scaled confidence threshold """
1748
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
1749
+ return np.clip(threshold, 0, 1)
1750
+
1751
+ def get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor,
1752
+ layer_index: int) -> torch.Tensor:
1753
+ """ mask points which should be removed """
1754
+ threshold = self.confidence_threshold(layer_index)
1755
+ if confidences is not None:
1756
+ scores = torch.where(
1757
+ confidences > threshold, scores, scores.new_tensor(1.0))
1758
+ return scores > (1 - self.width_confidence)
1759
+
1760
+ def check_if_stop(self,
1761
+ confidences0: torch.Tensor,
1762
+ confidences1: torch.Tensor,
1763
+ layer_index: int, num_points: int) -> torch.Tensor:
1764
+ """ evaluate stopping condition"""
1765
+ confidences = torch.cat([confidences0, confidences1], -1)
1766
+ threshold = self.confidence_threshold(layer_index)
1767
+ pos = 1.0 - (confidences < threshold).float().sum() / num_points
1768
+ return pos > self.depth_confidence
imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class TokenConfidence(nn.Module):
6
+ def __init__(self, dim: int) -> None:
7
+ super().__init__()
8
+ self.token = nn.Sequential(
9
+ nn.Linear(dim, 1),
10
+ nn.Sigmoid()
11
+ )
12
+
13
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
14
+ """ get confidence tokens """
15
+ return (
16
+ self.token(desc0.detach().float()).squeeze(-1),
17
+ self.token(desc1.detach().float()).squeeze(-1))
18
+
19
+ def sigmoid_log_double_softmax(
20
+ sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
21
+ """ create the log assignment matrix from logits and similarity"""
22
+ b, m, n = sim.shape
23
+ m0, m1 = torch.sigmoid(z0), torch.sigmoid(z1)
24
+ certainties = torch.log(m0) + torch.log(m1).transpose(1, 2)
25
+ scores0 = F.log_softmax(sim, 2)
26
+ scores1 = F.log_softmax(
27
+ sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
28
+ scores = scores0 + scores1 + certainties
29
+ # scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
30
+ # scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
31
+ return scores, m0, m1
32
+
33
+ class MatchAssignment(nn.Module):
34
+ def __init__(self, dim: int) -> None:
35
+ super().__init__()
36
+ self.dim = dim
37
+ self.matchability = nn.Linear(dim, 1, bias=True)
38
+ self.final_proj = nn.Linear(dim, dim, bias=True)
39
+
40
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
41
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
42
+ """ build assignment matrix from descriptors """
43
+ mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
44
+ _, _, d = mdesc0.shape
45
+ mdesc0, mdesc1 = mdesc0 / d**.25, mdesc1 / d**.25
46
+ sim = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1)
47
+ z0 = self.matchability(desc0)
48
+ z1 = self.matchability(desc1)
49
+ scores, m0, m1 = sigmoid_log_double_softmax(sim, z0, z1)
50
+ return scores, sim, m0, m1
51
+
52
+ def scores(self, desc0: torch.Tensor, desc1: torch.Tensor):
53
+ m0 = torch.sigmoid(self.matchability(desc0)).squeeze(-1)
54
+ m1 = torch.sigmoid(self.matchability(desc1)).squeeze(-1)
55
+ return m0, m1
56
+
57
+ def filter_matches(scores: torch.Tensor, th: float):
58
+ """ obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
59
+ max0, max1 = scores.max(2), scores.max(1)
60
+ m0, m1 = max0.indices, max1.indices
61
+ indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
62
+ indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
63
+ mutual0 = indices0 == m1.gather(1, m0)
64
+ mutual1 = indices1 == m0.gather(1, m1)
65
+ max0_exp = max0.values.exp()
66
+ zero = max0_exp.new_tensor(0)
67
+ mscores0 = torch.where(mutual0, max0_exp, zero)
68
+ mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
69
+ if th is not None:
70
+ valid0 = mutual0 & (mscores0 > th)
71
+ else:
72
+ valid0 = mutual0
73
+ valid1 = mutual1 & valid0.gather(1, m1)
74
+ m0 = torch.where(valid0, m0, -1)
75
+ m1 = torch.where(valid1, m1, -1)
76
+ return m0, m1, mscores0, mscores1
imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops.einops import rearrange, repeat
5
+
6
+ from loguru import logger
7
+
8
+ INF = 1e9
9
+
10
+ def mask_border(m, b: int, v):
11
+ """ Mask borders with value
12
+ Args:
13
+ m (torch.Tensor): [N, H0, W0, H1, W1]
14
+ b (int)
15
+ v (m.dtype)
16
+ """
17
+ if b <= 0:
18
+ return
19
+
20
+ m[:, :b] = v
21
+ m[:, :, :b] = v
22
+ m[:, :, :, :b] = v
23
+ m[:, :, :, :, :b] = v
24
+ m[:, -b:] = v
25
+ m[:, :, -b:] = v
26
+ m[:, :, :, -b:] = v
27
+ m[:, :, :, :, -b:] = v
28
+
29
+
30
+ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
31
+ if bd <= 0:
32
+ return
33
+
34
+ m[:, :bd] = v
35
+ m[:, :, :bd] = v
36
+ m[:, :, :, :bd] = v
37
+ m[:, :, :, :, :bd] = v
38
+
39
+ h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
40
+ h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
41
+ for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
42
+ m[b_idx, h0 - bd:] = v
43
+ m[b_idx, :, w0 - bd:] = v
44
+ m[b_idx, :, :, h1 - bd:] = v
45
+ m[b_idx, :, :, :, w1 - bd:] = v
46
+
47
+
48
+ def compute_max_candidates(p_m0, p_m1):
49
+ """Compute the max candidates of all pairs within a batch
50
+
51
+ Args:
52
+ p_m0, p_m1 (torch.Tensor): padded masks
53
+ """
54
+ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
55
+ h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
56
+ max_cand = torch.sum(
57
+ torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
58
+ return max_cand
59
+
60
+
61
+ class CoarseMatching(nn.Module):
62
+ def __init__(self, config):
63
+ super().__init__()
64
+ self.config = config
65
+ # general config
66
+ self.thr = config['thr']
67
+ self.border_rm = config['border_rm']
68
+ # -- # for trainig fine-level LoFTR
69
+ self.train_coarse_percent = config['train_coarse_percent']
70
+ self.train_pad_num_gt_min = config['train_pad_num_gt_min']
71
+
72
+ # we provide 2 options for differentiable matching
73
+ self.match_type = config['match_type']
74
+ if self.match_type == 'dual_softmax':
75
+ self.temperature = config['dsmax_temperature']
76
+ elif self.match_type == 'sinkhorn':
77
+ try:
78
+ from .superglue import log_optimal_transport
79
+ except ImportError:
80
+ raise ImportError("download superglue.py first!")
81
+ self.log_optimal_transport = log_optimal_transport
82
+ self.bin_score = nn.Parameter(
83
+ torch.tensor(config['skh_init_bin_score'], requires_grad=True))
84
+ self.skh_iters = config['skh_iters']
85
+ self.skh_prefilter = config['skh_prefilter']
86
+ else:
87
+ raise NotImplementedError()
88
+
89
+ self.mtd = config['mtd_spvs']
90
+ self.fix_bias = config['fix_bias']
91
+
92
+ def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
93
+ """
94
+ Args:
95
+ feat0 (torch.Tensor): [N, L, C]
96
+ feat1 (torch.Tensor): [N, S, C]
97
+ data (dict)
98
+ mask_c0 (torch.Tensor): [N, L] (optional)
99
+ mask_c1 (torch.Tensor): [N, S] (optional)
100
+ Update:
101
+ data (dict): {
102
+ 'b_ids' (torch.Tensor): [M'],
103
+ 'i_ids' (torch.Tensor): [M'],
104
+ 'j_ids' (torch.Tensor): [M'],
105
+ 'gt_mask' (torch.Tensor): [M'],
106
+ 'mkpts0_c' (torch.Tensor): [M, 2],
107
+ 'mkpts1_c' (torch.Tensor): [M, 2],
108
+ 'mconf' (torch.Tensor): [M]}
109
+ NOTE: M' != M during training.
110
+ """
111
+ N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
112
+
113
+ # normalize
114
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
115
+ [feat_c0, feat_c1])
116
+
117
+ if self.match_type == 'dual_softmax':
118
+ with torch.autocast(enabled=False, device_type='cuda'):
119
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
120
+ feat_c1) / self.temperature
121
+ if mask_c0 is not None:
122
+ sim_matrix = sim_matrix.float().masked_fill_(
123
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
124
+ -INF
125
+ # float("-inf") if sim_matrix.dtype == torch.float16 else -INF
126
+ )
127
+ if self.config['fp16log']:
128
+ t1 = F.softmax(sim_matrix, 1)
129
+ t2 = F.softmax(sim_matrix, 2)
130
+ conf_matrix = t1*t2
131
+ logger.info(f'feat_c0absmax: {feat_c0.abs().max()}')
132
+ logger.info(f'feat_c1absmax: {feat_c1.abs().max()}')
133
+ logger.info(f'sim_matrix: {sim_matrix.dtype}')
134
+ logger.info(f'sim_matrixabsmax: {sim_matrix.abs().max()}')
135
+ logger.info(f't1: {t1.dtype}, t2: {t2.dtype}, conf_matrix: {conf_matrix.dtype}')
136
+ logger.info(f't1absmax: {t1.abs().max()}, t2absmax: {t2.abs().max()}, conf_matrixabsmax: {conf_matrix.abs().max()}')
137
+ else:
138
+ conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
139
+
140
+ data.update({'conf_matrix': conf_matrix})
141
+
142
+ # predict coarse matches from conf_matrix
143
+ data.update(**self.get_coarse_match(conf_matrix, data))
144
+
145
+ @torch.no_grad()
146
+ def get_coarse_match(self, conf_matrix, data):
147
+ """
148
+ Args:
149
+ conf_matrix (torch.Tensor): [N, L, S]
150
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
151
+ Returns:
152
+ coarse_matches (dict): {
153
+ 'b_ids' (torch.Tensor): [M'],
154
+ 'i_ids' (torch.Tensor): [M'],
155
+ 'j_ids' (torch.Tensor): [M'],
156
+ 'gt_mask' (torch.Tensor): [M'],
157
+ 'm_bids' (torch.Tensor): [M],
158
+ 'mkpts0_c' (torch.Tensor): [M, 2],
159
+ 'mkpts1_c' (torch.Tensor): [M, 2],
160
+ 'mconf' (torch.Tensor): [M]}
161
+ """
162
+ axes_lengths = {
163
+ 'h0c': data['hw0_c'][0],
164
+ 'w0c': data['hw0_c'][1],
165
+ 'h1c': data['hw1_c'][0],
166
+ 'w1c': data['hw1_c'][1]
167
+ }
168
+ _device = conf_matrix.device
169
+ # 1. confidence thresholding
170
+ mask = conf_matrix > self.thr
171
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
172
+ **axes_lengths)
173
+ if 'mask0' not in data:
174
+ mask_border(mask, self.border_rm, False)
175
+ else:
176
+ mask_border_with_padding(mask, self.border_rm, False,
177
+ data['mask0'], data['mask1'])
178
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
179
+ **axes_lengths)
180
+
181
+ # 2. mutual nearest
182
+ if self.mtd:
183
+ b_ids, i_ids, j_ids = torch.where(mask)
184
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
185
+ else:
186
+ mask = mask \
187
+ * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
188
+ * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
189
+
190
+ # 3. find all valid coarse matches
191
+ # this only works when at most one `True` in each row
192
+ mask_v, all_j_ids = mask.max(dim=2)
193
+ b_ids, i_ids = torch.where(mask_v)
194
+ j_ids = all_j_ids[b_ids, i_ids]
195
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
196
+
197
+ # 4. Random sampling of training samples for fine-level LoFTR
198
+ # (optional) pad samples with gt coarse-level matches
199
+ if self.training:
200
+ # NOTE:
201
+ # The sampling is performed across all pairs in a batch without manually balancing
202
+ # #samples for fine-level increases w.r.t. batch_size
203
+ if 'mask0' not in data:
204
+ num_candidates_max = mask.size(0) * max(
205
+ mask.size(1), mask.size(2))
206
+ else:
207
+ num_candidates_max = compute_max_candidates(
208
+ data['mask0'], data['mask1'])
209
+ num_matches_train = int(num_candidates_max *
210
+ self.train_coarse_percent)
211
+ num_matches_pred = len(b_ids)
212
+ assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
213
+
214
+ # pred_indices is to select from prediction
215
+ if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
216
+ pred_indices = torch.arange(num_matches_pred, device=_device)
217
+ else:
218
+ pred_indices = torch.randint(
219
+ num_matches_pred,
220
+ (num_matches_train - self.train_pad_num_gt_min, ),
221
+ device=_device)
222
+
223
+ # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
224
+ gt_pad_indices = torch.randint(
225
+ len(data['spv_b_ids']),
226
+ (max(num_matches_train - num_matches_pred,
227
+ self.train_pad_num_gt_min), ),
228
+ device=_device)
229
+ mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
230
+
231
+ b_ids, i_ids, j_ids, mconf = map(
232
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
233
+ dim=0),
234
+ *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
235
+ [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
236
+
237
+ # These matches select patches that feed into fine-level network
238
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
239
+
240
+ # 4. Update with matches in original image resolution
241
+ if self.fix_bias:
242
+ scale = 8
243
+ else:
244
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
245
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
246
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
247
+ mkpts0_c = torch.stack(
248
+ [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
249
+ dim=1) * scale0
250
+ mkpts1_c = torch.stack(
251
+ [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
252
+ dim=1) * scale1
253
+
254
+ m_bids = b_ids[mconf != 0]
255
+
256
+ m_bids_f = repeat(m_bids, 'b -> b k', k = 3).reshape(-1)
257
+ coarse_matches.update({
258
+ 'gt_mask': mconf == 0,
259
+ 'm_bids': m_bids, # mconf == 0 => gt matches
260
+ 'm_bids_f': m_bids_f,
261
+ 'mkpts0_c': mkpts0_c[mconf != 0],
262
+ 'mkpts1_c': mkpts1_c[mconf != 0],
263
+ 'mconf': mconf[mconf != 0]
264
+ })
265
+
266
+ return coarse_matches
imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from kornia.geometry.subpix import dsnt
7
+ from kornia.utils.grid import create_meshgrid
8
+
9
+ from loguru import logger
10
+
11
+ class FineMatching(nn.Module):
12
+ """FineMatching with s2d paradigm"""
13
+
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ self.config = config
17
+ self.topk = config['match_fine']['topk']
18
+ self.mtd_spvs = config['fine']['mtd_spvs']
19
+ self.align_corner = config['align_corner']
20
+ self.fix_bias = config['fix_bias']
21
+ self.normfinem = config['match_fine']['normfinem']
22
+ self.fix_fine_matching = config['match_fine']['fix_fine_matching']
23
+ self.mutual_nearest = config['match_fine']['force_nearest']
24
+ self.skip_fine_softmax = config['match_fine']['skip_fine_softmax']
25
+ self.normfeat = config['match_fine']['normfeat']
26
+ self.use_sigmoid = config['match_fine']['use_sigmoid']
27
+ self.local_regress = config['match_fine']['local_regress']
28
+ self.local_regress_rmborder = config['match_fine']['local_regress_rmborder']
29
+ self.local_regress_nomask = config['match_fine']['local_regress_nomask']
30
+ self.local_regress_temperature = config['match_fine']['local_regress_temperature']
31
+ self.local_regress_padone = config['match_fine']['local_regress_padone']
32
+ self.local_regress_slice = config['match_fine']['local_regress_slice']
33
+ self.local_regress_slicedim = config['match_fine']['local_regress_slicedim']
34
+ self.local_regress_inner = config['match_fine']['local_regress_inner']
35
+ self.multi_regress = config['match_fine']['multi_regress']
36
+ def forward(self, feat_0, feat_1, data):
37
+ """
38
+ Args:
39
+ feat0 (torch.Tensor): [M, WW, C]
40
+ feat1 (torch.Tensor): [M, WW, C]
41
+ data (dict)
42
+ Update:
43
+ data (dict):{
44
+ 'expec_f' (torch.Tensor): [M, 3],
45
+ 'mkpts0_f' (torch.Tensor): [M, 2],
46
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
47
+ """
48
+ M, WW, C = feat_0.shape
49
+ W = int(math.sqrt(WW))
50
+ if self.fix_bias:
51
+ scale = 2
52
+ else:
53
+ scale = data['hw0_i'][0] / data['hw0_f'][0]
54
+ self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
55
+
56
+ # corner case: if no coarse matches found
57
+ if M == 0:
58
+ assert self.training == False, "M is always >0, when training, see coarse_matching.py"
59
+ # logger.warning('No matches found in coarse-level.')
60
+ if self.mtd_spvs:
61
+ data.update({
62
+ 'conf_matrix_f': torch.empty(0, WW, WW, device=feat_0.device),
63
+ 'mkpts0_f': data['mkpts0_c'],
64
+ 'mkpts1_f': data['mkpts1_c'],
65
+ })
66
+ # if self.local_regress:
67
+ # data.update({
68
+ # 'sim_matrix_f': torch.empty(0, WW, WW, device=feat_0.device),
69
+ # })
70
+ return
71
+ else:
72
+ data.update({
73
+ 'expec_f': torch.empty(0, 3, device=feat_0.device),
74
+ 'mkpts0_f': data['mkpts0_c'],
75
+ 'mkpts1_f': data['mkpts1_c'],
76
+ })
77
+ return
78
+
79
+ if self.mtd_spvs:
80
+ with torch.autocast(enabled=False, device_type='cuda'):
81
+ # feat_0 = feat_0 / feat_0.size(-2)
82
+ if self.local_regress_slice:
83
+ feat_ff0, feat_ff1 = feat_0[...,-self.local_regress_slicedim:], feat_1[...,-self.local_regress_slicedim:]
84
+ feat_f0, feat_f1 = feat_0[...,:-self.local_regress_slicedim], feat_1[...,:-self.local_regress_slicedim]
85
+ conf_matrix_ff = torch.einsum('mlc,mrc->mlr', feat_ff0, feat_ff1 / (self.local_regress_slicedim)**.5)
86
+ else:
87
+ feat_f0, feat_f1 = feat_0, feat_1
88
+ if self.normfinem:
89
+ feat_f0 = feat_f0 / C**.5
90
+ feat_f1 = feat_f1 / C**.5
91
+ conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1)
92
+ else:
93
+ if self.local_regress_slice:
94
+ conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1 / (C - self.local_regress_slicedim)**.5)
95
+ else:
96
+ conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1 / C**.5)
97
+
98
+ if self.normfeat:
99
+ feat_f0, feat_f1 = torch.nn.functional.normalize(feat_f0.float(), p=2, dim=-1), torch.nn.functional.normalize(feat_f1.float(), p=2, dim=-1)
100
+
101
+ if self.config['fp16log']:
102
+ logger.info(f'sim_matrix: {conf_matrix_f.abs().max()}')
103
+ # sim_matrix *= 1. / C**.5 # normalize
104
+
105
+ if self.multi_regress:
106
+ assert not self.local_regress
107
+ assert not self.normfinem and not self.normfeat
108
+ heatmap = F.softmax(conf_matrix_f, 2).view(M, WW, W, W) # [M, WW, W, W]
109
+
110
+ assert (W - 2) == (self.config['resolution'][0] // self.config['resolution'][1]) # c8
111
+ windows_scale = (W - 1) / (self.config['resolution'][0] // self.config['resolution'][1])
112
+
113
+ coords_normalized = dsnt.spatial_expectation2d(heatmap, True) * windows_scale # [M, WW, 2]
114
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2)[:,None,:,:] * windows_scale # [1, 1, WW, 2]
115
+
116
+ # compute std over <x, y>
117
+ var = torch.sum(grid_normalized**2 * heatmap.view(M, WW, WW, 1), dim=-2) - coords_normalized**2 # ([1,1,WW,2] * [M,WW,WW,1])->[M,WW,2]
118
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M,WW] clamp needed for numerical stability
119
+
120
+ # for fine-level supervision
121
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(-1)], -1)}) # [M, WW, 2]
122
+
123
+ # get the least uncertain matches
124
+ val, idx = torch.topk(std, self.topk, dim=-1, largest=False) # [M,topk]
125
+ coords_normalized = coords_normalized[torch.arange(M, device=conf_matrix_f.device, dtype=torch.long)[:,None], idx] # [M,topk]
126
+
127
+ grid = create_meshgrid(W, W, False, idx.device) - W // 2 + 0.5 # [1, W, W, 2]
128
+ grid = grid.reshape(1, -1, 2).expand(M, -1, -1) # [M, WW, 2]
129
+ delta_l = torch.gather(grid, 1, idx.unsqueeze(-1).expand(-1, -1, 2)) # [M, topk, 2] in (x, y)
130
+
131
+ # compute absolute kpt coords
132
+ self.get_multi_fine_match_align(delta_l, coords_normalized, data)
133
+
134
+
135
+ else:
136
+
137
+ if self.skip_fine_softmax:
138
+ pass
139
+ elif self.use_sigmoid:
140
+ conf_matrix_f = torch.sigmoid(conf_matrix_f)
141
+ else:
142
+ if self.local_regress:
143
+ del feat_f0, feat_f1
144
+ softmax_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2)
145
+ # softmax_matrix_f = conf_matrix_f
146
+ if self.local_regress_inner:
147
+ softmax_matrix_f = softmax_matrix_f.reshape(M, self.WW, self.W+2, self.W+2)
148
+ softmax_matrix_f = softmax_matrix_f[...,1:-1,1:-1].reshape(M, self.WW, self.WW)
149
+ # if self.training:
150
+ # for fine-level supervision
151
+ data.update({'conf_matrix_f': softmax_matrix_f})
152
+ if self.local_regress_slice:
153
+ data.update({'sim_matrix_ff': conf_matrix_ff})
154
+ else:
155
+ data.update({'sim_matrix_f': conf_matrix_f})
156
+
157
+ else:
158
+ conf_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2)
159
+
160
+ # for fine-level supervision
161
+ data.update({'conf_matrix_f': conf_matrix_f})
162
+
163
+ # compute absolute kpt coords
164
+ if self.local_regress:
165
+ self.get_fine_ds_match(softmax_matrix_f, data)
166
+ del softmax_matrix_f
167
+ idx_l, idx_r = data['idx_l'], data['idx_r']
168
+ del data['idx_l'], data['idx_r']
169
+ m_ids = torch.arange(M, device=idx_l.device, dtype=torch.long).unsqueeze(-1).expand(-1, self.topk)
170
+ # if self.training:
171
+ m_ids = m_ids[:len(data['mconf']) // self.topk]
172
+ idx_r_iids, idx_r_jids = idx_r // W, idx_r % W
173
+
174
+ # remove boarder
175
+ if self.local_regress_nomask:
176
+ # log for inner precent
177
+ # mask = (idx_r_iids >= 1) & (idx_r_iids <= W-2) & (idx_r_jids >= 1) & (idx_r_jids <= W-2)
178
+ # mask_sum = mask.sum()
179
+ # logger.info(f'total fine match: {mask.numel()}; regressed fine match: {mask_sum}, per: {mask_sum / mask.numel()}')
180
+ mask = None
181
+ m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1)
182
+ if self.local_regress_inner: # been sliced before
183
+ delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) # [1, 3, 3, 2]
184
+ else:
185
+ # no mask + 1 for padding
186
+ delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) + torch.tensor([1], dtype=torch.long, device=conf_matrix_f.device) # [1, 3, 3, 2]
187
+
188
+ m_ids = m_ids[...,None,None].expand(-1, 3, 3)
189
+ idx_l = idx_l[...,None,None].expand(-1, 3, 3) # [m, k, 3, 3]
190
+
191
+ idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1]
192
+ idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0]
193
+
194
+ if idx_l.numel() == 0:
195
+ data.update({
196
+ 'mkpts0_f': data['mkpts0_c'],
197
+ 'mkpts1_f': data['mkpts1_c'],
198
+ })
199
+ return
200
+
201
+ if self.local_regress_slice:
202
+ conf_matrix_f = conf_matrix_ff
203
+ if self.local_regress_inner:
204
+ conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W+2, self.W+2)
205
+ else:
206
+ conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W, self.W)
207
+ conf_matrix_f = F.pad(conf_matrix_f, (1,1,1,1))
208
+ else:
209
+ mask = (idx_r_iids >= 1) & (idx_r_iids <= W-2) & (idx_r_jids >= 1) & (idx_r_jids <= W-2)
210
+ if W == 10:
211
+ idx_l_iids, idx_l_jids = idx_l // W, idx_l % W
212
+ mask = mask & (idx_l_iids >= 1) & (idx_l_iids <= W-2) & (idx_l_jids >= 1) & (idx_l_jids <= W-2)
213
+
214
+ m_ids = m_ids[mask].to(torch.long)
215
+ idx_l, idx_r_iids, idx_r_jids = idx_l[mask].to(torch.long), idx_r_iids[mask].to(torch.long), idx_r_jids[mask].to(torch.long)
216
+
217
+ m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1)
218
+ mask = mask.reshape(-1)
219
+
220
+ delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) # [1, 3, 3, 2]
221
+
222
+ m_ids = m_ids[:,None,None].expand(-1, 3, 3)
223
+ idx_l = idx_l[:,None,None].expand(-1, 3, 3) # [m, 3, 3]
224
+ # bug !!!!!!!!! 1,0 rather 0,1
225
+ # idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0]
226
+ # idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1]
227
+ idx_r_iids = idx_r_iids[:,None,None].expand(-1, 3, 3) + delta[..., 1]
228
+ idx_r_jids = idx_r_jids[:,None,None].expand(-1, 3, 3) + delta[..., 0]
229
+
230
+ if idx_l.numel() == 0:
231
+ data.update({
232
+ 'mkpts0_f': data['mkpts0_c'],
233
+ 'mkpts1_f': data['mkpts1_c'],
234
+ })
235
+ return
236
+ if not self.local_regress_slice:
237
+ conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W, self.W)
238
+ else:
239
+ conf_matrix_f = conf_matrix_ff.reshape(M, self.WW, self.W, self.W)
240
+
241
+ conf_matrix_f = conf_matrix_f[m_ids, idx_l, idx_r_iids, idx_r_jids]
242
+ conf_matrix_f = conf_matrix_f.reshape(-1, 9)
243
+ if self.local_regress_padone: # follow the training detach the gradient of center
244
+ conf_matrix_f[:,4] = -1e4
245
+ heatmap = F.softmax(conf_matrix_f / self.local_regress_temperature, -1)
246
+ logger.info(f'maxmax&maxmean of heatmap: {heatmap.view(-1).max()}, {heatmap.view(-1).min(), heatmap.max(-1)[0].mean()}')
247
+ heatmap[:,4] = 1.0 # no need gradient calculation in inference
248
+ logger.info(f'min of heatmap: {heatmap.view(-1).min()}')
249
+ heatmap = heatmap.reshape(-1, 3, 3)
250
+ # heatmap = torch.ones_like(softmax) # ones_like for detach the gradient of center
251
+ # heatmap[:,:4], heatmap[:,5:] = softmax[:,:4], softmax[:,5:]
252
+ # heatmap = heatmap.reshape(-1, 3, 3)
253
+ else:
254
+ conf_matrix_f = F.softmax(conf_matrix_f / self.local_regress_temperature, -1)
255
+ # logger.info(f'max&min&mean of heatmap: {conf_matrix_f.view(-1).max()}, {conf_matrix_f.view(-1).min(), conf_matrix_f.max(-1)[0].mean()}')
256
+ heatmap = conf_matrix_f.reshape(-1, 3, 3)
257
+
258
+ # compute coordinates from heatmap
259
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
260
+
261
+ # coords_normalized_l2 = coords_normalized.norm(p=2, dim=-1)
262
+ # logger.info(f'mean&max&min abs of local: {coords_normalized_l2.mean(), coords_normalized_l2.max(), coords_normalized_l2.min()}')
263
+
264
+ # compute absolute kpt coords
265
+
266
+ if data['bs'] == 1:
267
+ scale1 = scale * data['scale1'] if 'scale0' in data else scale
268
+ else:
269
+ if mask is not None:
270
+ scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']) // self.topk,...][:,None,:].expand(-1, self.topk, 2).reshape(-1, 2)[mask] if 'scale0' in data else scale
271
+ else:
272
+ scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']) // self.topk,...][:,None,:].expand(-1, self.topk, 2).reshape(-1, 2) if 'scale0' in data else scale
273
+
274
+ self.get_fine_match_local(coords_normalized, data, scale1, mask, True)
275
+
276
+ else:
277
+ self.get_fine_ds_match(conf_matrix_f, data)
278
+
279
+
280
+ else:
281
+ if self.align_corner is True:
282
+ feat_f0, feat_f1 = feat_0, feat_1
283
+ feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
284
+ sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
285
+ softmax_temp = 1. / C**.5
286
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
287
+
288
+ # compute coordinates from heatmap
289
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
290
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
291
+
292
+ # compute std over <x, y>
293
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
294
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
295
+
296
+ # for fine-level supervision
297
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
298
+
299
+ # compute absolute kpt coords
300
+ self.get_fine_match(coords_normalized, data)
301
+ else:
302
+ feat_f0, feat_f1 = feat_0, feat_1
303
+ # even matching windows while coarse grid not aligned to fine grid!!!
304
+ # assert W == 5, "others size not checked"
305
+ if self.fix_bias:
306
+ assert W % 2 == 1, "W must be odd when select"
307
+ feat_f0_picked = feat_f0[:, WW//2]
308
+
309
+ else:
310
+ # assert W == 6, "others size not checked"
311
+ assert W % 2 == 0, "W must be even when coarse grid not aligned to fine grid(average)"
312
+ feat_f0_picked = (feat_f0[:, WW//2 - W//2 - 1] + feat_f0[:, WW//2 - W//2] + feat_f0[:, WW//2 + W//2] + feat_f0[:, WW//2 + W//2 - 1]) / 4
313
+ sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
314
+ softmax_temp = 1. / C**.5
315
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
316
+
317
+ # compute coordinates from heatmap
318
+ windows_scale = (W - 1) / (self.config['resolution'][0] // self.config['resolution'][1])
319
+
320
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] * windows_scale # [M, 2]
321
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) * windows_scale # [1, WW, 2]
322
+
323
+ # compute std over <x, y>
324
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
325
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
326
+
327
+ # for fine-level supervision
328
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
329
+
330
+ # compute absolute kpt coords
331
+ self.get_fine_match_align(coords_normalized, data)
332
+
333
+
334
+ @torch.no_grad()
335
+ def get_fine_match(self, coords_normed, data):
336
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
337
+
338
+ # mkpts0_f and mkpts1_f
339
+ mkpts0_f = data['mkpts0_c']
340
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
341
+ mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
342
+
343
+ data.update({
344
+ "mkpts0_f": mkpts0_f,
345
+ "mkpts1_f": mkpts1_f
346
+ })
347
+
348
+ def get_fine_match_local(self, coords_normed, data, scale1, mask, reserve_border=True):
349
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
350
+
351
+ if mask is None:
352
+ mkpts0_c, mkpts1_c = data['mkpts0_c'], data['mkpts1_c']
353
+ else:
354
+ data['mkpts0_c'], data['mkpts1_c'] = data['mkpts0_c'].reshape(-1, 2), data['mkpts1_c'].reshape(-1, 2)
355
+ mkpts0_c, mkpts1_c = data['mkpts0_c'][mask], data['mkpts1_c'][mask]
356
+ mask_sum = mask.sum()
357
+ logger.info(f'total fine match: {mask.numel()}; regressed fine match: {mask_sum}, per: {mask_sum / mask.numel()}')
358
+ # print(mkpts0_c.shape, mkpts1_c.shape, coords_normed.shape, scale1.shape)
359
+ # print(data['mkpts0_c'].shape, data['mkpts1_c'].shape)
360
+ # mkpts0_f and mkpts1_f
361
+ mkpts0_f = mkpts0_c
362
+ mkpts1_f = mkpts1_c + (coords_normed * (3 // 2) * scale1)
363
+
364
+ if reserve_border and mask is not None:
365
+ mkpts0_f, mkpts1_f = torch.cat([mkpts0_f, data['mkpts0_c'][~mask].reshape(-1, 2)]), torch.cat([mkpts1_f, data['mkpts1_c'][~mask].reshape(-1, 2)])
366
+ else:
367
+ pass
368
+
369
+ del data['mkpts0_c'], data['mkpts1_c']
370
+ data.update({
371
+ "mkpts0_f": mkpts0_f,
372
+ "mkpts1_f": mkpts1_f
373
+ })
374
+
375
+ # can be used for both aligned and not aligned
376
+ @torch.no_grad()
377
+ def get_fine_match_align(self, coord_normed, data):
378
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
379
+ c2f = self.config['resolution'][0] // self.config['resolution'][1]
380
+ # mkpts0_f and mkpts1_f
381
+ mkpts0_f = data['mkpts0_c']
382
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
383
+ mkpts1_f = data['mkpts1_c'] + (coord_normed * (c2f // 2) * scale1)[:len(data['mconf'])]
384
+
385
+ data.update({
386
+ "mkpts0_f": mkpts0_f,
387
+ "mkpts1_f": mkpts1_f
388
+ })
389
+
390
+ @torch.no_grad()
391
+ def get_multi_fine_match_align(self, delta_l, coord_normed, data):
392
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
393
+ c2f = self.config['resolution'][0] // self.config['resolution'][1]
394
+ # mkpts0_f and mkpts1_f
395
+ scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else torch.tensor([[scale, scale]], device=delta_l.device)
396
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else torch.tensor([[scale, scale]], device=delta_l.device)
397
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
398
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (coord_normed * (c2f // 2) * scale1[:,None,:])[:len(data['mconf'])]).reshape(-1, 2)
399
+
400
+ data.update({
401
+ "mkpts0_f": mkpts0_f,
402
+ "mkpts1_f": mkpts1_f,
403
+ "mconf": data['mconf'][:,None].expand(-1, self.topk).reshape(-1)
404
+ })
405
+
406
+ @torch.no_grad()
407
+ def get_fine_ds_match(self, conf_matrix, data):
408
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
409
+
410
+ # select topk matches
411
+ m, _, _ = conf_matrix.shape
412
+
413
+
414
+ if self.mutual_nearest:
415
+ pass
416
+
417
+
418
+ elif not self.fix_fine_matching: # only allow one2mul but mul2one
419
+
420
+ val, idx_r = conf_matrix.max(-1) # (m, WW), (m, WW)
421
+ val, idx_l = torch.topk(val, self.topk, dim = -1) # (m, topk), (m, topk)
422
+ idx_r = torch.gather(idx_r, 1, idx_l) # (m, topk)
423
+
424
+ # mkpts0_c use xy coordinate, so we don't need to convert it to hw coordinate
425
+ # grid = create_meshgrid(W, W, False, conf_matrix.device).transpose(-3,-2) - W // 2 + 0.5 # (1, W, W, 2)
426
+ grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5 # (1, W, W, 2)
427
+ grid = grid.reshape(1, -1, 2).expand(m, -1, -1) # (m, WW, 2)
428
+ delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2)) # (m, topk, 2)
429
+ delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2)) # (m, topk, 2)
430
+
431
+ # mkpts0_f and mkpts1_f
432
+ scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
433
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
434
+
435
+ if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1
436
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
437
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
438
+ else: # scale0 is a float
439
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)[:len(data['mconf']),...]).reshape(-1, 2)
440
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)[:len(data['mconf']),...]).reshape(-1, 2)
441
+
442
+ else: # allow one2mul mul2one and mul2mul
443
+ conf_matrix = conf_matrix.reshape(m, -1)
444
+ if self.local_regress: # for the compatibility of former config
445
+ conf_matrix = conf_matrix[:len(data['mconf']),...]
446
+ val, idx = torch.topk(conf_matrix, self.topk, dim = -1)
447
+ idx_l = idx // WW
448
+ idx_r = idx % WW
449
+
450
+ if self.local_regress:
451
+ data.update({'idx_l': idx_l, 'idx_r': idx_r})
452
+
453
+ # mkpts0_c use xy coordinate, so we don't need to convert it to hw coordinate
454
+ # grid = create_meshgrid(W, W, False, conf_matrix.device).transpose(-3,-2) - W // 2 + 0.5 # (1, W, W, 2)
455
+ grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5
456
+ grid = grid.reshape(1, -1, 2).expand(m, -1, -1)
457
+ delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2))
458
+ delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2))
459
+
460
+ # mkpts0_f and mkpts1_f
461
+ scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
462
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
463
+
464
+ if self.local_regress:
465
+ if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1
466
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
467
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
468
+ else: # scale0 is a float
469
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)).reshape(-1, 2)
470
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)).reshape(-1, 2)
471
+
472
+ else:
473
+ if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1
474
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
475
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
476
+ else: # scale0 is a float
477
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)[:len(data['mconf']),...]).reshape(-1, 2)
478
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)[:len(data['mconf']),...]).reshape(-1, 2)
479
+ del data['mkpts0_c'], data['mkpts1_c']
480
+ data['mconf'] = data['mconf'].reshape(-1, 1).expand(-1, self.topk).reshape(-1)
481
+ # data['mconf'] = val.reshape(-1)[:len(data['mconf'])]*0.1 + data['mconf']
482
+
483
+ if self.local_regress:
484
+ data.update({
485
+ "mkpts0_c": mkpts0_f,
486
+ "mkpts1_c": mkpts1_f
487
+ })
488
+ else:
489
+ data.update({
490
+ "mkpts0_f": mkpts0_f,
491
+ "mkpts1_f": mkpts1_f
492
+ })
493
+
imcui/third_party/MatchAnything/src/loftr/utils/geometry.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.utils.homography_utils import warp_points_torch
3
+
4
+ def get_unique_indices(input_tensor):
5
+ if input_tensor.shape[0] > 1:
6
+ unique, inverse = torch.unique(input_tensor, sorted=True, return_inverse=True, dim=0)
7
+ perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
8
+ inverse, perm = inverse.flip([0]), perm.flip([0])
9
+ perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
10
+ else:
11
+ perm = torch.zeros((input_tensor.shape[0],), dtype=torch.long, device=input_tensor.device)
12
+ return perm
13
+
14
+
15
+ @torch.no_grad()
16
+ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, consistency_thr=0.2, cycle_proj_distance_thr=3.0):
17
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
18
+ Also check covisibility and depth consistency.
19
+ Depth is consistent if relative error < 0.2 (hard-coded).
20
+
21
+ Args:
22
+ kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
23
+ depth0 (torch.Tensor): [N, H, W],
24
+ depth1 (torch.Tensor): [N, H, W],
25
+ T_0to1 (torch.Tensor): [N, 3, 4],
26
+ K0 (torch.Tensor): [N, 3, 3],
27
+ K1 (torch.Tensor): [N, 3, 3],
28
+ Returns:
29
+ calculable_mask (torch.Tensor): [N, L]
30
+ warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
31
+ """
32
+ kpts0_long = kpts0.round().long()
33
+
34
+ # Sample depth, get calculable_mask on depth != 0
35
+ kpts0_depth = torch.stack(
36
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
37
+ ) # (N, L)
38
+ nonzero_mask = kpts0_depth != 0
39
+
40
+ # Unproject
41
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
42
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
43
+
44
+ # Rigid Transform
45
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
46
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
47
+
48
+ # Project
49
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
50
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
51
+
52
+ # Covisible Check
53
+ h, w = depth1.shape[1:3]
54
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
55
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
56
+ w_kpts0_long = w_kpts0.long()
57
+ w_kpts0_long[~covisible_mask, :] = 0
58
+
59
+ w_kpts0_depth = torch.stack(
60
+ [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
61
+ ) # (N, L)
62
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < consistency_thr
63
+
64
+ # Cycle Consistency Check
65
+ dst_pts_h = torch.cat([w_kpts0, torch.ones_like(w_kpts0[..., [0]], device=w_kpts0.device)], dim=-1) * w_kpts0_depth[..., None] # B * N_dst * N_pts * 3
66
+ dst_pts_cam = K1.inverse() @ dst_pts_h.transpose(2, 1) # (N, 3, L)
67
+ dst_pose = T_0to1.inverse()
68
+ world_points_cycle_back = dst_pose[:, :3, :3] @ dst_pts_cam + dst_pose[:, :3, [3]]
69
+ src_warp_back_h = (K0 @ world_points_cycle_back).transpose(2, 1) # (N, L, 3)
70
+ src_back_proj_pts = src_warp_back_h[..., :2] / (src_warp_back_h[..., [2]] + 1e-4)
71
+ cycle_reproj_distance_mask = torch.linalg.norm(src_back_proj_pts - kpts0[:, None], dim=-1) < cycle_proj_distance_thr
72
+
73
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask * cycle_reproj_distance_mask
74
+
75
+ return valid_mask, w_kpts0
76
+
77
+ @torch.no_grad()
78
+ def warp_kpts_by_sparse_gt_matches_batches(kpts0, gt_matches, dist_thr):
79
+ B, n_pts = kpts0.shape[0], kpts0.shape[1]
80
+ if n_pts > 20 * 10000:
81
+ all_kpts_valid_mask, all_kpts_warpped = [], []
82
+ for b_id in range(B):
83
+ kpts_valid_mask, kpts_warpped = warp_kpts_by_sparse_gt_matches(kpts0[[b_id]], gt_matches[[b_id]], dist_thr[[b_id]])
84
+ all_kpts_valid_mask.append(kpts_valid_mask)
85
+ all_kpts_warpped.append(kpts_warpped)
86
+ return torch.cat(all_kpts_valid_mask, dim=0), torch.cat(all_kpts_warpped, dim=0)
87
+ else:
88
+ return warp_kpts_by_sparse_gt_matches(kpts0, gt_matches, dist_thr)
89
+
90
+ @torch.no_grad()
91
+ def warp_kpts_by_sparse_gt_matches(kpts0, gt_matches, dist_thr):
92
+ kpts_warpped = torch.zeros_like(kpts0)
93
+ kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool)
94
+ gt_matches_non_padding_mask = gt_matches.sum(-1) > 0
95
+
96
+ dist_matrix = torch.cdist(kpts0, gt_matches[..., :2]) # B * N * M
97
+ if dist_thr is not None:
98
+ mask = dist_matrix < dist_thr[:, None, None]
99
+ else:
100
+ mask = torch.ones_like(dist_matrix, dtype=torch.bool)
101
+ # Mutual-Nearest check:
102
+ mask = mask \
103
+ * (dist_matrix == dist_matrix.min(dim=2, keepdim=True)[0]) \
104
+ * (dist_matrix == dist_matrix.min(dim=1, keepdim=True)[0])
105
+
106
+ mask_v, all_j_ids = mask.max(dim=2)
107
+ b_ids, i_ids = torch.where(mask_v)
108
+ j_ids = all_j_ids[b_ids, i_ids]
109
+
110
+ j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1))
111
+ b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids])
112
+
113
+ i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1))
114
+ b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids])
115
+
116
+ kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[b_ids, j_ids]
117
+ kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][b_ids, j_ids]
118
+
119
+ return kpts_valid_mask, kpts_warpped
120
+
121
+ @torch.no_grad()
122
+ def warp_kpts_by_sparse_gt_matches_fine_chunks(kpts0, gt_matches, dist_thr):
123
+ B, n_pts = kpts0.shape[0], kpts0.shape[1]
124
+ chunk_n = 500
125
+ all_kpts_valid_mask, all_kpts_warpped = [], []
126
+ for b_id in range(0, B, chunk_n):
127
+ kpts_valid_mask, kpts_warpped = warp_kpts_by_sparse_gt_matches_fine(kpts0[b_id : b_id+chunk_n], gt_matches, dist_thr)
128
+ all_kpts_valid_mask.append(kpts_valid_mask)
129
+ all_kpts_warpped.append(kpts_warpped)
130
+ return torch.cat(all_kpts_valid_mask, dim=0), torch.cat(all_kpts_warpped, dim=0)
131
+
132
+ @torch.no_grad()
133
+ def warp_kpts_by_sparse_gt_matches_fine(kpts0, gt_matches, dist_thr):
134
+ """
135
+ Only support single batch
136
+ Input:
137
+ kpts0: N * ww * 2
138
+ gt_matches: M * 2
139
+ """
140
+ B = kpts0.shape[0] # B is the fine matches in a single pair
141
+ assert gt_matches.shape[0] == 1
142
+ kpts_warpped = torch.zeros_like(kpts0)
143
+ kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool)
144
+ gt_matches_non_padding_mask = gt_matches.sum(-1) > 0
145
+
146
+ dist_matrix = torch.cdist(kpts0, gt_matches[..., :2]) # B * N * M
147
+ if dist_thr is not None:
148
+ mask = dist_matrix < dist_thr[:, None, None]
149
+ else:
150
+ mask = torch.ones_like(dist_matrix, dtype=torch.bool)
151
+ # Mutual-Nearest check:
152
+ mask = mask \
153
+ * (dist_matrix == dist_matrix.min(dim=2, keepdim=True)[0]) \
154
+ * (dist_matrix == dist_matrix.min(dim=1, keepdim=True)[0])
155
+
156
+ mask_v, all_j_ids = mask.max(dim=2)
157
+ b_ids, i_ids = torch.where(mask_v)
158
+ j_ids = all_j_ids[b_ids, i_ids]
159
+
160
+ j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1))
161
+ b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids])
162
+
163
+ i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1))
164
+ b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids])
165
+
166
+ kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[0, j_ids]
167
+ kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][0, j_ids]
168
+
169
+ return kpts_valid_mask, kpts_warpped
170
+
171
+ @torch.no_grad()
172
+ def warp_kpts_by_sparse_gt_matches_fast(kpts0, gt_matches, scale0, current_h, current_w):
173
+ B, n_gt_pts = gt_matches.shape[0], gt_matches.shape[1]
174
+ kpts_warpped = torch.zeros_like(kpts0)
175
+ kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool)
176
+ gt_matches_non_padding_mask = gt_matches.sum(-1) > 0
177
+
178
+ all_j_idxs = torch.arange(gt_matches.shape[-2], device=gt_matches.device, dtype=torch.long)[None].expand(B, n_gt_pts)
179
+ all_b_idxs = torch.arange(B, device=gt_matches.device, dtype=torch.long)[:, None].expand(B, n_gt_pts)
180
+ gt_matches_rescale = gt_matches[..., :2] / scale0 # From original img scale to resized scale
181
+ in_boundary_mask = (gt_matches_rescale[..., 0] <= current_w-1) & (gt_matches_rescale[..., 0] >= 0) & (gt_matches_rescale[..., 1] <= current_h -1) & (gt_matches_rescale[..., 1] >= 0)
182
+
183
+ gt_matches_rescale = gt_matches_rescale.round().to(torch.long)
184
+ all_i_idxs = gt_matches_rescale[..., 1] * current_w + gt_matches_rescale[..., 0] # idx = y * w + x
185
+
186
+ # Filter:
187
+ b_ids, i_ids, j_ids = map(lambda x: x[gt_matches_non_padding_mask & in_boundary_mask], [all_b_idxs, all_i_idxs, all_j_idxs])
188
+
189
+ j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1))
190
+ b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids])
191
+
192
+ i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1))
193
+ b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids])
194
+
195
+ kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[b_ids, j_ids]
196
+ kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][b_ids, j_ids]
197
+
198
+ return kpts_valid_mask, kpts_warpped
199
+
200
+
201
+ @torch.no_grad()
202
+ def homo_warp_kpts(kpts0, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None):
203
+ """
204
+ original_size1: N * 2, (h, w)
205
+ """
206
+ normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L)
207
+ kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3)
208
+ kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2
209
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
210
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) # N * L
211
+ if original_size0 is not None:
212
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
213
+ & (kpts0[..., 1] < original_size0[:, [0]]) # N * L
214
+
215
+ return valid_mask, kpts_warpped
216
+
217
+ @torch.no_grad()
218
+ # if using mask in homo warp(for coarse supervision)
219
+ def homo_warp_kpts_with_mask(kpts0, scale, depth_mask, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None):
220
+ """
221
+ original_size1: N * 2, (h, w)
222
+ """
223
+ normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L)
224
+ kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3)
225
+ kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2
226
+ # get coarse-level depth_mask
227
+ depth_mask_coarse = depth_mask[:, :, ::scale, ::scale]
228
+ depth_mask_coarse = depth_mask_coarse.reshape(depth_mask.shape[0], -1)
229
+
230
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
231
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask_coarse != 0) # N * L
232
+ if original_size0 is not None:
233
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
234
+ & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask_coarse != 0) # N * L
235
+
236
+ return valid_mask, kpts_warpped
237
+
238
+ @torch.no_grad()
239
+ # if using mask in homo warp(for fine supervision)
240
+ def homo_warp_kpts_with_mask_f(kpts0, depth_mask, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None):
241
+ """
242
+ original_size1: N * 2, (h, w)
243
+ """
244
+ normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L)
245
+ kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3)
246
+ kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2
247
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
248
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask != 0) # N * L
249
+ if original_size0 is not None:
250
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
251
+ & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask != 0) # N * L
252
+
253
+ return valid_mask, kpts_warpped
254
+
255
+ @torch.no_grad()
256
+ def homo_warp_kpts_glue(kpts0, homo, original_size0=None, original_size1=None):
257
+ """
258
+ original_size1: N * 2, (h, w)
259
+ """
260
+ kpts_warpped = warp_points_torch(kpts0, homo, inverse=False)
261
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
262
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) # N * L
263
+ if original_size0 is not None:
264
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
265
+ & (kpts0[..., 1] < original_size0[:, [0]]) # N * L
266
+ return valid_mask, kpts_warpped
267
+
268
+ @torch.no_grad()
269
+ # if using mask in homo warp(for coarse supervision)
270
+ def homo_warp_kpts_glue_with_mask(kpts0, scale, depth_mask, homo, original_size0=None, original_size1=None):
271
+ """
272
+ original_size1: N * 2, (h, w)
273
+ """
274
+ kpts_warpped = warp_points_torch(kpts0, homo, inverse=False)
275
+ # get coarse-level depth_mask
276
+ depth_mask_coarse = depth_mask[:, :, ::scale, ::scale]
277
+ depth_mask_coarse = depth_mask_coarse.reshape(depth_mask.shape[0], -1)
278
+
279
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
280
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask_coarse != 0) # N * L
281
+ if original_size0 is not None:
282
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
283
+ & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask_coarse != 0) # N * L
284
+ return valid_mask, kpts_warpped
285
+
286
+ @torch.no_grad()
287
+ # if using mask in homo warp(for fine supervision)
288
+ def homo_warp_kpts_glue_with_mask_f(kpts0, depth_mask, homo, original_size0=None, original_size1=None):
289
+ """
290
+ original_size1: N * 2, (h, w)
291
+ """
292
+ kpts_warpped = warp_points_torch(kpts0, homo, inverse=False)
293
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
294
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask != 0) # N * L
295
+ if original_size0 is not None:
296
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
297
+ & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask != 0) # N * L
298
+ return valid_mask, kpts_warpped
imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class PositionEncodingSine(nn.Module):
7
+ """
8
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
9
+ """
10
+
11
+ def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True, npe=False):
12
+ """
13
+ Args:
14
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
15
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
16
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
17
+ on the final performance. For now, we keep both impls for backward compatability.
18
+ We will remove the buggy impl after re-training all variants of our released models.
19
+ """
20
+ super().__init__()
21
+
22
+ pe = torch.zeros((d_model, *max_shape))
23
+ y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
24
+ x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
25
+
26
+ assert npe is not None
27
+ if npe is not None:
28
+ if isinstance(npe, bool):
29
+ train_res_H, train_res_W, test_res_H, test_res_W = 832, 832, 832, 832
30
+ print('loftr no npe!!!!', npe)
31
+ else:
32
+ print('absnpe!!!!', npe)
33
+ train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W
34
+ y_position, x_position = y_position * train_res_H / test_res_H, x_position * train_res_W / test_res_W
35
+
36
+ if temp_bug_fix:
37
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
38
+ else: # a buggy implementation (for backward compatability only)
39
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
40
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
41
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
42
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
43
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
44
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
45
+
46
+ self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
47
+
48
+ def forward(self, x):
49
+ """
50
+ Args:
51
+ x: [N, C, H, W]
52
+ """
53
+ return x + self.pe[:, :, :x.size(2), :x.size(3)]
54
+
55
+ class RoPEPositionEncodingSine(nn.Module):
56
+ """
57
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
58
+ """
59
+
60
+ def __init__(self, d_model, max_shape=(256, 256), npe=None, ropefp16=True):
61
+ """
62
+ Args:
63
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
64
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
65
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
66
+ on the final performance. For now, we keep both impls for backward compatability.
67
+ We will remove the buggy impl after re-training all variants of our released models.
68
+ """
69
+ super().__init__()
70
+
71
+ # pe = torch.zeros((d_model, *max_shape))
72
+ # y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1)
73
+ # x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1)
74
+ i_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1) # [H, 1]
75
+ j_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1) # [W, 1]
76
+
77
+ assert npe is not None
78
+ if npe is not None:
79
+ train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W
80
+ i_position, j_position = i_position * train_res_H / test_res_H, j_position * train_res_W / test_res_W
81
+
82
+ div_term = torch.exp(torch.arange(0, d_model//4, 1).float() * (-math.log(10000.0) / (d_model//4)))
83
+ div_term = div_term[None, None, :] # [1, 1, C//4]
84
+ # pe[0::4, :, :] = torch.sin(x_position * div_term)
85
+ # pe[1::4, :, :] = torch.cos(x_position * div_term)
86
+ # pe[2::4, :, :] = torch.sin(y_position * div_term)
87
+ # pe[3::4, :, :] = torch.cos(y_position * div_term)
88
+ sin = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
89
+ cos = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
90
+ sin[:, :, 0::2] = torch.sin(i_position * div_term).half() if ropefp16 else torch.sin(i_position * div_term)
91
+ sin[:, :, 1::2] = torch.sin(j_position * div_term).half() if ropefp16 else torch.sin(j_position * div_term)
92
+ cos[:, :, 0::2] = torch.cos(i_position * div_term).half() if ropefp16 else torch.cos(i_position * div_term)
93
+ cos[:, :, 1::2] = torch.cos(j_position * div_term).half() if ropefp16 else torch.cos(j_position * div_term)
94
+
95
+ sin = sin.repeat_interleave(2, dim=-1)
96
+ cos = cos.repeat_interleave(2, dim=-1)
97
+ # self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, H, W, C]
98
+ self.register_buffer('sin', sin.unsqueeze(0), persistent=False) # [1, H, W, C//2]
99
+ self.register_buffer('cos', cos.unsqueeze(0), persistent=False) # [1, H, W, C//2]
100
+
101
+ i_position4 = i_position.reshape(64,4,64,4,1)[...,0,:]
102
+ i_position4 = i_position4.mean(-3)
103
+ j_position4 = j_position.reshape(64,4,64,4,1)[:,0,...]
104
+ j_position4 = j_position4.mean(-2)
105
+ sin4 = torch.zeros(max_shape[0]//4, max_shape[1]//4, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
106
+ cos4 = torch.zeros(max_shape[0]//4, max_shape[1]//4, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
107
+ sin4[:, :, 0::2] = torch.sin(i_position4 * div_term).half() if ropefp16 else torch.sin(i_position4 * div_term)
108
+ sin4[:, :, 1::2] = torch.sin(j_position4 * div_term).half() if ropefp16 else torch.sin(j_position4 * div_term)
109
+ cos4[:, :, 0::2] = torch.cos(i_position4 * div_term).half() if ropefp16 else torch.cos(i_position4 * div_term)
110
+ cos4[:, :, 1::2] = torch.cos(j_position4 * div_term).half() if ropefp16 else torch.cos(j_position4 * div_term)
111
+ sin4 = sin4.repeat_interleave(2, dim=-1)
112
+ cos4 = cos4.repeat_interleave(2, dim=-1)
113
+ self.register_buffer('sin4', sin4.unsqueeze(0), persistent=False) # [1, H, W, C//2]
114
+ self.register_buffer('cos4', cos4.unsqueeze(0), persistent=False) # [1, H, W, C//2]
115
+
116
+
117
+
118
+ def forward(self, x, ratio=1):
119
+ """
120
+ Args:
121
+ x: [N, H, W, C]
122
+ """
123
+ if ratio == 4:
124
+ return (x * self.cos4[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin4[:, :x.size(1), :x.size(2), :])
125
+ else:
126
+ return (x * self.cos[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin[:, :x.size(1), :x.size(2), :])
127
+
128
+ def rotate_half(self, x):
129
+ x = x.unflatten(-1, (-1, 2))
130
+ x1, x2 = x.unbind(dim=-1)
131
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
imcui/third_party/MatchAnything/src/loftr/utils/supervision.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import log
2
+ from loguru import logger as loguru_logger
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+ from kornia.utils import create_meshgrid
8
+
9
+ from .geometry import warp_kpts, homo_warp_kpts, homo_warp_kpts_glue, homo_warp_kpts_with_mask, homo_warp_kpts_with_mask_f, homo_warp_kpts_glue_with_mask, homo_warp_kpts_glue_with_mask_f, warp_kpts_by_sparse_gt_matches_fast, warp_kpts_by_sparse_gt_matches_fine_chunks
10
+
11
+ from kornia.geometry.subpix import dsnt
12
+ from kornia.utils.grid import create_meshgrid
13
+
14
+ def static_vars(**kwargs):
15
+ def decorate(func):
16
+ for k in kwargs:
17
+ setattr(func, k, kwargs[k])
18
+ return func
19
+ return decorate
20
+
21
+ ############## ↓ Coarse-Level supervision ↓ ##############
22
+
23
+ @torch.no_grad()
24
+ def mask_pts_at_padded_regions(grid_pt, mask):
25
+ """For megadepth dataset, zero-padding exists in images"""
26
+ mask = repeat(mask, 'n h w -> n (h w) c', c=2)
27
+ grid_pt[~mask.bool()] = 0
28
+ return grid_pt
29
+
30
+
31
+ @torch.no_grad()
32
+ def spvs_coarse(data, config):
33
+ """
34
+ Update:
35
+ data (dict): {
36
+ "conf_matrix_gt": [N, hw0, hw1],
37
+ 'spv_b_ids': [M]
38
+ 'spv_i_ids': [M]
39
+ 'spv_j_ids': [M]
40
+ 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
41
+ 'spv_pt1_i': [N, hw1, 2], in original image resolution
42
+ }
43
+
44
+ NOTE:
45
+ - for scannet dataset, there're 3 kinds of resolution {i, c, f}
46
+ - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
47
+ """
48
+ # 1. misc
49
+ device = data['image0'].device
50
+ N, _, H0, W0 = data['image0'].shape
51
+ _, _, H1, W1 = data['image1'].shape
52
+
53
+ if 'loftr' in config.METHOD:
54
+ scale = config['LOFTR']['RESOLUTION'][0]
55
+
56
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
57
+ scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
58
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
59
+
60
+ if config['LOFTR']['MATCH_COARSE']['MTD_SPVS'] and not config['LOFTR']['FORCE_LOOP_BACK']:
61
+ # 2. warp grids
62
+ # create kpts in meshgrid and resize them to image resolution
63
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
64
+ grid_pt0_i = scale0 * grid_pt0_c
65
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
66
+ grid_pt1_i = scale1 * grid_pt1_c
67
+
68
+ correct_0to1 = torch.zeros((grid_pt0_i.shape[0], grid_pt0_i.shape[1]), dtype=torch.bool, device=grid_pt0_i.device)
69
+ w_pt0_i = torch.zeros_like(grid_pt0_i)
70
+
71
+ valid_dpt_b_mask = data['T_0to1'].sum(dim=-1).sum(dim=-1) != 0
72
+ valid_homo_warp_mask = (data['homography'].sum(dim=-1).sum(dim=-1) != 0) | (data['homo_sample_normed'].sum(dim=-1).sum(dim=-1) != 0)
73
+ valid_gt_match_warp_mask = (data['gt_matches_mask'][:, 0] != 0) # N
74
+
75
+ if valid_homo_warp_mask.sum() != 0:
76
+ if data['homography'].sum()==0:
77
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0): # the key 'depth_mask' only exits when using the dataste "CommonDataSetHomoWarp"
78
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_with_mask(grid_pt0_i[valid_homo_warp_mask], scale, data['homo_mask0'][valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
79
+ else:
80
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts(grid_pt0_i[valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], \
81
+ data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
82
+ else:
83
+ if 'homo_mask0' in data and (data['homo_mask0']==0).sum()!=0:
84
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue_with_mask(grid_pt0_i[valid_homo_warp_mask], scale, data['homo_mask0'][valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
85
+ else:
86
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue(grid_pt0_i[valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \
87
+ original_size1=data['origin_img_size1'][valid_homo_warp_mask])
88
+ correct_0to1[valid_homo_warp_mask] = correct_0to1_homo
89
+ w_pt0_i[valid_homo_warp_mask] = w_pt0_i_homo
90
+
91
+ if valid_gt_match_warp_mask.sum() != 0:
92
+ correct_0to1_dpt, w_pt0_i_dpt = warp_kpts_by_sparse_gt_matches_fast(grid_pt0_i[valid_gt_match_warp_mask], data['gt_matches'][valid_gt_match_warp_mask], scale0=scale0[valid_gt_match_warp_mask], current_h=h0, current_w=w0)
93
+ correct_0to1[valid_gt_match_warp_mask] = correct_0to1_dpt
94
+ w_pt0_i[valid_gt_match_warp_mask] = w_pt0_i_dpt
95
+
96
+ if valid_dpt_b_mask.sum() != 0:
97
+ correct_0to1_dpt, w_pt0_i_dpt = warp_kpts(grid_pt0_i[valid_dpt_b_mask], data['depth0'][valid_dpt_b_mask], data['depth1'][valid_dpt_b_mask], data['T_0to1'][valid_dpt_b_mask], data['K0'][valid_dpt_b_mask], data['K1'][valid_dpt_b_mask], consistency_thr=0.05)
98
+ correct_0to1[valid_dpt_b_mask] = correct_0to1_dpt
99
+ w_pt0_i[valid_dpt_b_mask] = w_pt0_i_dpt
100
+
101
+ w_pt0_c = w_pt0_i / scale1
102
+
103
+ # 3. check if mutual nearest neighbor
104
+ w_pt0_c_round = w_pt0_c[:, :, :].round() # [N, hw, 2]
105
+ if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT:
106
+ w_pt0_c_error = (1.0 - 2*torch.abs(w_pt0_c - w_pt0_c_round)).prod(-1)
107
+ w_pt0_c_round = w_pt0_c_round.long() # [N, hw, 2]
108
+ nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 # [N, hw]
109
+
110
+ # corner case: out of boundary
111
+ def out_bound_mask(pt, w, h):
112
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
113
+ nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = -1
114
+
115
+ correct_0to1[:, 0] = False # ignore the top-left corner
116
+
117
+ # 4. construct a gt conf_matrix
118
+ mask1 = torch.stack([data['mask1'].reshape(-1, h1*w1)[_b, _i] for _b, _i in enumerate(nearest_index1)], dim=0)
119
+ correct_0to1 = correct_0to1 * data['mask0'].reshape(-1, h0*w0) * mask1
120
+
121
+ conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device, dtype=torch.bool)
122
+ b_ids, i_ids = torch.where(correct_0to1 != 0)
123
+ j_ids = nearest_index1[b_ids, i_ids]
124
+ valid_j_ids = j_ids != -1
125
+ b_ids, i_ids, j_ids = map(lambda x: x[valid_j_ids], [b_ids, i_ids, j_ids])
126
+
127
+ conf_matrix_gt[b_ids, i_ids, j_ids] = 1
128
+
129
+ # overlap weight
130
+ if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT:
131
+ conf_matrix_error_gt = w_pt0_c_error[b_ids, i_ids]
132
+ assert torch.all(conf_matrix_error_gt >= -0.001)
133
+ assert torch.all(conf_matrix_error_gt <= 1.001)
134
+ data.update({'conf_matrix_error_gt': conf_matrix_error_gt})
135
+ data.update({'conf_matrix_gt': conf_matrix_gt})
136
+
137
+ # 5. save coarse matches(gt) for training fine level
138
+ if len(b_ids) == 0:
139
+ loguru_logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
140
+ # this won't affect fine-level loss calculation
141
+ b_ids = torch.tensor([0], device=device)
142
+ i_ids = torch.tensor([0], device=device)
143
+ j_ids = torch.tensor([0], device=device)
144
+
145
+ data.update({
146
+ 'spv_b_ids': b_ids,
147
+ 'spv_i_ids': i_ids,
148
+ 'spv_j_ids': j_ids
149
+ })
150
+
151
+ data.update({'mkpts0_c_gt_b_ids': b_ids})
152
+ data.update({'mkpts0_c_gt': torch.stack([i_ids % w0, i_ids // w0], dim=-1) * scale0[b_ids, 0]})
153
+ data.update({'mkpts1_c_gt': torch.stack([j_ids % w1, j_ids // w1], dim=-1) * scale1[b_ids, 0]})
154
+
155
+ # 6. save intermediate results (for fast fine-level computation)
156
+ data.update({
157
+ 'spv_w_pt0_i': w_pt0_i,
158
+ 'spv_pt1_i': grid_pt1_i,
159
+ # 'correct_0to1_c': correct_0to1
160
+ })
161
+ else:
162
+ raise NotImplementedError
163
+
164
+ def compute_supervision_coarse(data, config):
165
+ spvs_coarse(data, config)
166
+
167
+ @torch.no_grad()
168
+ def get_gt_flow(data, h, w):
169
+ device = data['image0'].device
170
+ B, _, H0, W0 = data['image0'].shape
171
+ scale = H0 / h
172
+
173
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
174
+ scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
175
+
176
+ x1_n = torch.meshgrid(
177
+ *[
178
+ torch.linspace(
179
+ -1 + 1 / n, 1 - 1 / n, n, device=device
180
+ )
181
+ for n in (B, h, w)
182
+ ]
183
+ )
184
+ grid_coord = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, h*w, 2) # normalized
185
+ grid_coord = torch.stack(
186
+ (w * (grid_coord[..., 0] + 1) / 2, h * (grid_coord[..., 1] + 1) / 2), dim=-1
187
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
188
+ grid_coord_in_origin = grid_coord * scale0
189
+
190
+ correct_0to1 = torch.zeros((grid_coord_in_origin.shape[0], grid_coord_in_origin.shape[1]), dtype=torch.bool, device=device)
191
+ w_pt0_i = torch.zeros_like(grid_coord_in_origin)
192
+
193
+ valid_dpt_b_mask = data['T_0to1'].sum(dim=-1).sum(dim=-1) != 0
194
+ valid_homo_warp_mask = (data['homography'].sum(dim=-1).sum(dim=-1) != 0) | (data['homo_sample_normed'].sum(dim=-1).sum(dim=-1) != 0)
195
+ valid_gt_match_warp_mask = (data['gt_matches_mask'] != 0)[:, 0]
196
+
197
+ if valid_homo_warp_mask.sum() != 0:
198
+ if data['homography'].sum()==0:
199
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
200
+ # data['load_mask'] = True or False, data['depth_mask'] = depth_mask or None
201
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_with_mask(grid_coord_in_origin[valid_homo_warp_mask], int(scale), data['homo_mask0'][valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], \
202
+ data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
203
+ else:
204
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts(grid_coord_in_origin[valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], data['homo_sample_normed'][valid_homo_warp_mask], \
205
+ original_size1=data['origin_img_size1'][valid_homo_warp_mask])
206
+ else:
207
+ if 'homo_mask0' in data and (data['homo_mask0']==0).sum()!=0:
208
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue_with_mask(grid_coord_in_origin[valid_homo_warp_mask], int(scale), data['homo_mask0'][valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \
209
+ original_size1=data['origin_img_size1'][valid_homo_warp_mask])
210
+ else:
211
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue(grid_coord_in_origin[valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \
212
+ original_size1=data['origin_img_size1'][valid_homo_warp_mask])
213
+ correct_0to1[valid_homo_warp_mask] = correct_0to1_homo
214
+ w_pt0_i[valid_homo_warp_mask] = w_pt0_i_homo
215
+
216
+ if valid_gt_match_warp_mask.sum() != 0:
217
+ correct_0to1_dpt, w_pt0_i_dpt = warp_kpts_by_sparse_gt_matches_fast(grid_coord_in_origin[valid_gt_match_warp_mask], data['gt_matches'][valid_gt_match_warp_mask], scale0=scale0[valid_gt_match_warp_mask], current_h=h, current_w=w)
218
+ correct_0to1[valid_gt_match_warp_mask] = correct_0to1_dpt
219
+ w_pt0_i[valid_gt_match_warp_mask] = w_pt0_i_dpt
220
+ if valid_dpt_b_mask.sum() != 0:
221
+ correct_0to1_dpt, w_pt0_i_dpt = warp_kpts(grid_coord_in_origin[valid_dpt_b_mask], data['depth0'][valid_dpt_b_mask], data['depth1'][valid_dpt_b_mask], data['T_0to1'][valid_dpt_b_mask], data['K0'][valid_dpt_b_mask], data['K1'][valid_dpt_b_mask], consistency_thr=0.05)
222
+ correct_0to1[valid_dpt_b_mask] = correct_0to1_dpt
223
+ w_pt0_i[valid_dpt_b_mask] = w_pt0_i_dpt
224
+
225
+ w_pt0_c = w_pt0_i / scale1
226
+
227
+ def out_bound_mask(pt, w, h):
228
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
229
+ correct_0to1[out_bound_mask(w_pt0_c, w, h)] = 0
230
+
231
+ w_pt0_n = torch.stack(
232
+ (2 * w_pt0_c[..., 0] / w - 1, 2 * w_pt0_c[..., 1] / h - 1), dim=-1
233
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
234
+ # w_pt1_c = w_pt1_i / scale0
235
+
236
+ if scale > 8:
237
+ data.update({'mkpts0_c_gt': grid_coord_in_origin[correct_0to1]})
238
+ data.update({'mkpts1_c_gt': w_pt0_i[correct_0to1]})
239
+
240
+ return w_pt0_n.reshape(B, h, w, 2), correct_0to1.float().reshape(B, h, w)
241
+
242
+ @torch.no_grad()
243
+ def compute_roma_supervision(data, config):
244
+ gt_flow = {}
245
+ for scale in list(data["corresps"]):
246
+ scale_corresps = data["corresps"][scale]
247
+ flow_pre_delta = rearrange(scale_corresps['flow'] if 'flow'in scale_corresps else scale_corresps['dense_flow'], "b d h w -> b h w d")
248
+ b, h, w, d = flow_pre_delta.shape
249
+ gt_warp, gt_prob = get_gt_flow(data, h, w)
250
+ gt_flow[scale] = {'gt_warp': gt_warp, "gt_prob": gt_prob}
251
+
252
+ data.update({"gt": gt_flow})
253
+
254
+ ############## ↓ Fine-Level supervision ↓ ##############
255
+
256
+ @static_vars(counter = 0)
257
+ @torch.no_grad()
258
+ def spvs_fine(data, config, logger = None):
259
+ """
260
+ Update:
261
+ data (dict):{
262
+ "expec_f_gt": [M, 2]}
263
+ """
264
+ # 1. misc
265
+ # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
266
+ if config.LOFTR.FINE.MTD_SPVS:
267
+ pt1_i = data['spv_pt1_i']
268
+ else:
269
+ spv_w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
270
+ if 'loftr' in config.METHOD:
271
+ scale = config['LOFTR']['RESOLUTION'][1]
272
+ scale_c = config['LOFTR']['RESOLUTION'][0]
273
+ radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2
274
+
275
+ # 2. get coarse prediction
276
+ b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
277
+
278
+ # 3. compute gt
279
+ scalei0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
280
+ scale0 = scale * data['scale0'] if 'scale0' in data else scale
281
+ scalei1 = scale * data['scale1'][b_ids] if 'scale0' in data else scale
282
+
283
+ if config.LOFTR.FINE.MTD_SPVS:
284
+ W = config['LOFTR']['FINE_WINDOW_SIZE']
285
+ WW = W*W
286
+ device = data['image0'].device
287
+
288
+ N, _, H0, W0 = data['image0'].shape
289
+ _, _, H1, W1 = data['image1'].shape
290
+
291
+ if config.LOFTR.ALIGN_CORNER is False:
292
+ hf0, wf0, hf1, wf1 = data['hw0_f'][0], data['hw0_f'][1], data['hw1_f'][0], data['hw1_f'][1]
293
+ hc0, wc0, hc1, wc1 = data['hw0_c'][0], data['hw0_c'][1], data['hw1_c'][0], data['hw1_c'][1]
294
+ # loguru_logger.info('hf0, wf0, hf1, wf1', hf0, wf0, hf1, wf1)
295
+ else:
296
+ hf0, wf0, hf1, wf1 = map(lambda x: x // scale, [H0, W0, H1, W1])
297
+ hc0, wc0, hc1, wc1 = map(lambda x: x // scale_c, [H0, W0, H1, W1])
298
+
299
+ m = b_ids.shape[0]
300
+ if m == 0:
301
+ conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device)
302
+
303
+ data.update({'conf_matrix_f_gt': conf_matrix_f_gt})
304
+ if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
305
+ conf_matrix_f_error_gt = torch.zeros(1, device=device)
306
+ data.update({'conf_matrix_f_error_gt': conf_matrix_f_error_gt})
307
+ if config.LOFTR.MATCH_FINE.MULTI_REGRESS:
308
+ data.update({'expec_f': torch.zeros(1, 3, device=device)})
309
+ data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
310
+
311
+ if config.LOFTR.MATCH_FINE.LOCAL_REGRESS:
312
+ data.update({'expec_f': torch.zeros(1, 2, device=device)})
313
+ data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
314
+ else:
315
+ grid_pt0_f = create_meshgrid(hf0, wf0, False, device) - W // 2 + 0.5 # [1, hf0, wf0, 2] # use fine coordinates
316
+ # grid_pt0_f = create_meshgrid(hf0, wf0, False, device) + 0.5 # [1, hf0, wf0, 2] # use fine coordinates
317
+ grid_pt0_f = rearrange(grid_pt0_f, 'n h w c -> n c h w')
318
+ # 1. unfold(crop) all local windows
319
+ if config.LOFTR.ALIGN_CORNER is False: # even windows
320
+ if config.LOFTR.MATCH_FINE.MULTI_REGRESS or (config.LOFTR.MATCH_FINE.LOCAL_REGRESS and W == 10):
321
+ grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W-2, padding=1) # overlap windows W-2 padding=1
322
+ else:
323
+ grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W, padding=0)
324
+ else:
325
+ grid_pt0_f_unfold = F.unfold(grid_pt0_f[..., :-1, :-1], kernel_size=(W, W), stride=W, padding=W//2)
326
+ grid_pt0_f_unfold = rearrange(grid_pt0_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 2]
327
+ grid_pt0_f_unfold = repeat(grid_pt0_f_unfold[0], 'l ww c -> N l ww c', N=N)
328
+
329
+ # 2. select only the predicted matches
330
+ grid_pt0_f_unfold = grid_pt0_f_unfold[data['b_ids'], data['i_ids']] # [m, ww, 2]
331
+ grid_pt0_f_unfold = scalei0[:,None,:] * grid_pt0_f_unfold # [m, ww, 2]
332
+
333
+ # use depth mask
334
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
335
+ # depth_mask --> (n, 1, hf, wf)
336
+ homo_mask0 = data['homo_mask0']
337
+ homo_mask0 = F.unfold(homo_mask0[..., :-1, :-1], kernel_size=(W, W), stride=W, padding=W//2)
338
+ homo_mask0 = rearrange(homo_mask0, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 1]
339
+ homo_mask0 = repeat(homo_mask0[0], 'l ww c -> N l ww c', N=N)
340
+ # select only the predicted matches
341
+ homo_mask0 = homo_mask0[data['b_ids'], data['i_ids']]
342
+
343
+ correct_0to1_f_list, w_pt0_i_list = [], []
344
+
345
+ correct_0to1_f = torch.zeros(m, WW, device=device, dtype=torch.bool)
346
+ w_pt0_i = torch.zeros(m, WW, 2, device=device, dtype=torch.float32)
347
+ for b in range(N):
348
+ mask = b_ids == b
349
+
350
+ match = int(mask.sum())
351
+ skip_reshape = False
352
+ if match == 0:
353
+ print(f"no pred fine matches, skip!")
354
+ continue
355
+ if (data['homography'][b].sum() != 0) | (data['homo_sample_normed'][b].sum() != 0):
356
+ if data['homography'][b].sum()==0:
357
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
358
+ correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_with_mask_f(grid_pt0_f_unfold[mask].reshape(1,-1,2), homo_mask0[mask].reshape(1,-1), data['norm_pixel_mat'][[b]], \
359
+ data['homo_sample_normed'][[b]], data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
360
+ else:
361
+ correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['norm_pixel_mat'][[b]], \
362
+ data['homo_sample_normed'][[b]], data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
363
+ else:
364
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
365
+ correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_glue_with_mask_f(grid_pt0_f_unfold[mask].reshape(1,-1,2), homo_mask0[mask].reshape(1,-1), data['homography'][[b]], \
366
+ data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
367
+ else:
368
+ correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_glue(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['homography'][[b]], \
369
+ data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
370
+ elif data['T_0to1'][b].sum() != 0:
371
+ correct_0to1_f_mask, w_pt0_i_mask = warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['depth0'][[b],...],
372
+ data['depth1'][[b],...], data['T_0to1'][[b],...],
373
+ data['K0'][[b],...], data['K1'][[b],...]) # [k, WW], [k, WW, 2]
374
+ elif data['gt_matches_mask'][b].sum() != 0:
375
+ correct_0to1_f_mask, w_pt0_i_mask = warp_kpts_by_sparse_gt_matches_fine_chunks(grid_pt0_f_unfold[mask], gt_matches=data['gt_matches'][[b]], dist_thr=scale0[[b]].max(dim=-1)[0])
376
+ skip_reshape = True
377
+ correct_0to1_f[mask] = correct_0to1_f_mask.reshape(match, WW) if not skip_reshape else correct_0to1_f_mask
378
+ w_pt0_i[mask] = w_pt0_i_mask.reshape(match, WW, 2) if not skip_reshape else w_pt0_i_mask
379
+
380
+ delta_w_pt0_i = w_pt0_i - pt1_i[b_ids, j_ids][:,None,:] # [m, WW, 2]
381
+ delta_w_pt0_f = delta_w_pt0_i / scalei1[:,None,:] + W // 2 - 0.5
382
+ delta_w_pt0_f_round = delta_w_pt0_f[:, :, :].round()
383
+ if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT and config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2:
384
+ w_pt0_f_error = (1.0 - torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0.25, 1]
385
+ elif config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
386
+ w_pt0_f_error = (1.0 - 2*torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0, 1]
387
+ delta_w_pt0_f_round = delta_w_pt0_f_round.long()
388
+
389
+
390
+ nearest_index1 = delta_w_pt0_f_round[..., 0] + delta_w_pt0_f_round[..., 1] * W # [m, WW]
391
+
392
+ def out_bound_mask(pt, w, h):
393
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
394
+ ob_mask = out_bound_mask(delta_w_pt0_f_round, W, W)
395
+ nearest_index1[ob_mask] = 0
396
+
397
+ correct_0to1_f[ob_mask] = 0
398
+ m_ids_d, i_ids_d = torch.where(correct_0to1_f != 0)
399
+
400
+ j_ids_d = nearest_index1[m_ids_d, i_ids_d]
401
+
402
+ # For plotting:
403
+ mkpts0_f_gt = grid_pt0_f_unfold[m_ids_d, i_ids_d] # [m, 2]
404
+ mkpts1_f_gt = w_pt0_i[m_ids_d, i_ids_d] # [m, 2]
405
+ data.update({'mkpts0_f_gt_b_ids': m_ids_d})
406
+ data.update({'mkpts0_f_gt': mkpts0_f_gt})
407
+ data.update({'mkpts1_f_gt': mkpts1_f_gt})
408
+
409
+ if config.LOFTR.MATCH_FINE.MULTI_REGRESS:
410
+ assert not config.LOFTR.MATCH_FINE.LOCAL_REGRESS
411
+ expec_f_gt = delta_w_pt0_f - W // 2 + 0.5 # use delta(e.g. [-3.5,3.5]) in regression rather than [0,W] (e.g. [0,7])
412
+ expec_f_gt = expec_f_gt[m_ids_d, i_ids_d] / (W // 2 - 1) # specific radius for overlaped even windows & align_corner=False
413
+ data.update({'expec_f_gt': expec_f_gt})
414
+ data.update({'m_ids_d': m_ids_d, 'i_ids_d': i_ids_d})
415
+ else: # spv fine dual softmax
416
+ if config.LOFTR.MATCH_FINE.LOCAL_REGRESS:
417
+ expec_f_gt = delta_w_pt0_f - delta_w_pt0_f_round
418
+
419
+ # mask fine windows boarder
420
+ j_ids_d_il, j_ids_d_jl = j_ids_d // W, j_ids_d % W
421
+ if config.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK:
422
+ mask = None
423
+ m_ids_dl, i_ids_dl, j_ids_d_il, j_ids_d_jl = m_ids_d.to(torch.long), i_ids_d.to(torch.long), j_ids_d_il.to(torch.long), j_ids_d_jl.to(torch.long)
424
+ else:
425
+ mask = (j_ids_d_il >= 1) & (j_ids_d_il < W-1) & (j_ids_d_jl >= 1) & (j_ids_d_jl < W-1)
426
+ if W == 10:
427
+ i_ids_d_il, i_ids_d_jl = i_ids_d // W, i_ids_d % W
428
+ mask = mask & (i_ids_d_il >= 1) & (i_ids_d_il <= W-2) & (i_ids_d_jl >= 1) & (i_ids_d_jl <= W-2)
429
+
430
+ m_ids_dl, i_ids_dl, j_ids_d_il, j_ids_d_jl = m_ids_d[mask].to(torch.long), i_ids_d[mask].to(torch.long), j_ids_d_il[mask].to(torch.long), j_ids_d_jl[mask].to(torch.long)
431
+ if mask is not None:
432
+ loguru_logger.info(f'percent of gt mask.sum / mask.numel: {mask.sum().float()/mask.numel():.2f}')
433
+ if m_ids_dl.numel() == 0:
434
+ loguru_logger.warning(f"No groundtruth fine match found for local regress: {data['pair_names']}")
435
+ data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
436
+ data.update({'expec_f': torch.zeros(1, 2, device=device)})
437
+ else:
438
+ expec_f_gt = expec_f_gt[m_ids_dl, i_ids_dl]
439
+ data.update({"expec_f_gt": expec_f_gt})
440
+
441
+ data.update({"m_ids_dl": m_ids_dl,
442
+ "i_ids_dl": i_ids_dl,
443
+ "j_ids_d_il": j_ids_d_il,
444
+ "j_ids_d_jl": j_ids_d_jl
445
+ })
446
+ else: # no fine regress
447
+ pass
448
+
449
+ # spv fine dual softmax
450
+ conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device, dtype=torch.bool)
451
+ conf_matrix_f_gt[m_ids_d, i_ids_d, j_ids_d] = 1
452
+ data.update({'conf_matrix_f_gt': conf_matrix_f_gt})
453
+ if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
454
+ w_pt0_f_error = w_pt0_f_error[m_ids_d, i_ids_d]
455
+ assert torch.all(w_pt0_f_error >= -0.001)
456
+ assert torch.all(w_pt0_f_error <= 1.001)
457
+ data.update({'conf_matrix_f_error_gt': w_pt0_f_error})
458
+
459
+ conf_matrix_f_gt_sum = conf_matrix_f_gt.sum()
460
+ if conf_matrix_f_gt_sum != 0:
461
+ pass
462
+ else:
463
+ loguru_logger.info(f'[no gt plot]no fine matches to supervise')
464
+
465
+ else:
466
+ expec_f_gt = (spv_w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scalei1 / 4 # [M, 2]
467
+ data.update({"expec_f_gt": expec_f_gt})
468
+
469
+
470
+ def compute_supervision_fine(data, config, logger=None):
471
+ data_source = data['dataset_name'][0]
472
+ if data_source.lower() in ['scannet', 'megadepth']:
473
+ spvs_fine(data, config, logger)
474
+ else:
475
+ raise NotImplementedError
imcui/third_party/MatchAnything/src/optimizers/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
3
+
4
+
5
+ def build_optimizer(model, config):
6
+ name = config.TRAINER.OPTIMIZER
7
+ lr = config.TRAINER.TRUE_LR
8
+
9
+ if name == "adam":
10
+ return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY, eps=config.TRAINER.OPTIMIZER_EPS)
11
+ elif name == "adamw":
12
+ if ("ROMA" in config.METHOD) or ("DKM" in config.METHOD):
13
+ # Filter the backbone param and others param:
14
+ keyword = 'model.encoder'
15
+ backbone_params = [param for name, param in list(filter(lambda kv: keyword in kv[0], model.named_parameters()))]
16
+ base_params = [param for name, param in list(filter(lambda kv: keyword not in kv[0], model.named_parameters()))]
17
+ params = [{'params': backbone_params, 'lr': lr * 0.05}, {'params': base_params}]
18
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY, eps=config.TRAINER.OPTIMIZER_EPS)
19
+ else:
20
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY, eps=config.TRAINER.OPTIMIZER_EPS)
21
+ else:
22
+ raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
23
+
24
+
25
+ def build_scheduler(config, optimizer):
26
+ """
27
+ Returns:
28
+ scheduler (dict):{
29
+ 'scheduler': lr_scheduler,
30
+ 'interval': 'step', # or 'epoch'
31
+ 'monitor': 'val_f1', (optional)
32
+ 'frequency': x, (optional)
33
+ }
34
+ """
35
+ scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
36
+ name = config.TRAINER.SCHEDULER
37
+
38
+ if name == 'MultiStepLR':
39
+ scheduler.update(
40
+ {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
41
+ elif name == 'CosineAnnealing':
42
+ scheduler.update(
43
+ {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
44
+ elif name == 'ExponentialLR':
45
+ scheduler.update(
46
+ {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
47
+ else:
48
+ raise NotImplementedError()
49
+
50
+ return scheduler
imcui/third_party/MatchAnything/src/utils/__init__.py ADDED
File without changes
imcui/third_party/MatchAnything/src/utils/augment.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+
3
+
4
+ class DarkAug(object):
5
+ """
6
+ Extreme dark augmentation aiming at Aachen Day-Night
7
+ """
8
+
9
+ def __init__(self) -> None:
10
+ self.augmentor = A.Compose([
11
+ A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
12
+ A.Blur(p=0.1, blur_limit=(3, 9)),
13
+ A.MotionBlur(p=0.2, blur_limit=(3, 25)),
14
+ A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
15
+ A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
16
+ ], p=0.75)
17
+
18
+ def __call__(self, x):
19
+ return self.augmentor(image=x)['image']
20
+
21
+
22
+ class MobileAug(object):
23
+ """
24
+ Random augmentations aiming at images of mobile/handhold devices.
25
+ """
26
+
27
+ def __init__(self):
28
+ self.augmentor = A.Compose([
29
+ A.MotionBlur(p=0.25),
30
+ A.ColorJitter(p=0.5),
31
+ A.RandomRain(p=0.1), # random occlusion
32
+ A.RandomSunFlare(p=0.1),
33
+ A.JpegCompression(p=0.25),
34
+ A.ISONoise(p=0.25)
35
+ ], p=1.0)
36
+
37
+ def __call__(self, x):
38
+ return self.augmentor(image=x)['image']
39
+
40
+
41
+ def build_augmentor(method=None, **kwargs):
42
+ if method is not None:
43
+ raise NotImplementedError('Using of augmentation functions are not supported yet!')
44
+ if method == 'dark':
45
+ return DarkAug()
46
+ elif method == 'mobile':
47
+ return MobileAug()
48
+ elif method is None:
49
+ return None
50
+ else:
51
+ raise ValueError(f'Invalid augmentation method: {method}')
52
+
53
+
54
+ if __name__ == '__main__':
55
+ augmentor = build_augmentor('FDA')
imcui/third_party/MatchAnything/src/utils/colmap.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, ETH Zurich and UNC Chapel Hill.
2
+ # All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ #
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ #
10
+ # * Redistributions in binary form must reproduce the above copyright
11
+ # notice, this list of conditions and the following disclaimer in the
12
+ # documentation and/or other materials provided with the distribution.
13
+ #
14
+ # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15
+ # its contributors may be used to endorse or promote products derived
16
+ # from this software without specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
+ # POSSIBILITY OF SUCH DAMAGE.
29
+ #
30
+ # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31
+
32
+ from typing import List, Tuple, Dict
33
+ import os
34
+ import collections
35
+ import numpy as np
36
+ import struct
37
+ import argparse
38
+
39
+
40
+ CameraModel = collections.namedtuple(
41
+ "CameraModel", ["model_id", "model_name", "num_params"])
42
+ BaseCamera = collections.namedtuple(
43
+ "Camera", ["id", "model", "width", "height", "params"])
44
+ BaseImage = collections.namedtuple(
45
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
46
+ Point3D = collections.namedtuple(
47
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
48
+
49
+
50
+ class Image(BaseImage):
51
+ def qvec2rotmat(self):
52
+ return qvec2rotmat(self.qvec)
53
+
54
+ @property
55
+ def world_to_camera(self) -> np.ndarray:
56
+ R = qvec2rotmat(self.qvec)
57
+ t = self.tvec
58
+ world2cam = np.eye(4)
59
+ world2cam[:3, :3] = R
60
+ world2cam[:3, 3] = t
61
+ return world2cam
62
+
63
+
64
+ class Camera(BaseCamera):
65
+ @property
66
+ def K(self):
67
+ K = np.eye(3)
68
+ if self.model == "SIMPLE_PINHOLE" or self.model == "SIMPLE_RADIAL" or self.model == "RADIAL" or self.model == "SIMPLE_RADIAL_FISHEYE" or self.model == "RADIAL_FISHEYE":
69
+ K[0, 0] = self.params[0]
70
+ K[1, 1] = self.params[0]
71
+ K[0, 2] = self.params[1]
72
+ K[1, 2] = self.params[2]
73
+ elif self.model == "PINHOLE" or self.model == "OPENCV" or self.model == "OPENCV_FISHEYE" or self.model == "FULL_OPENCV" or self.model == "FOV" or self.model == "THIN_PRISM_FISHEYE":
74
+ K[0, 0] = self.params[0]
75
+ K[1, 1] = self.params[1]
76
+ K[0, 2] = self.params[2]
77
+ K[1, 2] = self.params[3]
78
+ else:
79
+ raise NotImplementedError
80
+ return K
81
+
82
+
83
+ CAMERA_MODELS = {
84
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
85
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
86
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
87
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
88
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
89
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
90
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
91
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
92
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
93
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
94
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
95
+ }
96
+ CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
97
+ for camera_model in CAMERA_MODELS])
98
+ CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
99
+ for camera_model in CAMERA_MODELS])
100
+
101
+
102
+ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
103
+ """Read and unpack the next bytes from a binary file.
104
+ :param fid:
105
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
106
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
107
+ :param endian_character: Any of {@, =, <, >, !}
108
+ :return: Tuple of read and unpacked values.
109
+ """
110
+ data = fid.read(num_bytes)
111
+ return struct.unpack(endian_character + format_char_sequence, data)
112
+
113
+
114
+ def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
115
+ """pack and write to a binary file.
116
+ :param fid:
117
+ :param data: data to send, if multiple elements are sent at the same time,
118
+ they should be encapsuled either in a list or a tuple
119
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
120
+ should be the same length as the data list or tuple
121
+ :param endian_character: Any of {@, =, <, >, !}
122
+ """
123
+ if isinstance(data, (list, tuple)):
124
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
125
+ else:
126
+ bytes = struct.pack(endian_character + format_char_sequence, data)
127
+ fid.write(bytes)
128
+
129
+
130
+ def read_cameras_text(path):
131
+ """
132
+ see: src/base/reconstruction.cc
133
+ void Reconstruction::WriteCamerasText(const std::string& path)
134
+ void Reconstruction::ReadCamerasText(const std::string& path)
135
+ """
136
+ cameras = {}
137
+ with open(path, "r") as fid:
138
+ while True:
139
+ line = fid.readline()
140
+ if not line:
141
+ break
142
+ line = line.strip()
143
+ if len(line) > 0 and line[0] != "#":
144
+ elems = line.split()
145
+ camera_id = int(elems[0])
146
+ model = elems[1]
147
+ width = int(elems[2])
148
+ height = int(elems[3])
149
+ params = np.array(tuple(map(float, elems[4:])))
150
+ cameras[camera_id] = Camera(id=camera_id, model=model,
151
+ width=width, height=height,
152
+ params=params)
153
+ return cameras
154
+
155
+
156
+ def read_cameras_binary(path_to_model_file):
157
+ """
158
+ see: src/base/reconstruction.cc
159
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
160
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
161
+ """
162
+ cameras = {}
163
+ with open(path_to_model_file, "rb") as fid:
164
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
165
+ for _ in range(num_cameras):
166
+ camera_properties = read_next_bytes(
167
+ fid, num_bytes=24, format_char_sequence="iiQQ")
168
+ camera_id = camera_properties[0]
169
+ model_id = camera_properties[1]
170
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
171
+ width = camera_properties[2]
172
+ height = camera_properties[3]
173
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
174
+ params = read_next_bytes(fid, num_bytes=8*num_params,
175
+ format_char_sequence="d"*num_params)
176
+ cameras[camera_id] = Camera(id=camera_id,
177
+ model=model_name,
178
+ width=width,
179
+ height=height,
180
+ params=np.array(params))
181
+ assert len(cameras) == num_cameras
182
+ return cameras
183
+
184
+
185
+ def write_cameras_text(cameras, path):
186
+ """
187
+ see: src/base/reconstruction.cc
188
+ void Reconstruction::WriteCamerasText(const std::string& path)
189
+ void Reconstruction::ReadCamerasText(const std::string& path)
190
+ """
191
+ HEADER = "# Camera list with one line of data per camera:\n" + \
192
+ "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + \
193
+ "# Number of cameras: {}\n".format(len(cameras))
194
+ with open(path, "w") as fid:
195
+ fid.write(HEADER)
196
+ for _, cam in cameras.items():
197
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
198
+ line = " ".join([str(elem) for elem in to_write])
199
+ fid.write(line + "\n")
200
+
201
+
202
+ def write_cameras_binary(cameras, path_to_model_file):
203
+ """
204
+ see: src/base/reconstruction.cc
205
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
206
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
207
+ """
208
+ with open(path_to_model_file, "wb") as fid:
209
+ write_next_bytes(fid, len(cameras), "Q")
210
+ for _, cam in cameras.items():
211
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
212
+ camera_properties = [cam.id,
213
+ model_id,
214
+ cam.width,
215
+ cam.height]
216
+ write_next_bytes(fid, camera_properties, "iiQQ")
217
+ for p in cam.params:
218
+ write_next_bytes(fid, float(p), "d")
219
+ return cameras
220
+
221
+
222
+ def read_images_text(path):
223
+ """
224
+ see: src/base/reconstruction.cc
225
+ void Reconstruction::ReadImagesText(const std::string& path)
226
+ void Reconstruction::WriteImagesText(const std::string& path)
227
+ """
228
+ images = {}
229
+ with open(path, "r") as fid:
230
+ while True:
231
+ line = fid.readline()
232
+ if not line:
233
+ break
234
+ line = line.strip()
235
+ if len(line) > 0 and line[0] != "#":
236
+ elems = line.split()
237
+ image_id = int(elems[0])
238
+ qvec = np.array(tuple(map(float, elems[1:5])))
239
+ tvec = np.array(tuple(map(float, elems[5:8])))
240
+ camera_id = int(elems[8])
241
+ image_name = elems[9]
242
+ elems = fid.readline().split()
243
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
244
+ tuple(map(float, elems[1::3]))])
245
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
246
+ images[image_id] = Image(
247
+ id=image_id, qvec=qvec, tvec=tvec,
248
+ camera_id=camera_id, name=image_name,
249
+ xys=xys, point3D_ids=point3D_ids)
250
+ return images
251
+
252
+
253
+ def read_images_binary(path_to_model_file):
254
+ """
255
+ see: src/base/reconstruction.cc
256
+ void Reconstruction::ReadImagesBinary(const std::string& path)
257
+ void Reconstruction::WriteImagesBinary(const std::string& path)
258
+ """
259
+ images = {}
260
+ with open(path_to_model_file, "rb") as fid:
261
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
262
+ for _ in range(num_reg_images):
263
+ binary_image_properties = read_next_bytes(
264
+ fid, num_bytes=64, format_char_sequence="idddddddi")
265
+ image_id = binary_image_properties[0]
266
+ qvec = np.array(binary_image_properties[1:5])
267
+ tvec = np.array(binary_image_properties[5:8])
268
+ camera_id = binary_image_properties[8]
269
+ image_name = ""
270
+ current_char = read_next_bytes(fid, 1, "c")[0]
271
+ while current_char != b"\x00": # look for the ASCII 0 entry
272
+ image_name += current_char.decode("utf-8")
273
+ current_char = read_next_bytes(fid, 1, "c")[0]
274
+ num_points2D = read_next_bytes(fid, num_bytes=8,
275
+ format_char_sequence="Q")[0]
276
+ x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
277
+ format_char_sequence="ddq"*num_points2D)
278
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
279
+ tuple(map(float, x_y_id_s[1::3]))])
280
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
281
+ images[image_id] = Image(
282
+ id=image_id, qvec=qvec, tvec=tvec,
283
+ camera_id=camera_id, name=image_name,
284
+ xys=xys, point3D_ids=point3D_ids)
285
+ return images
286
+
287
+
288
+ def write_images_text(images, path):
289
+ """
290
+ see: src/base/reconstruction.cc
291
+ void Reconstruction::ReadImagesText(const std::string& path)
292
+ void Reconstruction::WriteImagesText(const std::string& path)
293
+ """
294
+ if len(images) == 0:
295
+ mean_observations = 0
296
+ else:
297
+ mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images)
298
+ HEADER = "# Image list with two lines of data per image:\n" + \
299
+ "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + \
300
+ "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + \
301
+ "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations)
302
+
303
+ with open(path, "w") as fid:
304
+ fid.write(HEADER)
305
+ for _, img in images.items():
306
+ image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
307
+ first_line = " ".join(map(str, image_header))
308
+ fid.write(first_line + "\n")
309
+
310
+ points_strings = []
311
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
312
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
313
+ fid.write(" ".join(points_strings) + "\n")
314
+
315
+
316
+ def write_images_binary(images, path_to_model_file):
317
+ """
318
+ see: src/base/reconstruction.cc
319
+ void Reconstruction::ReadImagesBinary(const std::string& path)
320
+ void Reconstruction::WriteImagesBinary(const std::string& path)
321
+ """
322
+ with open(path_to_model_file, "wb") as fid:
323
+ write_next_bytes(fid, len(images), "Q")
324
+ for _, img in images.items():
325
+ write_next_bytes(fid, img.id, "i")
326
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
327
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
328
+ write_next_bytes(fid, img.camera_id, "i")
329
+ for char in img.name:
330
+ write_next_bytes(fid, char.encode("utf-8"), "c")
331
+ write_next_bytes(fid, b"\x00", "c")
332
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
333
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
334
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
335
+
336
+
337
+ def read_points3D_text(path):
338
+ """
339
+ see: src/base/reconstruction.cc
340
+ void Reconstruction::ReadPoints3DText(const std::string& path)
341
+ void Reconstruction::WritePoints3DText(const std::string& path)
342
+ """
343
+ points3D = {}
344
+ with open(path, "r") as fid:
345
+ while True:
346
+ line = fid.readline()
347
+ if not line:
348
+ break
349
+ line = line.strip()
350
+ if len(line) > 0 and line[0] != "#":
351
+ elems = line.split()
352
+ point3D_id = int(elems[0])
353
+ xyz = np.array(tuple(map(float, elems[1:4])))
354
+ rgb = np.array(tuple(map(int, elems[4:7])))
355
+ error = float(elems[7])
356
+ image_ids = np.array(tuple(map(int, elems[8::2])))
357
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
358
+ points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
359
+ error=error, image_ids=image_ids,
360
+ point2D_idxs=point2D_idxs)
361
+ return points3D
362
+
363
+
364
+ def read_points3D_binary(path_to_model_file):
365
+ """
366
+ see: src/base/reconstruction.cc
367
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
368
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
369
+ """
370
+ points3D = {}
371
+ with open(path_to_model_file, "rb") as fid:
372
+ num_points = read_next_bytes(fid, 8, "Q")[0]
373
+ for _ in range(num_points):
374
+ binary_point_line_properties = read_next_bytes(
375
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
376
+ point3D_id = binary_point_line_properties[0]
377
+ xyz = np.array(binary_point_line_properties[1:4])
378
+ rgb = np.array(binary_point_line_properties[4:7])
379
+ error = np.array(binary_point_line_properties[7])
380
+ track_length = read_next_bytes(
381
+ fid, num_bytes=8, format_char_sequence="Q")[0]
382
+ track_elems = read_next_bytes(
383
+ fid, num_bytes=8*track_length,
384
+ format_char_sequence="ii"*track_length)
385
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
386
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
387
+ points3D[point3D_id] = Point3D(
388
+ id=point3D_id, xyz=xyz, rgb=rgb,
389
+ error=error, image_ids=image_ids,
390
+ point2D_idxs=point2D_idxs)
391
+ return points3D
392
+
393
+
394
+ def write_points3D_text(points3D, path):
395
+ """
396
+ see: src/base/reconstruction.cc
397
+ void Reconstruction::ReadPoints3DText(const std::string& path)
398
+ void Reconstruction::WritePoints3DText(const std::string& path)
399
+ """
400
+ if len(points3D) == 0:
401
+ mean_track_length = 0
402
+ else:
403
+ mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
404
+ HEADER = "# 3D point list with one line of data per point:\n" + \
405
+ "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \
406
+ "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)
407
+
408
+ with open(path, "w") as fid:
409
+ fid.write(HEADER)
410
+ for _, pt in points3D.items():
411
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
412
+ fid.write(" ".join(map(str, point_header)) + " ")
413
+ track_strings = []
414
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
415
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
416
+ fid.write(" ".join(track_strings) + "\n")
417
+
418
+
419
+ def write_points3D_binary(points3D, path_to_model_file):
420
+ """
421
+ see: src/base/reconstruction.cc
422
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
423
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
424
+ """
425
+ with open(path_to_model_file, "wb") as fid:
426
+ write_next_bytes(fid, len(points3D), "Q")
427
+ for _, pt in points3D.items():
428
+ write_next_bytes(fid, pt.id, "Q")
429
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
430
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
431
+ write_next_bytes(fid, pt.error, "d")
432
+ track_length = pt.image_ids.shape[0]
433
+ write_next_bytes(fid, track_length, "Q")
434
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
435
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
436
+
437
+
438
+ def detect_model_format(path, ext):
439
+ if os.path.isfile(os.path.join(path, "cameras" + ext)) and \
440
+ os.path.isfile(os.path.join(path, "images" + ext)) and \
441
+ os.path.isfile(os.path.join(path, "points3D" + ext)):
442
+ print("Detected model format: '" + ext + "'")
443
+ return True
444
+
445
+ return False
446
+
447
+
448
+ def read_model(path, ext="") -> Tuple[Dict[int, Camera], Dict[int, Image], Dict[int, Point3D]]:
449
+ # try to detect the extension automatically
450
+ if ext == "":
451
+ if detect_model_format(path, ".bin"):
452
+ ext = ".bin"
453
+ elif detect_model_format(path, ".txt"):
454
+ ext = ".txt"
455
+ else:
456
+ raise ValueError("Provide model format: '.bin' or '.txt'")
457
+
458
+ if ext == ".txt":
459
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
460
+ images = read_images_text(os.path.join(path, "images" + ext))
461
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
462
+ else:
463
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
464
+ images = read_images_binary(os.path.join(path, "images" + ext))
465
+ points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
466
+ return cameras, images, points3D
467
+
468
+
469
+ def write_model(cameras, images, points3D, path, ext=".bin"):
470
+ if ext == ".txt":
471
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
472
+ write_images_text(images, os.path.join(path, "images" + ext))
473
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
474
+ else:
475
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
476
+ write_images_binary(images, os.path.join(path, "images" + ext))
477
+ write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
478
+ return cameras, images, points3D
479
+
480
+
481
+ def qvec2rotmat(qvec):
482
+ return np.array([
483
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
484
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
485
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
486
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
487
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
488
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
489
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
490
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
491
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
492
+
493
+
494
+ def rotmat2qvec(R):
495
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
496
+ K = np.array([
497
+ [Rxx - Ryy - Rzz, 0, 0, 0],
498
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
499
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
500
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
501
+ eigvals, eigvecs = np.linalg.eigh(K)
502
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
503
+ if qvec[0] < 0:
504
+ qvec *= -1
505
+ return qvec
506
+
507
+
508
+ def main():
509
+ parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models")
510
+ parser.add_argument("--input_model", help="path to input model folder")
511
+ parser.add_argument("--input_format", choices=[".bin", ".txt"],
512
+ help="input model format", default="")
513
+ parser.add_argument("--output_model",
514
+ help="path to output model folder")
515
+ parser.add_argument("--output_format", choices=[".bin", ".txt"],
516
+ help="outut model format", default=".txt")
517
+ args = parser.parse_args()
518
+
519
+ cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
520
+
521
+ print("num_cameras:", len(cameras))
522
+ print("num_images:", len(images))
523
+ print("num_points3D:", len(points3D))
524
+
525
+ if args.output_model is not None:
526
+ write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format)
527
+
528
+
529
+ if __name__ == "__main__":
530
+ main()
imcui/third_party/MatchAnything/src/utils/colmap/__init__.py ADDED
File without changes
imcui/third_party/MatchAnything/src/utils/colmap/database.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2
+ # All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ #
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ #
10
+ # * Redistributions in binary form must reproduce the above copyright
11
+ # notice, this list of conditions and the following disclaimer in the
12
+ # documentation and/or other materials provided with the distribution.
13
+ #
14
+ # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15
+ # its contributors may be used to endorse or promote products derived
16
+ # from this software without specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
+ # POSSIBILITY OF SUCH DAMAGE.
29
+ #
30
+ # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31
+
32
+ # This script is based on an original implementation by True Price.
33
+
34
+ import sys
35
+ import sqlite3
36
+ import numpy as np
37
+ from loguru import logger
38
+
39
+
40
+ IS_PYTHON3 = sys.version_info[0] >= 3
41
+
42
+ MAX_IMAGE_ID = 2**31 - 1
43
+
44
+ CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
45
+ camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
46
+ model INTEGER NOT NULL,
47
+ width INTEGER NOT NULL,
48
+ height INTEGER NOT NULL,
49
+ params BLOB,
50
+ prior_focal_length INTEGER NOT NULL)"""
51
+
52
+ CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
53
+ image_id INTEGER PRIMARY KEY NOT NULL,
54
+ rows INTEGER NOT NULL,
55
+ cols INTEGER NOT NULL,
56
+ data BLOB,
57
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
58
+
59
+ CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
60
+ image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
61
+ name TEXT NOT NULL UNIQUE,
62
+ camera_id INTEGER NOT NULL,
63
+ prior_qw REAL,
64
+ prior_qx REAL,
65
+ prior_qy REAL,
66
+ prior_qz REAL,
67
+ prior_tx REAL,
68
+ prior_ty REAL,
69
+ prior_tz REAL,
70
+ CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
71
+ FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
72
+ """.format(MAX_IMAGE_ID)
73
+
74
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
75
+ CREATE TABLE IF NOT EXISTS two_view_geometries (
76
+ pair_id INTEGER PRIMARY KEY NOT NULL,
77
+ rows INTEGER NOT NULL,
78
+ cols INTEGER NOT NULL,
79
+ data BLOB,
80
+ config INTEGER NOT NULL,
81
+ F BLOB,
82
+ E BLOB,
83
+ H BLOB,
84
+ qvec BLOB,
85
+ tvec BLOB)
86
+ """
87
+
88
+ CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
89
+ image_id INTEGER PRIMARY KEY NOT NULL,
90
+ rows INTEGER NOT NULL,
91
+ cols INTEGER NOT NULL,
92
+ data BLOB,
93
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
94
+ """
95
+
96
+ CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
97
+ pair_id INTEGER PRIMARY KEY NOT NULL,
98
+ rows INTEGER NOT NULL,
99
+ cols INTEGER NOT NULL,
100
+ data BLOB)"""
101
+
102
+ CREATE_NAME_INDEX = \
103
+ "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
104
+
105
+ CREATE_ALL = "; ".join([
106
+ CREATE_CAMERAS_TABLE,
107
+ CREATE_IMAGES_TABLE,
108
+ CREATE_KEYPOINTS_TABLE,
109
+ CREATE_DESCRIPTORS_TABLE,
110
+ CREATE_MATCHES_TABLE,
111
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE,
112
+ CREATE_NAME_INDEX
113
+ ])
114
+
115
+
116
+ def image_ids_to_pair_id(image_id1, image_id2):
117
+ if image_id1 > image_id2:
118
+ image_id1, image_id2 = image_id2, image_id1
119
+ return image_id1 * MAX_IMAGE_ID + image_id2
120
+
121
+
122
+ def pair_id_to_image_ids(pair_id):
123
+ image_id2 = pair_id % MAX_IMAGE_ID
124
+ image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
125
+ return image_id1, image_id2
126
+
127
+
128
+ def array_to_blob(array):
129
+ if IS_PYTHON3:
130
+ return array.tobytes()
131
+ else:
132
+ return np.getbuffer(array)
133
+
134
+
135
+ def blob_to_array(blob, dtype, shape=(-1,)):
136
+ if IS_PYTHON3:
137
+ return np.fromstring(blob, dtype=dtype).reshape(*shape)
138
+ else:
139
+ return np.frombuffer(blob, dtype=dtype).reshape(*shape)
140
+
141
+
142
+ class COLMAPDatabase(sqlite3.Connection):
143
+
144
+ @staticmethod
145
+ def connect(database_path):
146
+ return sqlite3.connect(str(database_path), factory=COLMAPDatabase)
147
+
148
+
149
+ def __init__(self, *args, **kwargs):
150
+ super(COLMAPDatabase, self).__init__(*args, **kwargs)
151
+
152
+ self.create_tables = lambda: self.executescript(CREATE_ALL)
153
+ self.create_cameras_table = \
154
+ lambda: self.executescript(CREATE_CAMERAS_TABLE)
155
+ self.create_descriptors_table = \
156
+ lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
157
+ self.create_images_table = \
158
+ lambda: self.executescript(CREATE_IMAGES_TABLE)
159
+ self.create_two_view_geometries_table = \
160
+ lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
161
+ self.create_keypoints_table = \
162
+ lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
163
+ self.create_matches_table = \
164
+ lambda: self.executescript(CREATE_MATCHES_TABLE)
165
+ self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
166
+
167
+ def add_camera(self, model, width, height, params,
168
+ prior_focal_length=False, camera_id=None):
169
+ params = np.asarray(params, np.float64)
170
+ cursor = self.execute(
171
+ "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
172
+ (camera_id, model, width, height, array_to_blob(params),
173
+ prior_focal_length))
174
+ return cursor.lastrowid
175
+
176
+ def add_image(self, name, camera_id,
177
+ prior_q=np.full(4, np.NaN), prior_t=np.full(3, np.NaN),
178
+ image_id=None):
179
+ cursor = self.execute(
180
+ "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
181
+ (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
182
+ prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
183
+ return cursor.lastrowid
184
+
185
+ def add_keypoints(self, image_id, keypoints):
186
+ assert(len(keypoints.shape) == 2)
187
+ assert(keypoints.shape[1] in [2, 4, 6])
188
+
189
+ keypoints = np.asarray(keypoints, np.float32)
190
+ self.execute(
191
+ "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
192
+ (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
193
+
194
+ def add_descriptors(self, image_id, descriptors):
195
+ descriptors = np.ascontiguousarray(descriptors, np.uint8)
196
+ self.execute(
197
+ "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
198
+ (image_id,) + descriptors.shape + (array_to_blob(descriptors),))
199
+
200
+ def add_matches(self, image_id1, image_id2, matches):
201
+ assert(len(matches.shape) == 2)
202
+ assert(matches.shape[1] == 2)
203
+
204
+ if image_id1 > image_id2:
205
+ matches = matches[:,::-1]
206
+
207
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
208
+ matches = np.asarray(matches, np.uint32)
209
+ self.execute(
210
+ "INSERT INTO matches VALUES (?, ?, ?, ?)",
211
+ (pair_id,) + matches.shape + (array_to_blob(matches),))
212
+
213
+ def add_two_view_geometry(self, image_id1, image_id2, matches,
214
+ F=np.eye(3), E=np.eye(3), H=np.eye(3),
215
+ qvec=np.array([1.0, 0.0, 0.0, 0.0]),
216
+ tvec=np.zeros(3), config=2):
217
+ assert(len(matches.shape) == 2)
218
+ assert(matches.shape[1] == 2)
219
+
220
+ if image_id1 > image_id2:
221
+ matches = matches[:,::-1]
222
+
223
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
224
+ matches = np.asarray(matches, np.uint32)
225
+ F = np.asarray(F, dtype=np.float64)
226
+ E = np.asarray(E, dtype=np.float64)
227
+ H = np.asarray(H, dtype=np.float64)
228
+ qvec = np.asarray(qvec, dtype=np.float64)
229
+ tvec = np.asarray(tvec, dtype=np.float64)
230
+ self.execute(
231
+ "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
232
+ (pair_id,) + matches.shape + (array_to_blob(matches), config,
233
+ array_to_blob(F), array_to_blob(E), array_to_blob(H),
234
+ array_to_blob(qvec), array_to_blob(tvec)))
235
+
236
+ def update_two_view_geometry(self, image_id1, image_id2, matches,
237
+ F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
238
+ assert(len(matches.shape) == 2)
239
+ assert(matches.shape[1] == 2)
240
+
241
+ if image_id1 > image_id2:
242
+ matches = matches[:,::-1]
243
+
244
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
245
+ matches = np.asarray(matches, np.uint32)
246
+ F = np.asarray(F, dtype=np.float64)
247
+ E = np.asarray(E, dtype=np.float64)
248
+ H = np.asarray(H, dtype=np.float64)
249
+
250
+ # Find whether exists:
251
+ row = self.execute(f"SELECT * FROM two_view_geometries WHERE pair_id = {pair_id} ")
252
+ data = list(next(row))
253
+ try:
254
+ matches_old = blob_to_array(data[3], np.uint32, (-1, 2))
255
+ except:
256
+ matches_old = None
257
+
258
+ if matches_old is not None:
259
+ for match in matches:
260
+ img0_id, img1_id = match
261
+
262
+ # Find duplicated pts
263
+ img0_dup_idxs = np.where(matches_old[:, 0] == img0_id)
264
+ img1_dup_idxs = np.where(matches_old[:, 1] == img1_id)
265
+
266
+ if len(img0_dup_idxs[0]) == 0 and len(img1_dup_idxs[0]) == 0:
267
+ # No duplicated matches:
268
+ matches_old = np.concatenate([matches_old, match[None]], axis=0)
269
+ elif len(img0_dup_idxs[0]) == 1 and len(img1_dup_idxs[0]) == 0:
270
+ matches_old[img0_dup_idxs[0]][0,1] = img1_id
271
+ elif len(img0_dup_idxs[0]) == 0 and len(img1_dup_idxs[0]) == 1:
272
+ matches_old[img1_dup_idxs[0]][0,0] = img0_id
273
+ elif len(img0_dup_idxs[0]) == 1 and len(img1_dup_idxs[0]) == 1:
274
+ if img0_dup_idxs[0] != img1_dup_idxs[0]:
275
+ # logger.warning(f"Duplicated matches exists!")
276
+ matches_old[img0_dup_idxs[0]][0,1] = img1_id
277
+ matches_old[img1_dup_idxs[0]][0,0] = img0_id
278
+ else:
279
+ raise NotImplementedError
280
+
281
+ # matches = np.concatenate([matches_old, matches], axis=0) # N * 2
282
+ matches = matches_old
283
+ self.execute(f"DELETE FROM two_view_geometries WHERE pair_id = {pair_id}")
284
+
285
+ data[1:4] = matches.shape + (array_to_blob(np.asarray(matches, np.uint32)),)
286
+ self.execute("INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", tuple(data))
287
+ else:
288
+ raise NotImplementedError
289
+
290
+ # self.add_two_view_geometry(image_id1, image_id2, matches)
291
+
292
+
293
+ def example_usage():
294
+ import os
295
+ import argparse
296
+
297
+ parser = argparse.ArgumentParser()
298
+ parser.add_argument("--database_path", default="database.db")
299
+ args = parser.parse_args()
300
+
301
+ if os.path.exists(args.database_path):
302
+ print("ERROR: database path already exists -- will not modify it.")
303
+ return
304
+
305
+ # Open the database.
306
+
307
+ db = COLMAPDatabase.connect(args.database_path)
308
+
309
+ # For convenience, try creating all the tables upfront.
310
+
311
+ db.create_tables()
312
+
313
+ # Create dummy cameras.
314
+
315
+ model1, width1, height1, params1 = \
316
+ 0, 1024, 768, np.array((1024., 512., 384.))
317
+ model2, width2, height2, params2 = \
318
+ 2, 1024, 768, np.array((1024., 512., 384., 0.1))
319
+
320
+ camera_id1 = db.add_camera(model1, width1, height1, params1)
321
+ camera_id2 = db.add_camera(model2, width2, height2, params2)
322
+
323
+ # Create dummy images.
324
+
325
+ image_id1 = db.add_image("image1.png", camera_id1)
326
+ image_id2 = db.add_image("image2.png", camera_id1)
327
+ image_id3 = db.add_image("image3.png", camera_id2)
328
+ image_id4 = db.add_image("image4.png", camera_id2)
329
+
330
+ # Create dummy keypoints.
331
+ #
332
+ # Note that COLMAP supports:
333
+ # - 2D keypoints: (x, y)
334
+ # - 4D keypoints: (x, y, theta, scale)
335
+ # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
336
+
337
+ num_keypoints = 1000
338
+ keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
339
+ keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
340
+ keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
341
+ keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
342
+
343
+ db.add_keypoints(image_id1, keypoints1)
344
+ db.add_keypoints(image_id2, keypoints2)
345
+ db.add_keypoints(image_id3, keypoints3)
346
+ db.add_keypoints(image_id4, keypoints4)
347
+
348
+ # Create dummy matches.
349
+
350
+ M = 50
351
+ matches12 = np.random.randint(num_keypoints, size=(M, 2))
352
+ matches23 = np.random.randint(num_keypoints, size=(M, 2))
353
+ matches34 = np.random.randint(num_keypoints, size=(M, 2))
354
+
355
+ db.add_matches(image_id1, image_id2, matches12)
356
+ db.add_matches(image_id2, image_id3, matches23)
357
+ db.add_matches(image_id3, image_id4, matches34)
358
+
359
+ # Commit the data to the file.
360
+
361
+ db.commit()
362
+
363
+ # Read and check cameras.
364
+
365
+ rows = db.execute("SELECT * FROM cameras")
366
+
367
+ camera_id, model, width, height, params, prior = next(rows)
368
+ params = blob_to_array(params, np.float64)
369
+ assert camera_id == camera_id1
370
+ assert model == model1 and width == width1 and height == height1
371
+ assert np.allclose(params, params1)
372
+
373
+ camera_id, model, width, height, params, prior = next(rows)
374
+ params = blob_to_array(params, np.float64)
375
+ assert camera_id == camera_id2
376
+ assert model == model2 and width == width2 and height == height2
377
+ assert np.allclose(params, params2)
378
+
379
+ # Read and check keypoints.
380
+
381
+ keypoints = dict(
382
+ (image_id, blob_to_array(data, np.float32, (-1, 2)))
383
+ for image_id, data in db.execute(
384
+ "SELECT image_id, data FROM keypoints"))
385
+
386
+ assert np.allclose(keypoints[image_id1], keypoints1)
387
+ assert np.allclose(keypoints[image_id2], keypoints2)
388
+ assert np.allclose(keypoints[image_id3], keypoints3)
389
+ assert np.allclose(keypoints[image_id4], keypoints4)
390
+
391
+ # Read and check matches.
392
+
393
+ pair_ids = [image_ids_to_pair_id(*pair) for pair in
394
+ ((image_id1, image_id2),
395
+ (image_id2, image_id3),
396
+ (image_id3, image_id4))]
397
+
398
+ matches = dict(
399
+ (pair_id_to_image_ids(pair_id),
400
+ blob_to_array(data, np.uint32, (-1, 2)))
401
+ for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
402
+ )
403
+
404
+ assert np.all(matches[(image_id1, image_id2)] == matches12)
405
+ assert np.all(matches[(image_id2, image_id3)] == matches23)
406
+ assert np.all(matches[(image_id3, image_id4)] == matches34)
407
+
408
+ # Clean up.
409
+
410
+ db.close()
411
+
412
+ if os.path.exists(args.database_path):
413
+ os.remove(args.database_path)
414
+
415
+
416
+ if __name__ == "__main__":
417
+ example_usage()
imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import cv2
3
+ import os
4
+ import numpy as np
5
+ from .read_write_model import read_images_binary
6
+
7
+
8
+ def align_model(model, rot, trans, scale):
9
+ return (np.matmul(rot, model) + trans) * scale
10
+
11
+
12
+ def align(model, data):
13
+ '''
14
+ Source: https://vision.in.tum.de/data/datasets/rgbd-dataset/tools
15
+ #absolute_trajectory_error_ate
16
+ Align two trajectories using the method of Horn (closed-form).
17
+
18
+ Input:
19
+ model -- first trajectory (3xn)
20
+ data -- second trajectory (3xn)
21
+
22
+ Output:
23
+ rot -- rotation matrix (3x3)
24
+ trans -- translation vector (3x1)
25
+ trans_error -- translational error per point (1xn)
26
+
27
+ '''
28
+
29
+ if model.shape[1] < 3:
30
+ print('Need at least 3 points for ATE: {}'.format(model))
31
+ return np.identity(3), np.zeros((3, 1)), 1
32
+
33
+ # Get zero centered point cloud
34
+ model_zerocentered = model - model.mean(1, keepdims=True)
35
+ data_zerocentered = data - data.mean(1, keepdims=True)
36
+
37
+ # constructed covariance matrix
38
+ W = np.zeros((3, 3))
39
+ for column in range(model.shape[1]):
40
+ W += np.outer(model_zerocentered[:, column],
41
+ data_zerocentered[:, column])
42
+
43
+ # SVD
44
+ U, d, Vh = np.linalg.linalg.svd(W.transpose())
45
+ S = np.identity(3)
46
+ if (np.linalg.det(U) * np.linalg.det(Vh) < 0):
47
+ S[2, 2] = -1
48
+ rot = np.matmul(np.matmul(U, S), Vh)
49
+ trans = data.mean(1, keepdims=True) - np.matmul(
50
+ rot, model.mean(1, keepdims=True))
51
+
52
+ # apply rot and trans to point cloud
53
+ model_aligned = align_model(model, rot, trans, 1.0)
54
+ model_aligned_zerocentered = model_aligned - model_aligned.mean(
55
+ 1, keepdims=True)
56
+
57
+ # calc scale based on distance to point cloud center
58
+ data_dist = np.sqrt((data_zerocentered * data_zerocentered).sum(axis=0))
59
+ model_aligned_dist = np.sqrt(
60
+ (model_aligned_zerocentered * model_aligned_zerocentered).sum(axis=0))
61
+ scale_array = data_dist / model_aligned_dist
62
+ scale = np.median(scale_array)
63
+
64
+ return rot, trans, scale
65
+
66
+
67
+ def quaternion_matrix(quaternion):
68
+ '''Return homogeneous rotation matrix from quaternion.
69
+
70
+ >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
71
+ >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
72
+ True
73
+ >>> M = quaternion_matrix([1, 0, 0, 0])
74
+ >>> numpy.allclose(M, numpy.identity(4))
75
+ True
76
+ >>> M = quaternion_matrix([0, 1, 0, 0])
77
+ >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
78
+ True
79
+ '''
80
+
81
+ q = np.array(quaternion, dtype=np.float64, copy=True)
82
+ n = np.dot(q, q)
83
+ if n < _EPS:
84
+ return np.identity(4)
85
+
86
+ q *= math.sqrt(2.0 / n)
87
+ q = np.outer(q, q)
88
+
89
+ return np.array(
90
+ [[1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0],
91
+ [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0],
92
+ [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0],
93
+ [0.0, 0.0, 0.0, 1.0]])
94
+
95
+
96
+ def quaternion_from_matrix(matrix, isprecise=False):
97
+ '''Return quaternion from rotation matrix.
98
+
99
+ If isprecise is True, the input matrix is assumed to be a precise rotation
100
+ matrix and a faster algorithm is used.
101
+
102
+ >>> q = quaternion_from_matrix(numpy.identity(4), True)
103
+ >>> numpy.allclose(q, [1, 0, 0, 0])
104
+ True
105
+ >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
106
+ >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
107
+ True
108
+ >>> R = rotation_matrix(0.123, (1, 2, 3))
109
+ >>> q = quaternion_from_matrix(R, True)
110
+ >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
111
+ True
112
+ >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
113
+ ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
114
+ >>> q = quaternion_from_matrix(R)
115
+ >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
116
+ True
117
+ >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
118
+ ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
119
+ >>> q = quaternion_from_matrix(R)
120
+ >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
121
+ True
122
+ >>> R = random_rotation_matrix()
123
+ >>> q = quaternion_from_matrix(R)
124
+ >>> is_same_transform(R, quaternion_matrix(q))
125
+ True
126
+ >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
127
+ >>> numpy.allclose(quaternion_from_matrix(R, isprecise=False),
128
+ ... quaternion_from_matrix(R, isprecise=True))
129
+ True
130
+
131
+ '''
132
+
133
+ M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4]
134
+ if isprecise:
135
+ q = np.empty((4, ))
136
+ t = np.trace(M)
137
+ if t > M[3, 3]:
138
+ q[0] = t
139
+ q[3] = M[1, 0] - M[0, 1]
140
+ q[2] = M[0, 2] - M[2, 0]
141
+ q[1] = M[2, 1] - M[1, 2]
142
+ else:
143
+ i, j, k = 1, 2, 3
144
+ if M[1, 1] > M[0, 0]:
145
+ i, j, k = 2, 3, 1
146
+ if M[2, 2] > M[i, i]:
147
+ i, j, k = 3, 1, 2
148
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
149
+ q[i] = t
150
+ q[j] = M[i, j] + M[j, i]
151
+ q[k] = M[k, i] + M[i, k]
152
+ q[3] = M[k, j] - M[j, k]
153
+ q *= 0.5 / math.sqrt(t * M[3, 3])
154
+ else:
155
+ m00 = M[0, 0]
156
+ m01 = M[0, 1]
157
+ m02 = M[0, 2]
158
+ m10 = M[1, 0]
159
+ m11 = M[1, 1]
160
+ m12 = M[1, 2]
161
+ m20 = M[2, 0]
162
+ m21 = M[2, 1]
163
+ m22 = M[2, 2]
164
+
165
+ # symmetric matrix K
166
+ K = np.array([[m00 - m11 - m22, 0.0, 0.0, 0.0],
167
+ [m01 + m10, m11 - m00 - m22, 0.0, 0.0],
168
+ [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
169
+ [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22]])
170
+ K /= 3.0
171
+
172
+ # quaternion is eigenvector of K that corresponds to largest eigenvalue
173
+ w, V = np.linalg.eigh(K)
174
+ q = V[[3, 0, 1, 2], np.argmax(w)]
175
+
176
+ if q[0] < 0.0:
177
+ np.negative(q, q)
178
+
179
+ return q
180
+
181
+ def is_colmap_img_valid(colmap_img_file):
182
+ '''Return validity of a colmap reconstruction'''
183
+
184
+ images_bin = read_images_binary(colmap_img_file)
185
+ # Check if everything is finite for this subset
186
+ for key in images_bin.keys():
187
+ q = np.asarray(images_bin[key].qvec).flatten()
188
+ t = np.asarray(images_bin[key].tvec).flatten()
189
+
190
+ is_cur_valid = True
191
+ is_cur_valid = is_cur_valid and q.shape == (4, )
192
+ is_cur_valid = is_cur_valid and t.shape == (3, )
193
+ is_cur_valid = is_cur_valid and np.all(np.isfinite(q))
194
+ is_cur_valid = is_cur_valid and np.all(np.isfinite(t))
195
+
196
+ # If any is invalid, immediately return
197
+ if not is_cur_valid:
198
+ return False
199
+
200
+ return True
201
+
202
+ def get_best_colmap_index(colmap_output_path):
203
+ '''
204
+ Determines the colmap model with the most images if there is more than one.
205
+ '''
206
+
207
+ # First find the colmap reconstruction with the most number of images.
208
+ best_index, best_num_images = -1, 0
209
+
210
+ # Check all valid sub reconstructions.
211
+ if os.path.exists(colmap_output_path):
212
+ idx_list = [
213
+ _d for _d in os.listdir(colmap_output_path)
214
+ if os.path.isdir(os.path.join(colmap_output_path, _d))
215
+ ]
216
+ else:
217
+ idx_list = []
218
+
219
+ for cur_index in idx_list:
220
+ cur_output_path = os.path.join(colmap_output_path, cur_index)
221
+ if os.path.isdir(cur_output_path):
222
+ colmap_img_file = os.path.join(cur_output_path, 'images.bin')
223
+ images_bin = read_images_binary(colmap_img_file)
224
+ # Check validity
225
+ if not is_colmap_img_valid(colmap_img_file):
226
+ continue
227
+ # Find the reconstruction with most number of images
228
+ if len(images_bin) > best_num_images:
229
+ best_index = int(cur_index)
230
+ best_num_images = len(images_bin)
231
+
232
+ return str(best_index)
imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2
+ # All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ #
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ #
10
+ # * Redistributions in binary form must reproduce the above copyright
11
+ # notice, this list of conditions and the following disclaimer in the
12
+ # documentation and/or other materials provided with the distribution.
13
+ #
14
+ # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15
+ # its contributors may be used to endorse or promote products derived
16
+ # from this software without specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
+ # POSSIBILITY OF SUCH DAMAGE.
29
+ #
30
+ # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31
+
32
+ import os
33
+ import sys
34
+ import collections
35
+ import numpy as np
36
+ import struct
37
+ import argparse
38
+
39
+
40
+ CameraModel = collections.namedtuple(
41
+ "CameraModel", ["model_id", "model_name", "num_params"])
42
+ Camera = collections.namedtuple(
43
+ "Camera", ["id", "model", "width", "height", "params"])
44
+ BaseImage = collections.namedtuple(
45
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
46
+ Point3D = collections.namedtuple(
47
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
48
+
49
+
50
+ class Image(BaseImage):
51
+ def qvec2rotmat(self):
52
+ return qvec2rotmat(self.qvec)
53
+
54
+
55
+ CAMERA_MODELS = {
56
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
57
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
58
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
59
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
60
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
61
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
62
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
63
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
64
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
65
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
66
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
67
+ }
68
+ CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
69
+ for camera_model in CAMERA_MODELS])
70
+ CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
71
+ for camera_model in CAMERA_MODELS])
72
+
73
+
74
+ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
75
+ """Read and unpack the next bytes from a binary file.
76
+ :param fid:
77
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
78
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
79
+ :param endian_character: Any of {@, =, <, >, !}
80
+ :return: Tuple of read and unpacked values.
81
+ """
82
+ data = fid.read(num_bytes)
83
+ return struct.unpack(endian_character + format_char_sequence, data)
84
+
85
+
86
+ def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
87
+ """pack and write to a binary file.
88
+ :param fid:
89
+ :param data: data to send, if multiple elements are sent at the same time,
90
+ they should be encapsuled either in a list or a tuple
91
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
92
+ should be the same length as the data list or tuple
93
+ :param endian_character: Any of {@, =, <, >, !}
94
+ """
95
+ if isinstance(data, (list, tuple)):
96
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
97
+ else:
98
+ bytes = struct.pack(endian_character + format_char_sequence, data)
99
+ fid.write(bytes)
100
+
101
+
102
+ def read_cameras_text(path):
103
+ """
104
+ see: src/base/reconstruction.cc
105
+ void Reconstruction::WriteCamerasText(const std::string& path)
106
+ void Reconstruction::ReadCamerasText(const std::string& path)
107
+ """
108
+ cameras = {}
109
+ with open(path, "r") as fid:
110
+ while True:
111
+ line = fid.readline()
112
+ if not line:
113
+ break
114
+ line = line.strip()
115
+ if len(line) > 0 and line[0] != "#":
116
+ elems = line.split()
117
+ camera_id = int(elems[0])
118
+ model = elems[1]
119
+ width = int(elems[2])
120
+ height = int(elems[3])
121
+ params = np.array(tuple(map(float, elems[4:])))
122
+ cameras[camera_id] = Camera(id=camera_id, model=model,
123
+ width=width, height=height,
124
+ params=params)
125
+ return cameras
126
+
127
+
128
+ def read_cameras_binary(path_to_model_file):
129
+ """
130
+ see: src/base/reconstruction.cc
131
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
132
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
133
+ """
134
+ cameras = {}
135
+ with open(path_to_model_file, "rb") as fid:
136
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
137
+ for _ in range(num_cameras):
138
+ camera_properties = read_next_bytes(
139
+ fid, num_bytes=24, format_char_sequence="iiQQ")
140
+ camera_id = camera_properties[0]
141
+ model_id = camera_properties[1]
142
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
143
+ width = camera_properties[2]
144
+ height = camera_properties[3]
145
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
146
+ params = read_next_bytes(fid, num_bytes=8*num_params,
147
+ format_char_sequence="d"*num_params)
148
+ cameras[camera_id] = Camera(id=camera_id,
149
+ model=model_name,
150
+ width=width,
151
+ height=height,
152
+ params=np.array(params))
153
+ assert len(cameras) == num_cameras
154
+ return cameras
155
+
156
+
157
+ def write_cameras_text(cameras, path):
158
+ """
159
+ see: src/base/reconstruction.cc
160
+ void Reconstruction::WriteCamerasText(const std::string& path)
161
+ void Reconstruction::ReadCamerasText(const std::string& path)
162
+ """
163
+ HEADER = "# Camera list with one line of data per camera:\n"
164
+ "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
165
+ "# Number of cameras: {}\n".format(len(cameras))
166
+ with open(path, "w") as fid:
167
+ fid.write(HEADER)
168
+ for _, cam in cameras.items():
169
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
170
+ line = " ".join([str(elem) for elem in to_write])
171
+ fid.write(line + "\n")
172
+
173
+
174
+ def write_cameras_binary(cameras, path_to_model_file):
175
+ """
176
+ see: src/base/reconstruction.cc
177
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
178
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
179
+ """
180
+ with open(path_to_model_file, "wb") as fid:
181
+ write_next_bytes(fid, len(cameras), "Q")
182
+ for _, cam in cameras.items():
183
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
184
+ camera_properties = [cam.id,
185
+ model_id,
186
+ cam.width,
187
+ cam.height]
188
+ write_next_bytes(fid, camera_properties, "iiQQ")
189
+ for p in cam.params:
190
+ write_next_bytes(fid, float(p), "d")
191
+ return cameras
192
+
193
+
194
+ def read_images_text(path):
195
+ """
196
+ see: src/base/reconstruction.cc
197
+ void Reconstruction::ReadImagesText(const std::string& path)
198
+ void Reconstruction::WriteImagesText(const std::string& path)
199
+ """
200
+ images = {}
201
+ with open(path, "r") as fid:
202
+ while True:
203
+ line = fid.readline()
204
+ if not line:
205
+ break
206
+ line = line.strip()
207
+ if len(line) > 0 and line[0] != "#":
208
+ elems = line.split()
209
+ image_id = int(elems[0])
210
+ qvec = np.array(tuple(map(float, elems[1:5])))
211
+ tvec = np.array(tuple(map(float, elems[5:8])))
212
+ camera_id = int(elems[8])
213
+ image_name = elems[9]
214
+ elems = fid.readline().split()
215
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
216
+ tuple(map(float, elems[1::3]))])
217
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
218
+ images[image_id] = Image(
219
+ id=image_id, qvec=qvec, tvec=tvec,
220
+ camera_id=camera_id, name=image_name,
221
+ xys=xys, point3D_ids=point3D_ids)
222
+ return images
223
+
224
+
225
+ def read_images_binary(path_to_model_file):
226
+ """
227
+ see: src/base/reconstruction.cc
228
+ void Reconstruction::ReadImagesBinary(const std::string& path)
229
+ void Reconstruction::WriteImagesBinary(const std::string& path)
230
+ """
231
+ images = {}
232
+ with open(path_to_model_file, "rb") as fid:
233
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
234
+ for _ in range(num_reg_images):
235
+ binary_image_properties = read_next_bytes(
236
+ fid, num_bytes=64, format_char_sequence="idddddddi")
237
+ image_id = binary_image_properties[0]
238
+ qvec = np.array(binary_image_properties[1:5])
239
+ tvec = np.array(binary_image_properties[5:8])
240
+ camera_id = binary_image_properties[8]
241
+ image_name = ""
242
+ current_char = read_next_bytes(fid, 1, "c")[0]
243
+ while current_char != b"\x00": # look for the ASCII 0 entry
244
+ image_name += current_char.decode("utf-8")
245
+ current_char = read_next_bytes(fid, 1, "c")[0]
246
+ num_points2D = read_next_bytes(fid, num_bytes=8,
247
+ format_char_sequence="Q")[0]
248
+ x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
249
+ format_char_sequence="ddq"*num_points2D)
250
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
251
+ tuple(map(float, x_y_id_s[1::3]))])
252
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
253
+ images[image_id] = Image(
254
+ id=image_id, qvec=qvec, tvec=tvec,
255
+ camera_id=camera_id, name=image_name,
256
+ xys=xys, point3D_ids=point3D_ids)
257
+ return images
258
+
259
+
260
+ def write_images_text(images, path):
261
+ """
262
+ see: src/base/reconstruction.cc
263
+ void Reconstruction::ReadImagesText(const std::string& path)
264
+ void Reconstruction::WriteImagesText(const std::string& path)
265
+ """
266
+ if len(images) == 0:
267
+ mean_observations = 0
268
+ else:
269
+ mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images)
270
+ HEADER = "# Image list with two lines of data per image:\n"
271
+ "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
272
+ "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
273
+ "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations)
274
+
275
+ with open(path, "w") as fid:
276
+ fid.write(HEADER)
277
+ for _, img in images.items():
278
+ image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
279
+ first_line = " ".join(map(str, image_header))
280
+ fid.write(first_line + "\n")
281
+
282
+ points_strings = []
283
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
284
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
285
+ fid.write(" ".join(points_strings) + "\n")
286
+
287
+
288
+ def write_images_binary(images, path_to_model_file):
289
+ """
290
+ see: src/base/reconstruction.cc
291
+ void Reconstruction::ReadImagesBinary(const std::string& path)
292
+ void Reconstruction::WriteImagesBinary(const std::string& path)
293
+ """
294
+ with open(path_to_model_file, "wb") as fid:
295
+ write_next_bytes(fid, len(images), "Q")
296
+ for _, img in images.items():
297
+ write_next_bytes(fid, img.id, "i")
298
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
299
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
300
+ write_next_bytes(fid, img.camera_id, "i")
301
+ for char in img.name:
302
+ write_next_bytes(fid, char.encode("utf-8"), "c")
303
+ write_next_bytes(fid, b"\x00", "c")
304
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
305
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
306
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
307
+
308
+
309
+ def read_points3D_text(path):
310
+ """
311
+ see: src/base/reconstruction.cc
312
+ void Reconstruction::ReadPoints3DText(const std::string& path)
313
+ void Reconstruction::WritePoints3DText(const std::string& path)
314
+ """
315
+ points3D = {}
316
+ with open(path, "r") as fid:
317
+ while True:
318
+ line = fid.readline()
319
+ if not line:
320
+ break
321
+ line = line.strip()
322
+ if len(line) > 0 and line[0] != "#":
323
+ elems = line.split()
324
+ point3D_id = int(elems[0])
325
+ xyz = np.array(tuple(map(float, elems[1:4])))
326
+ rgb = np.array(tuple(map(int, elems[4:7])))
327
+ error = float(elems[7])
328
+ image_ids = np.array(tuple(map(int, elems[8::2])))
329
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
330
+ points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
331
+ error=error, image_ids=image_ids,
332
+ point2D_idxs=point2D_idxs)
333
+ return points3D
334
+
335
+
336
+ def read_points3d_binary(path_to_model_file):
337
+ """
338
+ see: src/base/reconstruction.cc
339
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
340
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
341
+ """
342
+ points3D = {}
343
+ with open(path_to_model_file, "rb") as fid:
344
+ num_points = read_next_bytes(fid, 8, "Q")[0]
345
+ for _ in range(num_points):
346
+ binary_point_line_properties = read_next_bytes(
347
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
348
+ point3D_id = binary_point_line_properties[0]
349
+ xyz = np.array(binary_point_line_properties[1:4])
350
+ rgb = np.array(binary_point_line_properties[4:7])
351
+ error = np.array(binary_point_line_properties[7])
352
+ track_length = read_next_bytes(
353
+ fid, num_bytes=8, format_char_sequence="Q")[0]
354
+ track_elems = read_next_bytes(
355
+ fid, num_bytes=8*track_length,
356
+ format_char_sequence="ii"*track_length)
357
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
358
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
359
+ points3D[point3D_id] = Point3D(
360
+ id=point3D_id, xyz=xyz, rgb=rgb,
361
+ error=error, image_ids=image_ids,
362
+ point2D_idxs=point2D_idxs)
363
+ return points3D
364
+
365
+ def write_points3D_text(points3D, path):
366
+ """
367
+ see: src/base/reconstruction.cc
368
+ void Reconstruction::ReadPoints3DText(const std::string& path)
369
+ void Reconstruction::WritePoints3DText(const std::string& path)
370
+ """
371
+ if len(points3D) == 0:
372
+ mean_track_length = 0
373
+ else:
374
+ mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
375
+ HEADER = "# 3D point list with one line of data per point:\n"
376
+ "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
377
+ "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)
378
+
379
+ with open(path, "w") as fid:
380
+ fid.write(HEADER)
381
+ for _, pt in points3D.items():
382
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
383
+ fid.write(" ".join(map(str, point_header)) + " ")
384
+ track_strings = []
385
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
386
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
387
+ fid.write(" ".join(track_strings) + "\n")
388
+
389
+
390
+ def write_points3d_binary(points3D, path_to_model_file):
391
+ """
392
+ see: src/base/reconstruction.cc
393
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
394
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
395
+ """
396
+ with open(path_to_model_file, "wb") as fid:
397
+ write_next_bytes(fid, len(points3D), "Q")
398
+ for _, pt in points3D.items():
399
+ write_next_bytes(fid, pt.id, "Q")
400
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
401
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
402
+ write_next_bytes(fid, pt.error, "d")
403
+ track_length = pt.image_ids.shape[0]
404
+ write_next_bytes(fid, track_length, "Q")
405
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
406
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
407
+
408
+
409
+ def detect_model_format(path, ext):
410
+ if os.path.isfile(os.path.join(path, "cameras" + ext)) and \
411
+ os.path.isfile(os.path.join(path, "images" + ext)) and \
412
+ os.path.isfile(os.path.join(path, "points3D" + ext)):
413
+ print("Detected model format: '" + ext + "'")
414
+ return True
415
+
416
+ return False
417
+
418
+
419
+ def read_model(path, ext=""):
420
+ # try to detect the extension automatically
421
+ if ext == "":
422
+ if detect_model_format(path, ".bin"):
423
+ ext = ".bin"
424
+ elif detect_model_format(path, ".txt"):
425
+ ext = ".txt"
426
+ else:
427
+ print("Provide model format: '.bin' or '.txt'")
428
+ return
429
+
430
+ if ext == ".txt":
431
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
432
+ images = read_images_text(os.path.join(path, "images" + ext))
433
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
434
+ else:
435
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
436
+ images = read_images_binary(os.path.join(path, "images" + ext))
437
+ points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
438
+ return cameras, images, points3D
439
+
440
+
441
+ def write_model(cameras, images, points3D, path, ext=".bin"):
442
+ if ext == ".txt":
443
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
444
+ write_images_text(images, os.path.join(path, "images" + ext))
445
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
446
+ else:
447
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
448
+ write_images_binary(images, os.path.join(path, "images" + ext))
449
+ write_points3d_binary(points3D, os.path.join(path, "points3D") + ext)
450
+ return cameras, images, points3D
451
+
452
+
453
+ def qvec2rotmat(qvec):
454
+ return np.array([
455
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
456
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
457
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
458
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
459
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
460
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
461
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
462
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
463
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
464
+
465
+
466
+ def rotmat2qvec(R):
467
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
468
+ K = np.array([
469
+ [Rxx - Ryy - Rzz, 0, 0, 0],
470
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
471
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
472
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
473
+ eigvals, eigvecs = np.linalg.eigh(K)
474
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
475
+ if qvec[0] < 0:
476
+ qvec *= -1
477
+ return qvec
478
+
479
+
480
+ def main():
481
+ parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models")
482
+ parser.add_argument("--input_model", help="path to input model folder")
483
+ parser.add_argument("--input_format", choices=[".bin", ".txt"],
484
+ help="input model format", default="")
485
+ parser.add_argument("--output_model",
486
+ help="path to output model folder")
487
+ parser.add_argument("--output_format", choices=[".bin", ".txt"],
488
+ help="outut model format", default=".txt")
489
+ args = parser.parse_args()
490
+
491
+ cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
492
+
493
+ # FIXME: for debug only
494
+ # images_ = images[1]
495
+ # tvec, qvec = images_.tvec, images_.qvec
496
+ # rotation = qvec2rotmat(qvec).reshape(3, 3)
497
+ # pose = np.concatenate([rotation, tvec.reshape(3, 1)], axis=1)
498
+ # import ipdb; ipdb.set_trace()
499
+
500
+ print("num_cameras:", len(cameras))
501
+ print("num_images:", len(images))
502
+ print("num_points3D:", len(points3D))
503
+
504
+ if args.output_model is not None:
505
+ write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format)
506
+
507
+
508
+ if __name__ == "__main__":
509
+ main()
imcui/third_party/MatchAnything/src/utils/comm.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ [Copied from detectron2]
4
+ This file contains primitives for multi-gpu communication.
5
+ This is useful when doing distributed training.
6
+ """
7
+
8
+ import functools
9
+ import logging
10
+ import numpy as np
11
+ import pickle
12
+ import torch
13
+ import torch.distributed as dist
14
+
15
+ _LOCAL_PROCESS_GROUP = None
16
+ """
17
+ A torch process group which only includes processes that on the same machine as the current process.
18
+ This variable is set when processes are spawned by `launch()` in "engine/launch.py".
19
+ """
20
+
21
+
22
+ def get_world_size() -> int:
23
+ if not dist.is_available():
24
+ return 1
25
+ if not dist.is_initialized():
26
+ return 1
27
+ return dist.get_world_size()
28
+
29
+
30
+ def get_rank() -> int:
31
+ if not dist.is_available():
32
+ return 0
33
+ if not dist.is_initialized():
34
+ return 0
35
+ return dist.get_rank()
36
+
37
+
38
+ def get_local_rank() -> int:
39
+ """
40
+ Returns:
41
+ The rank of the current process within the local (per-machine) process group.
42
+ """
43
+ if not dist.is_available():
44
+ return 0
45
+ if not dist.is_initialized():
46
+ return 0
47
+ assert _LOCAL_PROCESS_GROUP is not None
48
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
49
+
50
+
51
+ def get_local_size() -> int:
52
+ """
53
+ Returns:
54
+ The size of the per-machine process group,
55
+ i.e. the number of processes per machine.
56
+ """
57
+ if not dist.is_available():
58
+ return 1
59
+ if not dist.is_initialized():
60
+ return 1
61
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
62
+
63
+
64
+ def is_main_process() -> bool:
65
+ return get_rank() == 0
66
+
67
+
68
+ def synchronize():
69
+ """
70
+ Helper function to synchronize (barrier) among all processes when
71
+ using distributed training
72
+ """
73
+ if not dist.is_available():
74
+ return
75
+ if not dist.is_initialized():
76
+ return
77
+ world_size = dist.get_world_size()
78
+ if world_size == 1:
79
+ return
80
+ dist.barrier()
81
+
82
+
83
+ @functools.lru_cache()
84
+ def _get_global_gloo_group():
85
+ """
86
+ Return a process group based on gloo backend, containing all the ranks
87
+ The result is cached.
88
+ """
89
+ if dist.get_backend() == "nccl":
90
+ return dist.new_group(backend="gloo")
91
+ else:
92
+ return dist.group.WORLD
93
+
94
+
95
+ def _serialize_to_tensor(data, group):
96
+ backend = dist.get_backend(group)
97
+ assert backend in ["gloo", "nccl"]
98
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
99
+
100
+ buffer = pickle.dumps(data)
101
+ if len(buffer) > 1024 ** 3:
102
+ logger = logging.getLogger(__name__)
103
+ logger.warning(
104
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
105
+ get_rank(), len(buffer) / (1024 ** 3), device
106
+ )
107
+ )
108
+ storage = torch.ByteStorage.from_buffer(buffer)
109
+ tensor = torch.ByteTensor(storage).to(device=device)
110
+ return tensor
111
+
112
+
113
+ def _pad_to_largest_tensor(tensor, group):
114
+ """
115
+ Returns:
116
+ list[int]: size of the tensor, on each rank
117
+ Tensor: padded tensor that has the max size
118
+ """
119
+ world_size = dist.get_world_size(group=group)
120
+ assert (
121
+ world_size >= 1
122
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
123
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
124
+ size_list = [
125
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
126
+ ]
127
+ dist.all_gather(size_list, local_size, group=group)
128
+
129
+ size_list = [int(size.item()) for size in size_list]
130
+
131
+ max_size = max(size_list)
132
+
133
+ # we pad the tensor because torch all_gather does not support
134
+ # gathering tensors of different shapes
135
+ if local_size != max_size:
136
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
137
+ tensor = torch.cat((tensor, padding), dim=0)
138
+ return size_list, tensor
139
+
140
+
141
+ def all_gather(data, group=None):
142
+ """
143
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
144
+
145
+ Args:
146
+ data: any picklable object
147
+ group: a torch process group. By default, will use a group which
148
+ contains all ranks on gloo backend.
149
+
150
+ Returns:
151
+ list[data]: list of data gathered from each rank
152
+ """
153
+ if get_world_size() == 1:
154
+ return [data]
155
+ if group is None:
156
+ group = _get_global_gloo_group()
157
+ if dist.get_world_size(group) == 1:
158
+ return [data]
159
+
160
+ tensor = _serialize_to_tensor(data, group)
161
+
162
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
163
+ max_size = max(size_list)
164
+
165
+ # receiving Tensor from all ranks
166
+ tensor_list = [
167
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
168
+ ]
169
+ dist.all_gather(tensor_list, tensor, group=group)
170
+
171
+ data_list = []
172
+ for size, tensor in zip(size_list, tensor_list):
173
+ buffer = tensor.cpu().numpy().tobytes()[:size]
174
+ data_list.append(pickle.loads(buffer))
175
+
176
+ return data_list
177
+
178
+
179
+ def gather(data, dst=0, group=None):
180
+ """
181
+ Run gather on arbitrary picklable data (not necessarily tensors).
182
+
183
+ Args:
184
+ data: any picklable object
185
+ dst (int): destination rank
186
+ group: a torch process group. By default, will use a group which
187
+ contains all ranks on gloo backend.
188
+
189
+ Returns:
190
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
191
+ an empty list.
192
+ """
193
+ if get_world_size() == 1:
194
+ return [data]
195
+ if group is None:
196
+ group = _get_global_gloo_group()
197
+ if dist.get_world_size(group=group) == 1:
198
+ return [data]
199
+ rank = dist.get_rank(group=group)
200
+
201
+ tensor = _serialize_to_tensor(data, group)
202
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
203
+
204
+ # receiving Tensor from all ranks
205
+ if rank == dst:
206
+ max_size = max(size_list)
207
+ tensor_list = [
208
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
209
+ ]
210
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
211
+
212
+ data_list = []
213
+ for size, tensor in zip(size_list, tensor_list):
214
+ buffer = tensor.cpu().numpy().tobytes()[:size]
215
+ data_list.append(pickle.loads(buffer))
216
+ return data_list
217
+ else:
218
+ dist.gather(tensor, [], dst=dst, group=group)
219
+ return []
220
+
221
+
222
+ def shared_random_seed():
223
+ """
224
+ Returns:
225
+ int: a random number that is the same across all workers.
226
+ If workers need a shared RNG, they can use this shared seed to
227
+ create one.
228
+
229
+ All workers must call this function, otherwise it will deadlock.
230
+ """
231
+ ints = np.random.randint(2 ** 31)
232
+ all_ints = all_gather(ints)
233
+ return all_ints[0]
234
+
235
+
236
+ def reduce_dict(input_dict, average=True):
237
+ """
238
+ Reduce the values in the dictionary from all processes so that process with rank
239
+ 0 has the reduced results.
240
+
241
+ Args:
242
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
243
+ average (bool): whether to do average or sum
244
+
245
+ Returns:
246
+ a dict with the same keys as input_dict, after reduction.
247
+ """
248
+ world_size = get_world_size()
249
+ if world_size < 2:
250
+ return input_dict
251
+ with torch.no_grad():
252
+ names = []
253
+ values = []
254
+ # sort the keys so that they are consistent across processes
255
+ for k in sorted(input_dict.keys()):
256
+ names.append(k)
257
+ values.append(input_dict[k])
258
+ values = torch.stack(values, dim=0)
259
+ dist.reduce(values, dst=0)
260
+ if dist.get_rank() == 0 and average:
261
+ # only main process gets accumulated, so only divide by
262
+ # world_size in this case
263
+ values /= world_size
264
+ reduced_dict = {k: v for k, v in zip(names, values)}
265
+ return reduced_dict
imcui/third_party/MatchAnything/src/utils/dataloader.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ # --- PL-DATAMODULE ---
5
+
6
+ def get_local_split(items: list, world_size: int, rank: int, seed: int):
7
+ """ The local rank only loads a split of the dataset. """
8
+ n_items = len(items)
9
+ items_permute = np.random.RandomState(seed).permutation(items)
10
+ if n_items % world_size == 0:
11
+ padded_items = items_permute
12
+ else:
13
+ padding = np.random.RandomState(seed).choice(
14
+ items,
15
+ world_size - (n_items % world_size),
16
+ replace=True)
17
+ padded_items = np.concatenate([items_permute, padding])
18
+ assert len(padded_items) % world_size == 0, \
19
+ f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
20
+ n_per_rank = len(padded_items) // world_size
21
+ local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
22
+
23
+ return local_items
imcui/third_party/MatchAnything/src/utils/dataset.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from loguru import logger
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from pathlib import Path
7
+ import h5py
8
+ import torch
9
+ import re
10
+ from PIL import Image
11
+ from numpy.linalg import inv
12
+ from torchvision.transforms import Normalize
13
+ from .sample_homo import sample_homography_sap
14
+ from kornia.geometry import homography_warp, normalize_homography, normal_transform_pixel
15
+ OSS_FOLDER_PATH = '???'
16
+ PCACHE_FOLDER_PATH = '???'
17
+
18
+ import fsspec
19
+ from PIL import Image
20
+
21
+ # Initialize pcache
22
+ try:
23
+ PCACHE_HOST = "???"
24
+ PCACHE_PORT = 00000
25
+ pcache_kwargs = {"host": PCACHE_HOST, "port": PCACHE_PORT}
26
+ pcache_fs = fsspec.filesystem("pcache", pcache_kwargs=pcache_kwargs)
27
+ root_dir='???'
28
+ except Exception as e:
29
+ logger.error(f"Error captured:{e}")
30
+
31
+ try:
32
+ # for internel use only
33
+ from pcache_fileio import fileio
34
+ except Exception:
35
+ MEGADEPTH_CLIENT = SCANNET_CLIENT = None
36
+
37
+ # --- DATA IO ---
38
+
39
+ def load_pfm(pfm_path):
40
+ with open(pfm_path, 'rb') as fin:
41
+ color = None
42
+ width = None
43
+ height = None
44
+ scale = None
45
+ data_type = None
46
+ header = str(fin.readline().decode('UTF-8')).rstrip()
47
+
48
+ if header == 'PF':
49
+ color = True
50
+ elif header == 'Pf':
51
+ color = False
52
+ else:
53
+ raise Exception('Not a PFM file.')
54
+
55
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8'))
56
+ if dim_match:
57
+ width, height = map(int, dim_match.groups())
58
+ else:
59
+ raise Exception('Malformed PFM header.')
60
+ scale = float((fin.readline().decode('UTF-8')).rstrip())
61
+ if scale < 0: # little-endian
62
+ data_type = '<f'
63
+ else:
64
+ data_type = '>f' # big-endian
65
+ data_string = fin.read()
66
+ data = np.fromstring(data_string, data_type)
67
+ shape = (height, width, 3) if color else (height, width)
68
+ data = np.reshape(data, shape)
69
+ data = np.flip(data, 0)
70
+ return data
71
+
72
+ def load_array_from_pcache(
73
+ path, cv_type,
74
+ use_h5py=False,
75
+ ):
76
+
77
+ filename = path.split(root_dir)[1]
78
+ pcache_path = Path(root_dir) / filename
79
+ try:
80
+ if not use_h5py:
81
+ load_failed = True
82
+ failed_num = 0
83
+ while load_failed:
84
+ try:
85
+ with pcache_fs.open(str(pcache_path), 'rb') as f:
86
+ data = Image.open(f).convert("L")
87
+ data = np.array(data)
88
+ load_failed = False
89
+ except:
90
+ failed_num += 1
91
+ if failed_num > 5000:
92
+ logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times")
93
+ continue
94
+ else:
95
+ load_failed = True
96
+ failed_num = 0
97
+ while load_failed:
98
+ try:
99
+ with pcache_fs.open(str(pcache_path), 'rb') as f:
100
+ data = np.array(h5py.File(io.BytesIO(f.read()), 'r')['/depth'])
101
+ load_failed = False
102
+ except:
103
+ failed_num += 1
104
+ if failed_num > 5000:
105
+ logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times")
106
+ continue
107
+
108
+ except Exception as ex:
109
+ print(f"==> Data loading failure: {path}")
110
+ raise ex
111
+
112
+ assert data is not None
113
+ return data
114
+
115
+
116
+ def imread_gray(path, augment_fn=None, cv_type=None):
117
+ if path.startswith('oss://'):
118
+ path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH)
119
+ if path.startswith('pcache://'):
120
+ path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/'
121
+
122
+ if cv_type is None:
123
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
124
+ else cv2.IMREAD_COLOR
125
+ if str(path).startswith('oss://') or str(path).startswith('pcache://'):
126
+ image = load_array_from_pcache(str(path), cv_type)
127
+ else:
128
+ image = cv2.imread(str(path), cv_type)
129
+
130
+ if augment_fn is not None:
131
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
132
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
133
+ image = augment_fn(image)
134
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
135
+ return image # (h, w)
136
+
137
+ def imread_color(path, augment_fn=None):
138
+ if path.startswith('oss://'):
139
+ path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH)
140
+ if path.startswith('pcache://'):
141
+ path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/'
142
+
143
+ if str(path).startswith('oss://') or str(path).startswith('pcache://'):
144
+ filename = path.split(root_dir)[1]
145
+ pcache_path = Path(root_dir) / filename
146
+ load_failed = True
147
+ failed_num = 0
148
+ while load_failed:
149
+ try:
150
+ with pcache_fs.open(str(pcache_path), 'rb') as f:
151
+ pil_image = Image.open(f).convert("RGB")
152
+ load_failed = False
153
+ except:
154
+ failed_num += 1
155
+ if failed_num > 5000:
156
+ logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times")
157
+ continue
158
+ else:
159
+ pil_image = Image.open(str(path)).convert("RGB")
160
+ image = np.array(pil_image)
161
+
162
+ if augment_fn is not None:
163
+ image = augment_fn(image)
164
+ return image # (h, w)
165
+
166
+
167
+ def get_resized_wh(w, h, resize=None):
168
+ if resize is not None: # resize the longer edge
169
+ scale = resize / max(h, w)
170
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
171
+ else:
172
+ w_new, h_new = w, h
173
+ return w_new, h_new
174
+
175
+
176
+ def get_divisible_wh(w, h, df=None):
177
+ if df is not None:
178
+ w_new, h_new = map(lambda x: int(x // df * df), [w, h])
179
+ else:
180
+ w_new, h_new = w, h
181
+ return w_new, h_new
182
+
183
+
184
+ def pad_bottom_right(inp, pad_size, ret_mask=False):
185
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
186
+ mask = None
187
+ if inp.ndim == 2:
188
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
189
+ padded[:inp.shape[0], :inp.shape[1]] = inp
190
+ if ret_mask:
191
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
192
+ mask[:inp.shape[0], :inp.shape[1]] = True
193
+ elif inp.ndim == 3:
194
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
195
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
196
+ if ret_mask:
197
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
198
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
199
+ mask = mask[0]
200
+ else:
201
+ raise NotImplementedError()
202
+ return padded, mask
203
+
204
+
205
+ # --- MEGADEPTH ---
206
+
207
+ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False):
208
+ """
209
+ Args:
210
+ resize (int, optional): the longer edge of resized images. None for no resize.
211
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
212
+ augment_fn (callable, optional): augments images with pre-defined visual effects
213
+ Returns:
214
+ image (torch.tensor): (1, h, w)
215
+ mask (torch.tensor): (h, w)
216
+ scale (torch.tensor): [w/w_new, h/h_new]
217
+ """
218
+ # read image
219
+ if read_gray:
220
+ image = imread_gray(path, augment_fn)
221
+ else:
222
+ image = imread_color(path, augment_fn)
223
+
224
+ # resize image
225
+ try:
226
+ w, h = image.shape[1], image.shape[0]
227
+ except:
228
+ logger.error(f"{path} not exist or read image error!")
229
+ if resize_by_stretch:
230
+ w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0])
231
+ else:
232
+ if resize:
233
+ if not isinstance(resize, int):
234
+ assert resize[0] == resize[1]
235
+ resize = resize[0]
236
+ w_new, h_new = get_resized_wh(w, h, resize)
237
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
238
+ else:
239
+ w_new, h_new = w, h
240
+
241
+ image = cv2.resize(image, (w_new, h_new))
242
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
243
+ origin_img_size = torch.tensor([h, w], dtype=torch.float)
244
+
245
+ if not read_gray:
246
+ image = image.transpose(2,0,1)
247
+
248
+ if padding: # padding
249
+ pad_to = max(h_new, w_new)
250
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
251
+ else:
252
+ mask = None
253
+
254
+ if len(image.shape) == 2:
255
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
256
+ else:
257
+ image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
258
+ if mask is not None:
259
+ mask = torch.from_numpy(mask)
260
+
261
+ if image.shape[0] == 3 and normalize_img:
262
+ # Normalize image:
263
+ image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W
264
+
265
+ return image, mask, scale, origin_img_size
266
+
267
+ def read_megadepth_gray_sample_homowarp(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False):
268
+ """
269
+ Args:
270
+ resize (int, optional): the longer edge of resized images. None for no resize.
271
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
272
+ augment_fn (callable, optional): augments images with pre-defined visual effects
273
+ Returns:
274
+ image (torch.tensor): (1, h, w)
275
+ mask (torch.tensor): (h, w)
276
+ scale (torch.tensor): [w/w_new, h/h_new]
277
+ """
278
+ # read image
279
+ if read_gray:
280
+ image = imread_gray(path, augment_fn)
281
+ else:
282
+ image = imread_color(path, augment_fn)
283
+
284
+ # resize image
285
+ w, h = image.shape[1], image.shape[0]
286
+ if resize_by_stretch:
287
+ w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0])
288
+ else:
289
+ if not isinstance(resize, int):
290
+ assert resize[0] == resize[1]
291
+ resize = resize[0]
292
+ w_new, h_new = get_resized_wh(w, h, resize)
293
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
294
+
295
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
296
+
297
+ origin_img_size = torch.tensor([h, w], dtype=torch.float)
298
+
299
+ # Sample homography and warp:
300
+ homo_sampled = sample_homography_sap(h, w) # 3*3
301
+ homo_sampled_normed = normalize_homography(
302
+ torch.from_numpy(homo_sampled[None]).to(torch.float32),
303
+ (h, w),
304
+ (h, w),
305
+ )
306
+
307
+ if len(image.shape) == 2:
308
+ image = torch.from_numpy(image).float()[None, None] / 255 # B * C * H * W
309
+ else:
310
+ image = torch.from_numpy(image).float().permute(2,0,1)[None] / 255
311
+
312
+ homo_warpped_image = homography_warp(
313
+ image, # 1 * C * H * W
314
+ torch.linalg.inv(homo_sampled_normed),
315
+ (h, w),
316
+ )
317
+ image = (homo_warpped_image[0].permute(1,2,0).numpy() * 255).astype(np.uint8)
318
+ norm_pixel_mat = normal_transform_pixel(h, w) # 1 * 3 * 3
319
+
320
+ image = cv2.resize(image, (w_new, h_new))
321
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
322
+
323
+ if not read_gray:
324
+ image = image.transpose(2,0,1)
325
+
326
+ if padding: # padding
327
+ pad_to = max(h_new, w_new)
328
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
329
+ else:
330
+ mask = None
331
+
332
+ if len(image.shape) == 2:
333
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
334
+ else:
335
+ image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
336
+ if mask is not None:
337
+ mask = torch.from_numpy(mask)
338
+
339
+ if image.shape[0] == 3 and normalize_img:
340
+ # Normalize image:
341
+ image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W
342
+
343
+ return image, mask, scale, origin_img_size, norm_pixel_mat[0], homo_sampled_normed[0]
344
+
345
+
346
+ def read_megadepth_depth_gray(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False):
347
+ """
348
+ Args:
349
+ resize (int, optional): the longer edge of resized images. None for no resize.
350
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
351
+ augment_fn (callable, optional): augments images with pre-defined visual effects
352
+ Returns:
353
+ image (torch.tensor): (1, h, w)
354
+ mask (torch.tensor): (h, w)
355
+ scale (torch.tensor): [w/w_new, h/h_new]
356
+ """
357
+ depth = read_megadepth_depth(path, return_tensor=False)
358
+
359
+ # following controlnet 1-depth
360
+ depth = depth.astype(np.float64)
361
+ depth_non_zero = depth[depth!=0]
362
+ vmin = np.percentile(depth_non_zero, 2)
363
+ vmax = np.percentile(depth_non_zero, 85)
364
+ depth -= vmin
365
+ depth /= (vmax - vmin + 1e-4)
366
+ depth = 1.0 - depth
367
+ image = (depth * 255.0).clip(0, 255).astype(np.uint8)
368
+
369
+ # resize image
370
+ w, h = image.shape[1], image.shape[0]
371
+ if resize_by_stretch:
372
+ w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0])
373
+ else:
374
+ if not isinstance(resize, int):
375
+ assert resize[0] == resize[1]
376
+ resize = resize[0]
377
+ w_new, h_new = get_resized_wh(w, h, resize)
378
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
379
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
380
+ origin_img_size = torch.tensor([h, w], dtype=torch.float)
381
+
382
+ image = cv2.resize(image, (w_new, h_new))
383
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
384
+
385
+ if padding: # padding
386
+ pad_to = max(h_new, w_new)
387
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
388
+ else:
389
+ mask = None
390
+
391
+ if read_gray:
392
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
393
+ else:
394
+ image = np.stack([image]*3) # 3 * H * W
395
+ image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
396
+ if mask is not None:
397
+ mask = torch.from_numpy(mask)
398
+
399
+ if image.shape[0] == 3 and normalize_img:
400
+ # Normalize image:
401
+ image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W
402
+
403
+ return image, mask, scale, origin_img_size
404
+
405
+ def read_megadepth_depth(path, pad_to=None, return_tensor=True):
406
+ if path.startswith('oss://'):
407
+ path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH)
408
+ if path.startswith('pcache://'):
409
+ path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/'
410
+
411
+ load_failed = True
412
+ failed_num = 0
413
+ while load_failed:
414
+ try:
415
+ if '.png' in path:
416
+ if 'scannet_plus' in path:
417
+ depth = imread_gray(path, cv_type=cv2.IMREAD_UNCHANGED).astype(np.float32)
418
+
419
+ with open(path, 'rb') as f:
420
+ # CO3D
421
+ depth = np.asarray(Image.open(f)).astype(np.float32)
422
+ depth = depth / 1000
423
+ elif '.pfm' in path:
424
+ # For BlendedMVS dataset (not support pcache):
425
+ depth = load_pfm(path).copy()
426
+ else:
427
+ # For MegaDepth
428
+ if str(path).startswith('oss://') or str(path).startswith('pcache://'):
429
+ depth = load_array_from_pcache(path, None, use_h5py=True)
430
+ else:
431
+ depth = np.array(h5py.File(path, 'r')['depth'])
432
+ load_failed = False
433
+ except:
434
+ failed_num += 1
435
+ if failed_num > 5000:
436
+ logger.error(f"Try to load: {path}, but failed {failed_num} times")
437
+ continue
438
+
439
+ if pad_to is not None:
440
+ depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
441
+ if return_tensor:
442
+ depth = torch.from_numpy(depth).float() # (h, w)
443
+ return depth
444
+
445
+
446
+ # --- ScanNet ---
447
+
448
+ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
449
+ """
450
+ Args:
451
+ resize (tuple): align image to depthmap, in (w, h).
452
+ augment_fn (callable, optional): augments images with pre-defined visual effects
453
+ Returns:
454
+ image (torch.tensor): (1, h, w)
455
+ mask (torch.tensor): (h, w)
456
+ scale (torch.tensor): [w/w_new, h/h_new]
457
+ """
458
+ # read and resize image
459
+ image = imread_gray(path, augment_fn)
460
+ image = cv2.resize(image, resize)
461
+
462
+ # (h, w) -> (1, h, w) and normalized
463
+ image = torch.from_numpy(image).float()[None] / 255
464
+ return image
465
+
466
+
467
+ def read_scannet_depth(path):
468
+ if str(path).startswith('s3://'):
469
+ depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
470
+ else:
471
+ depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
472
+ depth = depth / 1000
473
+ depth = torch.from_numpy(depth).float() # (h, w)
474
+ return depth
475
+
476
+
477
+ def read_scannet_pose(path):
478
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
479
+
480
+ Returns:
481
+ pose_w2c (np.ndarray): (4, 4)
482
+ """
483
+ cam2world = np.loadtxt(path, delimiter=' ')
484
+ world2cam = inv(cam2world)
485
+ return world2cam
486
+
487
+
488
+ def read_scannet_intrinsic(path):
489
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
490
+ """
491
+ intrinsic = np.loadtxt(path, delimiter=' ')
492
+ return intrinsic[:-1, :-1]
493
+
494
+ def dict_to_cuda(data_dict):
495
+ data_dict_cuda = {}
496
+ for k, v in data_dict.items():
497
+ if isinstance(v, torch.Tensor):
498
+ data_dict_cuda[k] = v.cuda()
499
+ elif isinstance(v, dict):
500
+ data_dict_cuda[k] = dict_to_cuda(v)
501
+ elif isinstance(v, list):
502
+ data_dict_cuda[k] = list_to_cuda(v)
503
+ else:
504
+ data_dict_cuda[k] = v
505
+ return data_dict_cuda
506
+
507
+ def list_to_cuda(data_list):
508
+ data_list_cuda = []
509
+ for obj in data_list:
510
+ if isinstance(obj, torch.Tensor):
511
+ data_list_cuda.append(obj.cuda())
512
+ elif isinstance(obj, dict):
513
+ data_list_cuda.append(dict_to_cuda(obj))
514
+ elif isinstance(obj, list):
515
+ data_list_cuda.append(list_to_cuda(obj))
516
+ else:
517
+ data_list_cuda.append(obj)
518
+ return data_list_cuda
imcui/third_party/MatchAnything/src/utils/easydict.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class EasyDict(dict):
2
+ """
3
+ Get attributes
4
+
5
+ >>> d = EasyDict({'foo':3})
6
+ >>> d['foo']
7
+ 3
8
+ >>> d.foo
9
+ 3
10
+ >>> d.bar
11
+ Traceback (most recent call last):
12
+ ...
13
+ AttributeError: 'EasyDict' object has no attribute 'bar'
14
+
15
+ Works recursively
16
+
17
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
18
+ >>> isinstance(d.bar, dict)
19
+ True
20
+ >>> d.bar.x
21
+ 1
22
+
23
+ Bullet-proof
24
+
25
+ >>> EasyDict({})
26
+ {}
27
+ >>> EasyDict(d={})
28
+ {}
29
+ >>> EasyDict(None)
30
+ {}
31
+ >>> d = {'a': 1}
32
+ >>> EasyDict(**d)
33
+ {'a': 1}
34
+
35
+ Set attributes
36
+
37
+ >>> d = EasyDict()
38
+ >>> d.foo = 3
39
+ >>> d.foo
40
+ 3
41
+ >>> d.bar = {'prop': 'value'}
42
+ >>> d.bar.prop
43
+ 'value'
44
+ >>> d
45
+ {'foo': 3, 'bar': {'prop': 'value'}}
46
+ >>> d.bar.prop = 'newer'
47
+ >>> d.bar.prop
48
+ 'newer'
49
+
50
+
51
+ Values extraction
52
+
53
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
54
+ >>> isinstance(d.bar, list)
55
+ True
56
+ >>> from operator import attrgetter
57
+ >>> map(attrgetter('x'), d.bar)
58
+ [1, 3]
59
+ >>> map(attrgetter('y'), d.bar)
60
+ [2, 4]
61
+ >>> d = EasyDict()
62
+ >>> d.keys()
63
+ []
64
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
65
+ >>> d.foo
66
+ 3
67
+ >>> d.bar.x
68
+ 1
69
+
70
+ Still like a dict though
71
+
72
+ >>> o = EasyDict({'clean':True})
73
+ >>> o.items()
74
+ [('clean', True)]
75
+
76
+ And like a class
77
+
78
+ >>> class Flower(EasyDict):
79
+ ... power = 1
80
+ ...
81
+ >>> f = Flower()
82
+ >>> f.power
83
+ 1
84
+ >>> f = Flower({'height': 12})
85
+ >>> f.height
86
+ 12
87
+ >>> f['power']
88
+ 1
89
+ >>> sorted(f.keys())
90
+ ['height', 'power']
91
+
92
+ update and pop items
93
+ >>> d = EasyDict(a=1, b='2')
94
+ >>> e = EasyDict(c=3.0, a=9.0)
95
+ >>> d.update(e)
96
+ >>> d.c
97
+ 3.0
98
+ >>> d['c']
99
+ 3.0
100
+ >>> d.get('c')
101
+ 3.0
102
+ >>> d.update(a=4, b=4)
103
+ >>> d.b
104
+ 4
105
+ >>> d.pop('a')
106
+ 4
107
+ >>> d.a
108
+ Traceback (most recent call last):
109
+ ...
110
+ AttributeError: 'EasyDict' object has no attribute 'a'
111
+ """
112
+
113
+ def __init__(self, d=None, **kwargs):
114
+ if d is None:
115
+ d = {}
116
+ if kwargs:
117
+ d.update(**kwargs)
118
+ for k, v in d.items():
119
+ setattr(self, k, v)
120
+ # Class attributes
121
+ for k in self.__class__.__dict__.keys():
122
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
123
+ setattr(self, k, getattr(self, k))
124
+
125
+ def __setattr__(self, name, value):
126
+ if isinstance(value, (list, tuple)):
127
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
128
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
129
+ value = self.__class__(value)
130
+ super(EasyDict, self).__setattr__(name, value)
131
+ super(EasyDict, self).__setitem__(name, value)
132
+
133
+ __setitem__ = __setattr__
134
+
135
+ def update(self, e=None, **f):
136
+ d = e or dict()
137
+ d.update(f)
138
+ for k in d:
139
+ setattr(self, k, d[k])
140
+
141
+ def pop(self, k, d=None):
142
+ if hasattr(self, k):
143
+ delattr(self, k)
144
+ return super(EasyDict, self).pop(k, d)
145
+
146
+
147
+ if __name__ == "__main__":
148
+ import doctest
imcui/third_party/MatchAnything/src/utils/geometry.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ # from numba import jit
6
+
7
+ pixel_coords = None
8
+
9
+ def set_id_grid(depth):
10
+ b, h, w = depth.size()
11
+ i_range = torch.arange(0, h).view(1, h, 1).expand(1,h,w).type_as(depth) # [1, H, W]
12
+ j_range = torch.arange(0, w).view(1, 1, w).expand(1,h,w).type_as(depth) # [1, H, W]
13
+ ones = torch.ones(1,h,w).type_as(depth)
14
+
15
+ pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W]
16
+ return pixel_coords
17
+
18
+ def check_sizes(input, input_name, expected):
19
+ condition = [input.ndimension() == len(expected)]
20
+ for i,size in enumerate(expected):
21
+ if size.isdigit():
22
+ condition.append(input.size(i) == int(size))
23
+ assert(all(condition)), "wrong size for {}, expected {}, got {}".format(input_name, 'x'.join(expected), list(input.size()))
24
+
25
+
26
+ def pixel2cam(depth, intrinsics_inv):
27
+ """Transform coordinates in the pixel frame to the camera frame.
28
+ Args:
29
+ depth: depth maps -- [B, H, W]
30
+ intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3]
31
+ Returns:
32
+ array of (u,v,1) cam coordinates -- [B, 3, H, W]
33
+ """
34
+ b, h, w = depth.size()
35
+ pixel_coords = set_id_grid(depth)
36
+ current_pixel_coords = pixel_coords[:,:,:h,:w].expand(b,3,h,w).reshape(b, 3, -1) # [B, 3, H*W]
37
+ cam_coords = (intrinsics_inv.float() @ current_pixel_coords.float()).reshape(b, 3, h, w)
38
+ return cam_coords * depth.unsqueeze(1)
39
+
40
+ def cam2pixel_depth(cam_coords, proj_c2p_rot, proj_c2p_tr):
41
+ """Transform coordinates in the camera frame to the pixel frame and get depth map.
42
+ Args:
43
+ cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
44
+ proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4]
45
+ proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
46
+ Returns:
47
+ tensor of [-1,1] coordinates -- [B, 2, H, W]
48
+ depth map -- [B, H, W]
49
+ """
50
+ b, _, h, w = cam_coords.size()
51
+ cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W]
52
+ if proj_c2p_rot is not None:
53
+ pcoords = proj_c2p_rot @ cam_coords_flat
54
+ else:
55
+ pcoords = cam_coords_flat
56
+
57
+ if proj_c2p_tr is not None:
58
+ pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
59
+ X = pcoords[:, 0]
60
+ Y = pcoords[:, 1]
61
+ Z = pcoords[:, 2].clamp(min=1e-3) # [B, H*W] min_depth = 1 mm
62
+
63
+ X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W]
64
+ Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W]
65
+
66
+ pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
67
+ return pixel_coords.reshape(b,h,w,2), Z.reshape(b, h, w)
68
+
69
+
70
+ def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr):
71
+ """Transform coordinates in the camera frame to the pixel frame.
72
+ Args:
73
+ cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
74
+ proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4]
75
+ proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
76
+ Returns:
77
+ array of [-1,1] coordinates -- [B, 2, H, W]
78
+ """
79
+ b, _, h, w = cam_coords.size()
80
+ cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W]
81
+ if proj_c2p_rot is not None:
82
+ pcoords = proj_c2p_rot @ cam_coords_flat
83
+ else:
84
+ pcoords = cam_coords_flat
85
+
86
+ if proj_c2p_tr is not None:
87
+ pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
88
+ X = pcoords[:, 0]
89
+ Y = pcoords[:, 1]
90
+ Z = pcoords[:, 2].clamp(min=1e-3) # [B, H*W] min_depth = 1 mm
91
+
92
+ X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W]
93
+ Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W]
94
+
95
+ pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
96
+ return pixel_coords.reshape(b,h,w,2)
97
+
98
+
99
+ def reproject_kpts(dim0_idxs, kpts, depth, rel_pose, K0, K1):
100
+ """ Reproject keypoints with depth, relative pose and camera intrinsics
101
+ Args:
102
+ dim0_idxs (torch.LoneTensor): (B*max_kpts, )
103
+ kpts (torch.LongTensor): (B, max_kpts, 2) - <x,y>
104
+ depth (torch.Tensor): (B, H, W)
105
+ rel_pose (torch.Tensor): (B, 3, 4) relative transfomation from target to source (T_0to1) --
106
+ K0: (torch.Tensor): (N, 3, 3) - (K_0)
107
+ K1: (torch.Tensor): (N, 3, 3) - (K_1)
108
+ Returns:
109
+ (torch.Tensor): (B, max_kpts, 2) the reprojected kpts
110
+ """
111
+ # pixel to camera
112
+ device = kpts.device
113
+ B, max_kpts, _ = kpts.shape
114
+
115
+ kpts = kpts.reshape(-1, 2) # (B*K, 2)
116
+ kpts_depth = depth[dim0_idxs, kpts[:, 1], kpts[:, 0]] # (B*K, )
117
+ kpts = torch.cat([kpts.float(),
118
+ torch.ones((kpts.shape[0], 1), dtype=torch.float32, device=device)], -1) # (B*K, 3)
119
+ pixel_coords = (kpts * kpts_depth[:, None]).reshape(B, max_kpts, 3).permute(0, 2, 1) # (B, 3, K)
120
+
121
+ cam_coords = K0.inverse() @ pixel_coords # (N, 3, max_kpts)
122
+ # camera1 to camera 2
123
+ rel_pose_R = rel_pose[:, :, :-1] # (B, 3, 3)
124
+ rel_pose_t = rel_pose[:, :, -1][..., None] # (B, 3, 1)
125
+ cam2_coords = rel_pose_R @ cam_coords + rel_pose_t # (B, 3, max_kpts)
126
+ # projection
127
+ pixel2_coords = K1 @ cam2_coords # (B, 3, max_kpts)
128
+ reproj_kpts = pixel2_coords[:, :-1, :] / pixel2_coords[:, -1, :][:, None].expand(-1, 2, -1)
129
+ return reproj_kpts.permute(0, 2, 1)
130
+
131
+
132
+ def check_depth_consistency(b_idxs, kpts0, depth0, kpts1, depth1, T_0to1, K0, K1,
133
+ atol=0.1, rtol=0.0):
134
+ """
135
+ Args:
136
+ b_idxs (torch.LongTensor): (n_kpts, ) the batch indices which each keypoints pairs belong to
137
+ kpts0 (torch.LongTensor): (n_kpts, 2) - <x, y>
138
+ depth0 (torch.Tensor): (B, H, W)
139
+ kpts1 (torch.LongTensor): (n_kpts, 2)
140
+ depth1 (torch.Tensor): (B, H, W)
141
+ T_0to1 (torch.Tensor): (B, 3, 4)
142
+ K0: (torch.Tensor): (N, 3, 3) - (K_0)
143
+ K1: (torch.Tensor): (N, 3, 3) - (K_1)
144
+ atol (float): the absolute tolerance for depth consistency check
145
+ rtol (float): the relative tolerance for depth consistency check
146
+ Returns:
147
+ valid_mask (torch.Tensor): (n_kpts, )
148
+ Notes:
149
+ The two corresponding keypoints are depth consistent if the following equation is held:
150
+ abs(kpt_0to1_depth - kpt1_depth) <= (atol + rtol * abs(kpt1_depth))
151
+ * In the initial reimplementation, `atol=0.1, rtol=0` is used, and the result is better with
152
+ `atol=1.0, rtol=0` (which nearly ignore the depth consistency check).
153
+ * However, the author suggests using `atol=0.0, rtol=0.1` as in https://github.com/magicleap/SuperGluePretrainedNetwork/issues/31#issuecomment-681866054
154
+ """
155
+ device = kpts0.device
156
+ n_kpts = kpts0.shape[0]
157
+
158
+ kpts0_depth = depth0[b_idxs, kpts0[:, 1], kpts0[:, 0]] # (n_kpts, )
159
+ kpts1_depth = depth1[b_idxs, kpts1[:, 1], kpts1[:, 0]] # (n_kpts, )
160
+ kpts0 = torch.cat([kpts0.float(),
161
+ torch.ones((n_kpts, 1), dtype=torch.float32, device=device)], -1) # (n_kpts, 3)
162
+ pixel_coords = (kpts0 * kpts0_depth[:, None])[..., None] # (n_kpts, 3, 1)
163
+
164
+ # indexing from T_0to1 and K - treat all kpts as a batch
165
+ K0 = K0[b_idxs, :, :] # (n_kpts, 3, 3)
166
+ T_0to1 = T_0to1[b_idxs, :, :] # (n_kpts, 3, 4)
167
+ cam_coords = K0.inverse() @ pixel_coords # (n_kpts, 3, 1)
168
+
169
+ # camera1 to camera2
170
+ R_0to1 = T_0to1[:, :, :-1] # (n_kpts, 3, 3)
171
+ t_0to1 = T_0to1[:, :, -1][..., None] # (n_kpts, 3, 1)
172
+ cam1_coords = R_0to1 @ cam_coords + t_0to1 # (n_kpts, 3, 1)
173
+ K1 = K1[b_idxs, :, :] # (n_kpts, 3, 3)
174
+ pixel1_coords = K1 @ cam1_coords # (n_kpts, 3, 1)
175
+ kpts_0to1_depth = pixel1_coords[:, -1, 0] # (n_kpts, )
176
+ return (kpts_0to1_depth - kpts1_depth).abs() <= atol + rtol * kpts1_depth.abs()
177
+
178
+
179
+ def inverse_warp(img, depth, pose, intrinsics, mode='bilinear', padding_mode='zeros'):
180
+ """
181
+ Inverse warp a source image to the target image plane.
182
+
183
+ Args:
184
+ img: the source image (where to sample pixels) -- [B, 3, H, W]
185
+ depth: depth map of the target image -- [B, H, W]
186
+ pose: relative transfomation from target to source -- [B, 3, 4]
187
+ intrinsics: camera intrinsic matrix -- [B, 3, 3]
188
+ Returns:
189
+ projected_img: Source image warped to the target image plane
190
+ valid_points: Boolean array indicating point validity
191
+ """
192
+ # check_sizes(img, 'img', 'B3HW')
193
+ check_sizes(depth, 'depth', 'BHW')
194
+ # check_sizes(pose, 'pose', 'B6')
195
+ check_sizes(intrinsics, 'intrinsics', 'B33')
196
+
197
+ batch_size, _, img_height, img_width = img.size()
198
+
199
+ cam_coords = pixel2cam(depth, intrinsics.inverse()) # [B,3,H,W]
200
+
201
+ pose_mat = pose # (B, 3, 4)
202
+
203
+ # Get projection matrix for target camera frame to source pixel frame
204
+ proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4]
205
+
206
+ rot, tr = proj_cam_to_src_pixel[:,:,:3], proj_cam_to_src_pixel[:,:,-1:]
207
+ src_pixel_coords = cam2pixel(cam_coords, rot, tr) # [B,H,W,2]
208
+ projected_img = F.grid_sample(img, src_pixel_coords, mode=mode,
209
+ padding_mode=padding_mode, align_corners=True)
210
+
211
+ valid_points = src_pixel_coords.abs().max(dim=-1)[0] <= 1
212
+
213
+ return projected_img, valid_points
214
+
215
+ def depth_inverse_warp(depth_source, depth, pose, intrinsic_source, intrinsic, mode='nearest', padding_mode='zeros'):
216
+ """
217
+ 1. Inversely warp a source depth map to the target image plane (warped depth map still in source frame)
218
+ 2. Transform the target depth map to the source image frame
219
+ Args:
220
+ depth_source: the source image (where to sample pixels) -- [B, H, W]
221
+ depth: depth map of the target image -- [B, H, W]
222
+ pose: relative transfomation from target to source -- [B, 3, 4]
223
+ intrinsics: camera intrinsic matrix -- [B, 3, 3]
224
+ Returns:
225
+ warped_depth: Source depth warped to the target image plane -- [B, H, W]
226
+ projected_depth: Target depth projected to the source image frame -- [B, H, W]
227
+ valid_points: Boolean array indicating point validity -- [B, H, W]
228
+ """
229
+ check_sizes(depth_source, 'depth', 'BHW')
230
+ check_sizes(depth, 'depth', 'BHW')
231
+ check_sizes(intrinsic_source, 'intrinsics', 'B33')
232
+
233
+ b, h, w = depth.size()
234
+
235
+ cam_coords = pixel2cam(depth, intrinsic.inverse()) # [B,3,H,W]
236
+
237
+ pose_mat = pose # (B, 3, 4)
238
+
239
+ # Get projection matrix from target camera frame to source pixel frame
240
+ proj_cam_to_src_pixel = intrinsic_source @ pose_mat # [B, 3, 4]
241
+
242
+ rot, tr = proj_cam_to_src_pixel[:,:,:3], proj_cam_to_src_pixel[:,:,-1:]
243
+ src_pixel_coords, depth_target2src = cam2pixel_depth(cam_coords, rot, tr) # [B,H,W,2]
244
+ warped_depth = F.grid_sample(depth_source[:, None], src_pixel_coords, mode=mode,
245
+ padding_mode=padding_mode, align_corners=True) # [B, 1, H, W]
246
+
247
+ valid_points = (src_pixel_coords.abs().max(dim=-1)[0] <= 1) &\
248
+ (depth > 0.0) & (warped_depth[:, 0] > 0.0) # [B, H, W]
249
+ return warped_depth[:, 0], depth_target2src, valid_points
250
+
251
+ def to_skew(t):
252
+ """ Transform the translation vector t to skew-symmetric matrix.
253
+ Args:
254
+ t (torch.Tensor): (B, 3)
255
+ """
256
+ t_skew = t.new_ones((t.shape[0], 3, 3))
257
+ t_skew[:, 0, 1] = -t[:, 2]
258
+ t_skew[:, 1, 0] = t[:, 2]
259
+ t_skew[:, 0, 2] = t[:, 1]
260
+ t_skew[:, 2, 0] = -t[:, 1]
261
+ t_skew[:, 1, 2] = -t[:, 0]
262
+ t_skew[:, 2, 1] = t[:, 0]
263
+ return t_skew # (B, 3, 3)
264
+
265
+
266
+ def to_homogeneous(pts):
267
+ """
268
+ Args:
269
+ pts (torch.Tensor): (B, K, 2)
270
+ """
271
+ return torch.cat([pts, torch.ones_like(pts[..., :1])], -1) # (B, K, 3)
272
+
273
+
274
+ def pix2img(pts, K):
275
+ """
276
+ Args:
277
+ pts (torch.Tensor): (B, K, 2)
278
+ K (torch.Tensor): (B, 3, 3)
279
+ """
280
+ return (pts - K[:, [0, 1], [2, 2]][:, None]) / K[:, [0, 1], [0, 1]][:, None]
281
+
282
+
283
+ def weighted_blind_sed(kpts0, kpts1, weights, E, K0, K1):
284
+ """ Calculate the squared weighted blind symmetric epipolar distance, which is the sed between
285
+ all possible keypoints pairs.
286
+ Args:
287
+ kpts0 (torch.Tensor): (B, K0, 2)
288
+ ktps1 (torch.Tensor): (B, K1, 2)
289
+ weights (torch.Tensor): (B, K0, K1)
290
+ E (torch.Tensor): (B, 3, 3) - the essential matrix
291
+ K0 (torch.Tensor): (B, 3, 3)
292
+ K1 (torch.Tensor): (B, 3, 3)
293
+ Returns:
294
+ w_sed (torch.Tensor): (B, K0, K1)
295
+ """
296
+ M, N = kpts0.shape[1], kpts1.shape[1]
297
+
298
+ kpts0 = to_homogeneous(pix2img(kpts0, K0))
299
+ kpts1 = to_homogeneous(pix2img(kpts1, K1)) # (B, K1, 3)
300
+
301
+ R = kpts0 @ E.transpose(1, 2) @ kpts1.transpose(1, 2) # (B, K0, K1)
302
+ # w_R = weights * R # (B, K0, K1)
303
+
304
+ Ep0 = kpts0 @ E.transpose(1, 2) # (B, K0, 3)
305
+ Etp1 = kpts1 @ E # (B, K1, 3)
306
+ d = R**2 * (1.0 / (Ep0[..., 0]**2 + Ep0[..., 1]**2)[..., None].expand(-1, -1, N)
307
+ + 1.0 / (Etp1[..., 0]**2 + Etp1[..., 1]**2)[:, None].expand(-1, M, -1)) * weights # (B, K0, K1)
308
+ return d
309
+
310
+ def weighted_blind_sampson(kpts0, kpts1, weights, E, K0, K1):
311
+ """ Calculate the squared weighted blind sampson distance, which is the sampson distance between
312
+ all possible keypoints pairs weighted by the given weights.
313
+ """
314
+ M, N = kpts0.shape[1], kpts1.shape[1]
315
+
316
+ kpts0 = to_homogeneous(pix2img(kpts0, K0))
317
+ kpts1 = to_homogeneous(pix2img(kpts1, K1)) # (B, K1, 3)
318
+
319
+ R = kpts0 @ E.transpose(1, 2) @ kpts1.transpose(1, 2) # (B, K0, K1)
320
+ # w_R = weights * R # (B, K0, K1)
321
+
322
+ Ep0 = kpts0 @ E.transpose(1, 2) # (B, K0, 3)
323
+ Etp1 = kpts1 @ E # (B, K1, 3)
324
+ d = R**2 * (1.0 / ((Ep0[..., 0]**2 + Ep0[..., 1]**2)[..., None].expand(-1, -1, N)
325
+ + (Etp1[..., 0]**2 + Etp1[..., 1]**2)[:, None].expand(-1, M, -1))) * weights # (B, K0, K1)
326
+ return d
327
+
328
+
329
+ def angular_rel_rot(T_0to1):
330
+ """
331
+ Args:
332
+ T0_to_1 (np.ndarray): (4, 4)
333
+ """
334
+ cos = (np.trace(T_0to1[:-1, :-1]) - 1) / 2
335
+ if cos < -1:
336
+ cos = -1.0
337
+ if cos > 1:
338
+ cos = 1.0
339
+ angle_error_rot = np.rad2deg(np.abs(np.arccos(cos)))
340
+
341
+ return angle_error_rot
342
+
343
+ def angular_rel_pose(T0, T1):
344
+ """
345
+ Args:
346
+ T0 (np.ndarray): (4, 4)
347
+ T1 (np.ndarray): (4, 4)
348
+
349
+ """
350
+ cos = (np.trace(T0[:-1, :-1].T @ T1[:-1, :-1]) - 1) / 2
351
+ if cos < -1:
352
+ cos = -1.0
353
+ if cos > 1:
354
+ cos = 1.0
355
+ angle_error_rot = np.rad2deg(np.abs(np.arccos(cos)))
356
+
357
+ # calculate angular translation error
358
+ n = np.linalg.norm(T0[:-1, -1]) * np.linalg.norm(T1[:-1, -1])
359
+ cos = np.dot(T0[:-1, -1], T1[:-1, -1]) / n
360
+ if cos < -1:
361
+ cos = -1.0
362
+ if cos > 1:
363
+ cos = 1.0
364
+ angle_error_trans = np.rad2deg(np.arccos(cos))
365
+
366
+ return angle_error_rot, angle_error_trans
imcui/third_party/MatchAnything/src/utils/homography_utils.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ def to_homogeneous(points):
8
+ """Convert N-dimensional points to homogeneous coordinates.
9
+ Args:
10
+ points: torch.Tensor or numpy.ndarray with size (..., N).
11
+ Returns:
12
+ A torch.Tensor or numpy.ndarray with size (..., N+1).
13
+ """
14
+ if isinstance(points, torch.Tensor):
15
+ pad = points.new_ones(points.shape[:-1] + (1,))
16
+ return torch.cat([points, pad], dim=-1)
17
+ elif isinstance(points, np.ndarray):
18
+ pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
19
+ return np.concatenate([points, pad], axis=-1)
20
+ else:
21
+ raise ValueError
22
+
23
+
24
+ def from_homogeneous(points, eps=0.0):
25
+ """Remove the homogeneous dimension of N-dimensional points.
26
+ Args:
27
+ points: torch.Tensor or numpy.ndarray with size (..., N+1).
28
+ eps: Epsilon value to prevent zero division.
29
+ Returns:
30
+ A torch.Tensor or numpy ndarray with size (..., N).
31
+ """
32
+ return points[..., :-1] / (points[..., -1:] + eps)
33
+
34
+
35
+ def flat2mat(H):
36
+ return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
37
+
38
+
39
+ # Homography creation
40
+
41
+
42
+ def create_center_patch(shape, patch_shape=None):
43
+ if patch_shape is None:
44
+ patch_shape = shape
45
+ width, height = shape
46
+ pwidth, pheight = patch_shape
47
+ left = int((width - pwidth) / 2)
48
+ bottom = int((height - pheight) / 2)
49
+ right = int((width + pwidth) / 2)
50
+ top = int((height + pheight) / 2)
51
+ return np.array([[left, bottom], [left, top], [right, top], [right, bottom]])
52
+
53
+
54
+ def check_convex(patch, min_convexity=0.05):
55
+ """Checks if given polygon vertices [N,2] form a convex shape"""
56
+ for i in range(patch.shape[0]):
57
+ x1, y1 = patch[(i - 1) % patch.shape[0]]
58
+ x2, y2 = patch[i]
59
+ x3, y3 = patch[(i + 1) % patch.shape[0]]
60
+ if (x2 - x1) * (y3 - y2) - (x3 - x2) * (y2 - y1) > -min_convexity:
61
+ return False
62
+ return True
63
+
64
+
65
+ def sample_homography_corners(
66
+ shape,
67
+ patch_shape,
68
+ difficulty=1.0,
69
+ translation=0.4,
70
+ n_angles=10,
71
+ max_angle=90,
72
+ min_convexity=0.05,
73
+ rng=np.random,
74
+ ):
75
+ max_angle = max_angle / 180.0 * math.pi
76
+ width, height = shape
77
+ pwidth, pheight = width * (1 - difficulty), height * (1 - difficulty)
78
+ min_pts1 = create_center_patch(shape, (pwidth, pheight))
79
+ full = create_center_patch(shape)
80
+ pts2 = create_center_patch(patch_shape)
81
+ scale = min_pts1 - full
82
+ found_valid = False
83
+ cnt = -1
84
+ while not found_valid:
85
+ offsets = rng.uniform(0.0, 1.0, size=(4, 2)) * scale
86
+ pts1 = full + offsets
87
+ found_valid = check_convex(pts1 / np.array(shape), min_convexity)
88
+ cnt += 1
89
+
90
+ # re-center
91
+ pts1 = pts1 - np.mean(pts1, axis=0, keepdims=True)
92
+ pts1 = pts1 + np.mean(min_pts1, axis=0, keepdims=True)
93
+
94
+ # Rotation
95
+ if n_angles > 0 and difficulty > 0:
96
+ angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
97
+ rng.shuffle(angles)
98
+ rng.shuffle(angles)
99
+ angles = np.concatenate([[0.0], angles], axis=0)
100
+
101
+ center = np.mean(pts1, axis=0, keepdims=True)
102
+ rot_mat = np.reshape(
103
+ np.stack(
104
+ [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)],
105
+ axis=1,
106
+ ),
107
+ [-1, 2, 2],
108
+ )
109
+ rotated = (
110
+ np.matmul(
111
+ np.tile(np.expand_dims(pts1 - center, axis=0), [n_angles + 1, 1, 1]),
112
+ rot_mat,
113
+ )
114
+ + center
115
+ )
116
+
117
+ for idx in range(1, n_angles):
118
+ warped_points = rotated[idx] / np.array(shape)
119
+ if np.all((warped_points >= 0.0) & (warped_points < 1.0)):
120
+ pts1 = rotated[idx]
121
+ break
122
+
123
+ # Translation
124
+ if translation > 0:
125
+ min_trans = -np.min(pts1, axis=0)
126
+ max_trans = shape - np.max(pts1, axis=0)
127
+ trans = rng.uniform(min_trans, max_trans)[None]
128
+ pts1 += trans * translation * difficulty
129
+
130
+ H = compute_homography(pts1, pts2, [1.0, 1.0])
131
+ warped = warp_points(full, H, inverse=False)
132
+ return H, full, warped, patch_shape
133
+
134
+
135
+ def compute_homography(pts1_, pts2_, shape):
136
+ """Compute the homography matrix from 4 point correspondences"""
137
+ # Rescale to actual size
138
+ shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
139
+ pts1 = pts1_ * np.expand_dims(shape, axis=0)
140
+ pts2 = pts2_ * np.expand_dims(shape, axis=0)
141
+
142
+ def ax(p, q):
143
+ return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
144
+
145
+ def ay(p, q):
146
+ return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
147
+
148
+ a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
149
+ p_mat = np.transpose(
150
+ np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
151
+ )
152
+ homography = np.transpose(np.linalg.solve(a_mat, p_mat))
153
+ return flat2mat(homography)
154
+
155
+
156
+ # Point warping utils
157
+
158
+
159
+ def warp_points(points, homography, inverse=True):
160
+ """
161
+ Warp a list of points with the INVERSE of the given homography.
162
+ The inverse is used to be coherent with tf.contrib.image.transform
163
+ Arguments:
164
+ points: list of N points, shape (N, 2).
165
+ homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
166
+ Returns: a Tensor of shape (N, 2) or (B, N, 2) (depending on whether the homography
167
+ is batched) containing the new coordinates of the warped points.
168
+ """
169
+ H = homography[None] if len(homography.shape) == 2 else homography
170
+
171
+ # Get the points to the homogeneous format
172
+ num_points = points.shape[0]
173
+ points = np.concatenate([points, np.ones([num_points, 1], dtype=np.float32)], -1)
174
+
175
+ H_inv = np.transpose(np.linalg.inv(H) if inverse else H)
176
+ warped_points = np.tensordot(points, H_inv, axes=[[1], [0]])
177
+
178
+ warped_points = np.transpose(warped_points, [2, 0, 1])
179
+ warped_points[np.abs(warped_points[:, :, 2]) < 1e-8, 2] = 1e-8
180
+ warped_points = warped_points[:, :, :2] / warped_points[:, :, 2:]
181
+
182
+ return warped_points[0] if len(homography.shape) == 2 else warped_points
183
+
184
+
185
+ def warp_points_torch(points, H, inverse=True):
186
+ """
187
+ Warp a list of points with the INVERSE of the given homography.
188
+ The inverse is used to be coherent with tf.contrib.image.transform
189
+ Arguments:
190
+ points: batched list of N points, shape (B, N, 2).
191
+ H: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
192
+ inverse: Whether to multiply the points by H or the inverse of H
193
+ Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps.
194
+ """
195
+
196
+ # Get the points to the homogeneous format
197
+ points = to_homogeneous(points)
198
+
199
+ # Apply the homography
200
+ H_mat = (torch.inverse(H) if inverse else H).transpose(-2, -1)
201
+ warped_points = torch.einsum("...nj,...ji->...ni", points, H_mat)
202
+
203
+ warped_points = from_homogeneous(warped_points, eps=1e-5)
204
+ return warped_points
205
+
206
+
207
+ # Line warping utils
208
+
209
+
210
+ def seg_equation(segs):
211
+ # calculate list of start, end and midpoints points from both lists
212
+ start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(
213
+ segs[..., 1, :]
214
+ )
215
+ # Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1
216
+ lines = torch.cross(start_points, end_points, dim=-1)
217
+ lines_norm = torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None]
218
+ assert torch.all(
219
+ lines_norm > 0
220
+ ), "Error: trying to compute the equation of a line with a single point"
221
+ lines = lines / lines_norm
222
+ return lines
223
+
224
+
225
+ def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]):
226
+ h, w = img_shape
227
+ return (
228
+ (pts >= 0).all(dim=-1)
229
+ & (pts[..., 0] < w)
230
+ & (pts[..., 1] < h)
231
+ & (~torch.isinf(pts).any(dim=-1))
232
+ )
233
+
234
+
235
+ def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor:
236
+ """
237
+ Shrink an array of segments to fit inside the image.
238
+ :param segs: The tensor of segments with shape (N, 2, 2)
239
+ :param img_shape: The image shape in format (H, W)
240
+ """
241
+ EPS = 1e-4
242
+ device = segs.device
243
+ w, h = img_shape[1], img_shape[0]
244
+ # Project the segments to the reference image
245
+ segs = segs.clone()
246
+ eqs = seg_equation(segs)
247
+ x0, y0 = torch.tensor([1.0, 0, 0.0], device=device), torch.tensor(
248
+ [0.0, 1, 0], device=device
249
+ )
250
+ x0 = x0.repeat(eqs.shape[:-1] + (1,))
251
+ y0 = y0.repeat(eqs.shape[:-1] + (1,))
252
+ pt_x0s = torch.cross(eqs, x0, dim=-1)
253
+ pt_x0s = pt_x0s[..., :-1] / pt_x0s[..., None, -1]
254
+ pt_x0s_valid = is_inside_img(pt_x0s, img_shape)
255
+ pt_y0s = torch.cross(eqs, y0, dim=-1)
256
+ pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1]
257
+ pt_y0s_valid = is_inside_img(pt_y0s, img_shape)
258
+
259
+ xW = torch.tensor([1.0, 0, EPS - w], device=device)
260
+ yH = torch.tensor([0.0, 1, EPS - h], device=device)
261
+ xW = xW.repeat(eqs.shape[:-1] + (1,))
262
+ yH = yH.repeat(eqs.shape[:-1] + (1,))
263
+ pt_xWs = torch.cross(eqs, xW, dim=-1)
264
+ pt_xWs = pt_xWs[..., :-1] / pt_xWs[..., None, -1]
265
+ pt_xWs_valid = is_inside_img(pt_xWs, img_shape)
266
+ pt_yHs = torch.cross(eqs, yH, dim=-1)
267
+ pt_yHs = pt_yHs[..., :-1] / pt_yHs[..., None, -1]
268
+ pt_yHs_valid = is_inside_img(pt_yHs, img_shape)
269
+
270
+ # If the X coordinate of the first endpoint is out
271
+ mask = (segs[..., 0, 0] < 0) & pt_x0s_valid
272
+ segs[mask, 0, :] = pt_x0s[mask]
273
+ mask = (segs[..., 0, 0] > (w - 1)) & pt_xWs_valid
274
+ segs[mask, 0, :] = pt_xWs[mask]
275
+ # If the X coordinate of the second endpoint is out
276
+ mask = (segs[..., 1, 0] < 0) & pt_x0s_valid
277
+ segs[mask, 1, :] = pt_x0s[mask]
278
+ mask = (segs[:, 1, 0] > (w - 1)) & pt_xWs_valid
279
+ segs[mask, 1, :] = pt_xWs[mask]
280
+ # If the Y coordinate of the first endpoint is out
281
+ mask = (segs[..., 0, 1] < 0) & pt_y0s_valid
282
+ segs[mask, 0, :] = pt_y0s[mask]
283
+ mask = (segs[..., 0, 1] > (h - 1)) & pt_yHs_valid
284
+ segs[mask, 0, :] = pt_yHs[mask]
285
+ # If the Y coordinate of the second endpoint is out
286
+ mask = (segs[..., 1, 1] < 0) & pt_y0s_valid
287
+ segs[mask, 1, :] = pt_y0s[mask]
288
+ mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid
289
+ segs[mask, 1, :] = pt_yHs[mask]
290
+
291
+ assert (
292
+ torch.all(segs >= 0)
293
+ and torch.all(segs[..., 0] < w)
294
+ and torch.all(segs[..., 1] < h)
295
+ )
296
+ return segs
297
+
298
+
299
+ def warp_lines_torch(
300
+ lines, H, inverse=True, dst_shape: Tuple[int, int] = None
301
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
302
+ """
303
+ :param lines: A tensor of shape (B, N, 2, 2)
304
+ where B is the batch size, N the number of lines.
305
+ :param H: The homography used to convert the lines.
306
+ batched or not (shapes (B, 3, 3) and (3, 3) respectively).
307
+ :param inverse: Whether to apply H or the inverse of H
308
+ :param dst_shape:If provided, lines are trimmed to be inside the image
309
+ """
310
+ device = lines.device
311
+ batch_size = len(lines)
312
+ lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(
313
+ lines.shape
314
+ )
315
+
316
+ if dst_shape is None:
317
+ return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device)
318
+
319
+ out_img = torch.any(
320
+ (lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1
321
+ )
322
+ valid = ~out_img.all(-1)
323
+ any_out_of_img = out_img.any(-1)
324
+ lines_to_trim = valid & any_out_of_img
325
+
326
+ for b in range(batch_size):
327
+ lines_to_trim_mask_b = lines_to_trim[b]
328
+ lines_to_trim_b = lines[b][lines_to_trim_mask_b]
329
+ corrected_lines = shrink_segs_to_img(lines_to_trim_b, dst_shape)
330
+ lines[b][lines_to_trim_mask_b] = corrected_lines
331
+
332
+ return lines, valid
333
+
334
+
335
+ # Homography evaluation utils
336
+
337
+
338
+ def sym_homography_error(kpts0, kpts1, T_0to1):
339
+ kpts0_1 = from_homogeneous(to_homogeneous(kpts0) @ T_0to1.transpose(-1, -2))
340
+ dist0_1 = ((kpts0_1 - kpts1) ** 2).sum(-1).sqrt()
341
+
342
+ kpts1_0 = from_homogeneous(
343
+ to_homogeneous(kpts1) @ torch.pinverse(T_0to1.transpose(-1, -2))
344
+ )
345
+ dist1_0 = ((kpts1_0 - kpts0) ** 2).sum(-1).sqrt()
346
+
347
+ return (dist0_1 + dist1_0) / 2.0
348
+
349
+
350
+ def sym_homography_error_all(kpts0, kpts1, H):
351
+ kp0_1 = warp_points_torch(kpts0, H, inverse=False)
352
+ kp1_0 = warp_points_torch(kpts1, H, inverse=True)
353
+
354
+ # build a distance matrix of size [... x M x N]
355
+ dist0 = torch.sum((kp0_1.unsqueeze(-2) - kpts1.unsqueeze(-3)) ** 2, -1).sqrt()
356
+ dist1 = torch.sum((kpts0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1).sqrt()
357
+ return (dist0 + dist1) / 2.0
358
+
359
+
360
+ def homography_corner_error(T, T_gt, image_size):
361
+ W, H = image_size[..., 0], image_size[..., 1]
362
+ corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T)
363
+ corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2))
364
+ corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2))
365
+ d = torch.sqrt(((corners1 - corners1_gt) ** 2).sum(-1))
366
+ return d.mean(-1)
imcui/third_party/MatchAnything/src/utils/metrics.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ from collections import OrderedDict
5
+ from loguru import logger
6
+ from .homography_utils import warp_points, warp_points_torch
7
+ from kornia.geometry.epipolar import numeric
8
+ from kornia.geometry.conversions import convert_points_to_homogeneous
9
+ import pprint
10
+
11
+
12
+ # --- METRICS ---
13
+
14
+ def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
15
+ # angle error between 2 vectors
16
+ t_gt = T_0to1[:3, 3]
17
+ n = np.linalg.norm(t) * np.linalg.norm(t_gt)
18
+ t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
19
+ t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
20
+ if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
21
+ t_err = 0
22
+
23
+ # angle error between 2 rotation matrices
24
+ R_gt = T_0to1[:3, :3]
25
+ cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
26
+ cos = np.clip(cos, -1., 1.) # handle numercial errors
27
+ R_err = np.rad2deg(np.abs(np.arccos(cos)))
28
+
29
+ return t_err, R_err
30
+
31
+ def warp_pts_error(H_est, pts_coord, H_gt=None, pts_gt=None):
32
+ """
33
+ corner_coord: 4*2
34
+ """
35
+ if H_gt is not None:
36
+ est_warp = warp_points(pts_coord, H_est, False)
37
+ est_gt = warp_points(pts_coord, H_gt, False)
38
+ diff = est_warp - est_gt
39
+ elif pts_gt is not None:
40
+ est_warp = warp_points(pts_coord, H_est, False)
41
+ diff = est_warp - pts_gt
42
+
43
+ return np.mean(np.linalg.norm(diff, axis=1))
44
+
45
+ def homo_warp_match_distance(H_gt, kpts0, kpts1, hw):
46
+ """
47
+ corner_coord: 4*2
48
+ """
49
+ if isinstance(H_gt, np.ndarray):
50
+ kpts_warped = warp_points(kpts0, H_gt)
51
+ normalized_distance = np.linalg.norm((kpts_warped - kpts1) / hw[None, [1,0]], axis=1)
52
+ else:
53
+ kpts_warped = warp_points_torch(kpts0, H_gt)
54
+ normalized_distance = torch.linalg.norm((kpts_warped - kpts1) / hw[None, [1,0]], axis=1)
55
+ return normalized_distance
56
+
57
+ def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
58
+ """Squared symmetric epipolar distance.
59
+ This can be seen as a biased estimation of the reprojection error.
60
+ Args:
61
+ pts0 (torch.Tensor): [N, 2]
62
+ E (torch.Tensor): [3, 3]
63
+ """
64
+ pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
65
+ pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
66
+ pts0 = convert_points_to_homogeneous(pts0)
67
+ pts1 = convert_points_to_homogeneous(pts1)
68
+
69
+ Ep0 = pts0 @ E.T # [N, 3]
70
+ p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
71
+ Etp1 = pts1 @ E # [N, 3]
72
+
73
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
74
+ return d
75
+
76
+
77
+ def compute_symmetrical_epipolar_errors(data, config):
78
+ """
79
+ Update:
80
+ data (dict):{"epi_errs": [M]}
81
+ """
82
+ Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
83
+ E_mat = Tx @ data['T_0to1'][:, :3, :3]
84
+
85
+ m_bids = data['m_bids']
86
+ pts0 = data['mkpts0_f']
87
+ pts1 = data['mkpts1_f'].clone().detach()
88
+
89
+ if config.LOFTR.FINE.MTD_SPVS:
90
+ m_bids = data['m_bids_f'] if 'm_bids_f' in data else data['m_bids']
91
+ epi_errs = []
92
+ for bs in range(Tx.size(0)):
93
+ mask = m_bids == bs
94
+ epi_errs.append(
95
+ symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
96
+ epi_errs = torch.cat(epi_errs, dim=0)
97
+
98
+ data.update({'epi_errs': epi_errs})
99
+
100
+ def compute_homo_match_warp_errors(data, config):
101
+ """
102
+ Update:
103
+ data (dict):{"epi_errs": [M]}
104
+ """
105
+
106
+ homography_gt = data['homography']
107
+ m_bids = data['m_bids']
108
+ pts0 = data['mkpts0_f']
109
+ pts1 = data['mkpts1_f']
110
+ origin_img0_size = data['origin_img_size0']
111
+
112
+ if config.LOFTR.FINE.MTD_SPVS:
113
+ m_bids = data['m_bids_f'] if 'm_bids_f' in data else data['m_bids']
114
+ epi_errs = []
115
+ for bs in range(homography_gt.shape[0]):
116
+ mask = m_bids == bs
117
+ epi_errs.append(
118
+ homo_warp_match_distance(homography_gt[bs], pts0[mask], pts1[mask], origin_img0_size[bs]))
119
+ epi_errs = torch.cat(epi_errs, dim=0)
120
+
121
+ data.update({'epi_errs': epi_errs})
122
+
123
+
124
+ def compute_symmetrical_epipolar_errors_gt(data, config):
125
+ """
126
+ Update:
127
+ data (dict):{"epi_errs": [M]}
128
+ """
129
+ Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
130
+ E_mat = Tx @ data['T_0to1'][:, :3, :3]
131
+
132
+ m_bids = data['m_bids']
133
+ pts0 = data['mkpts0_f_gt']
134
+ pts1 = data['mkpts1_f_gt']
135
+
136
+ epi_errs = []
137
+ for bs in range(Tx.size(0)):
138
+ # mask = m_bids == bs
139
+ assert bs == 0
140
+ mask = torch.tensor([True]*pts0.shape[0], device = pts0.device)
141
+ if config.LOFTR.FINE.MTD_SPVS:
142
+ epi_errs.append(
143
+ symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
144
+ else:
145
+ epi_errs.append(
146
+ symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
147
+ epi_errs = torch.cat(epi_errs, dim=0)
148
+
149
+ data.update({'epi_errs': epi_errs})
150
+
151
+
152
+ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
153
+ if len(kpts0) < 5:
154
+ return None
155
+ # normalize keypoints
156
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
157
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
158
+
159
+ # normalize ransac threshold
160
+ ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
161
+
162
+ # compute pose with cv2
163
+ E, mask = cv2.findEssentialMat(
164
+ kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC)
165
+ if E is None:
166
+ print("\nE is None while trying to recover pose.\n")
167
+ return None
168
+
169
+ # recover pose from E
170
+ best_num_inliers = 0
171
+ ret = None
172
+ for _E in np.split(E, len(E) / 3):
173
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
174
+ if n > best_num_inliers:
175
+ ret = (R, t[:, 0], mask.ravel() > 0)
176
+ best_num_inliers = n
177
+
178
+ return ret
179
+
180
+ def estimate_homo(kpts0, kpts1, thresh, conf=0.99999, mode='affine'):
181
+ if mode == 'affine':
182
+ H_est, inliers = cv2.estimateAffine2D(kpts0, kpts1, ransacReprojThreshold=thresh, confidence=conf, method=cv2.RANSAC)
183
+ if H_est is None:
184
+ return np.eye(3) * 0, np.empty((0))
185
+ H_est = np.concatenate([H_est, np.array([[0, 0, 1]])], axis=0) # 3 * 3
186
+ elif mode == 'homo':
187
+ H_est, inliers = cv2.findHomography(kpts0, kpts1, method=cv2.LMEDS, ransacReprojThreshold=thresh)
188
+ if H_est is None:
189
+ return np.eye(3) * 0, np.empty((0))
190
+
191
+ return H_est, inliers
192
+
193
+ def compute_homo_corner_warp_errors(data, config):
194
+ """
195
+ Update:
196
+ data (dict):{
197
+ "R_errs" List[float]: [N] # Actually warp error
198
+ "t_errs" List[float]: [N] # Zero, place holder
199
+ "inliers" List[np.ndarray]: [N]
200
+ }
201
+ """
202
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
203
+ conf = config.TRAINER.RANSAC_CONF # 0.99999
204
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
205
+
206
+ if config.LOFTR.FINE.MTD_SPVS:
207
+ m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy()
208
+
209
+ else:
210
+ m_bids = data['m_bids'].cpu().numpy()
211
+ pts0 = data['mkpts0_f'].cpu().numpy()
212
+ pts1 = data['mkpts1_f'].cpu().numpy()
213
+ homography_gt = data['homography'].cpu().numpy()
214
+ origin_size_0 = data['origin_img_size0'].cpu().numpy()
215
+
216
+ for bs in range(homography_gt.shape[0]):
217
+ mask = m_bids == bs
218
+ ret = estimate_homo(pts0[mask], pts1[mask], pixel_thr, conf=conf)
219
+
220
+ if ret is None:
221
+ data['R_errs'].append(np.inf)
222
+ data['t_errs'].append(np.inf)
223
+ data['inliers'].append(np.array([]).astype(bool))
224
+ else:
225
+ H_est, inliers = ret
226
+ corner_coord = np.array([[0, 0], [0, origin_size_0[bs][0]], [origin_size_0[bs][1], 0], [origin_size_0[bs][1], origin_size_0[bs][0]]])
227
+ corner_warp_distance = warp_pts_error(H_est, corner_coord, H_gt=homography_gt[bs])
228
+ data['R_errs'].append(corner_warp_distance)
229
+ data['t_errs'].append(0)
230
+ data['inliers'].append(inliers)
231
+
232
+ def compute_warp_control_pts_errors(data, config):
233
+ """
234
+ Update:
235
+ data (dict):{
236
+ "R_errs" List[float]: [N] # Actually warp error
237
+ "t_errs" List[float]: [N] # Zero, place holder
238
+ "inliers" List[np.ndarray]: [N]
239
+ }
240
+ """
241
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
242
+ conf = config.TRAINER.RANSAC_CONF # 0.99999
243
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
244
+
245
+ if config.LOFTR.FINE.MTD_SPVS:
246
+ m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy()
247
+
248
+ else:
249
+ m_bids = data['m_bids'].cpu().numpy()
250
+ pts0 = data['mkpts0_f'].cpu().numpy()
251
+ pts1 = data['mkpts1_f'].cpu().numpy()
252
+ gt_2D_matches = data["gt_2D_matches"].cpu().numpy()
253
+
254
+ data.update({'epi_errs': torch.zeros(m_bids.shape[0])})
255
+ for bs in range(gt_2D_matches.shape[0]):
256
+ mask = m_bids == bs
257
+ ret = estimate_homo(pts0[mask], pts1[mask], pixel_thr, conf=conf, mode=config.TRAINER.WARP_ESTIMATOR_MODEL)
258
+
259
+ if ret is None:
260
+ data['R_errs'].append(np.inf)
261
+ data['t_errs'].append(np.inf)
262
+ data['inliers'].append(np.array([]).astype(bool))
263
+ else:
264
+ H_est, inliers = ret
265
+ img0_pts, img1_pts = gt_2D_matches[bs][:, :2], gt_2D_matches[bs][:, 2:]
266
+ pts_warp_distance = warp_pts_error(H_est, img0_pts, pts_gt=img1_pts)
267
+ print(pts_warp_distance)
268
+ data['R_errs'].append(pts_warp_distance)
269
+ data['t_errs'].append(0)
270
+ data['inliers'].append(inliers)
271
+
272
+ def compute_pose_errors(data, config):
273
+ """
274
+ Update:
275
+ data (dict):{
276
+ "R_errs" List[float]: [N]
277
+ "t_errs" List[float]: [N]
278
+ "inliers" List[np.ndarray]: [N]
279
+ }
280
+ """
281
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
282
+ conf = config.TRAINER.RANSAC_CONF # 0.99999
283
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
284
+
285
+ if config.LOFTR.FINE.MTD_SPVS:
286
+ m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy()
287
+
288
+ else:
289
+ m_bids = data['m_bids'].cpu().numpy()
290
+ pts0 = data['mkpts0_f'].cpu().numpy()
291
+ pts1 = data['mkpts1_f'].cpu().numpy()
292
+ K0 = data['K0'].cpu().numpy()
293
+ K1 = data['K1'].cpu().numpy()
294
+ T_0to1 = data['T_0to1'].cpu().numpy()
295
+
296
+ for bs in range(K0.shape[0]):
297
+ mask = m_bids == bs
298
+ if config.LOFTR.EVAL_TIMES >= 1:
299
+ bpts0, bpts1 = pts0[mask], pts1[mask]
300
+ R_list, T_list, inliers_list = [], [], []
301
+ for _ in range(5):
302
+ shuffling = np.random.permutation(np.arange(len(bpts0)))
303
+ if _ >= config.LOFTR.EVAL_TIMES:
304
+ continue
305
+ bpts0 = bpts0[shuffling]
306
+ bpts1 = bpts1[shuffling]
307
+
308
+ ret = estimate_pose(bpts0, bpts1, K0[bs], K1[bs], pixel_thr, conf=conf)
309
+ if ret is None:
310
+ R_list.append(np.inf)
311
+ T_list.append(np.inf)
312
+ inliers_list.append(np.array([]).astype(bool))
313
+ print('Pose error: inf')
314
+ else:
315
+ R, t, inliers = ret
316
+ t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
317
+ R_list.append(R_err)
318
+ T_list.append(t_err)
319
+ inliers_list.append(inliers)
320
+ print(f'Pose error: {max(R_err, t_err)}')
321
+ R_err_mean = np.array(R_list).mean()
322
+ T_err_mean = np.array(T_list).mean()
323
+ # inliers_mean = np.array(inliers_list).mean()
324
+
325
+ data['R_errs'].append(R_list)
326
+ data['t_errs'].append(T_list)
327
+ data['inliers'].append(inliers_list[0])
328
+
329
+ else:
330
+ ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
331
+
332
+ if ret is None:
333
+ data['R_errs'].append(np.inf)
334
+ data['t_errs'].append(np.inf)
335
+ data['inliers'].append(np.array([]).astype(bool))
336
+ print('Pose error: inf')
337
+ else:
338
+ R, t, inliers = ret
339
+ t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
340
+ data['R_errs'].append(R_err)
341
+ data['t_errs'].append(t_err)
342
+ data['inliers'].append(inliers)
343
+ print(f'Pose error: {max(R_err, t_err)}')
344
+
345
+
346
+ # --- METRIC AGGREGATION ---
347
+ def error_rmse(error):
348
+ squard_errors = np.square(error) # N * 2
349
+ mse = np.mean(np.sum(squard_errors, axis=1))
350
+ rmse = np.sqrt(mse)
351
+ return rmse
352
+
353
+ def error_mae(error):
354
+ abs_diff = np.abs(error) # N * 2
355
+ absolute_errors = np.sum(abs_diff, axis=1)
356
+
357
+ # Return the maximum absolute error
358
+ mae = np.max(absolute_errors)
359
+ return mae
360
+
361
+ def error_auc(errors, thresholds, method='exact_auc'):
362
+ """
363
+ Args:
364
+ errors (list): [N,]
365
+ thresholds (list)
366
+ """
367
+ if method == 'exact_auc':
368
+ errors = [0] + sorted(list(errors))
369
+ recall = list(np.linspace(0, 1, len(errors)))
370
+
371
+ aucs = []
372
+ for thr in thresholds:
373
+ last_index = np.searchsorted(errors, thr)
374
+ y = recall[:last_index] + [recall[last_index-1]]
375
+ x = errors[:last_index] + [thr]
376
+ aucs.append(np.trapz(y, x) / thr)
377
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
378
+ elif method == 'fire_paper':
379
+ aucs = []
380
+ for threshold in thresholds:
381
+ accum_error = 0
382
+ percent_error_below = np.zeros(threshold + 1)
383
+ for i in range(1, threshold + 1):
384
+ percent_error_below[i] = np.sum(errors < i) * 100 / len(errors)
385
+ accum_error += percent_error_below[i]
386
+
387
+ aucs.append(accum_error / (threshold * 100))
388
+
389
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
390
+ elif method == 'success_rate':
391
+ aucs = []
392
+ for threshold in thresholds:
393
+ aucs.append((errors < threshold).astype(float).mean())
394
+ return {f'SR@{t}': auc for t, auc in zip(thresholds, aucs)}
395
+ else:
396
+ raise NotImplementedError
397
+
398
+
399
+ def epidist_prec(errors, thresholds, ret_dict=False):
400
+ precs = []
401
+ for thr in thresholds:
402
+ prec_ = []
403
+ for errs in errors:
404
+ correct_mask = errs < thr
405
+ prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
406
+ precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
407
+ if ret_dict:
408
+ return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
409
+ else:
410
+ return precs
411
+
412
+
413
+ def aggregate_metrics(metrics, epi_err_thr=5e-4, eval_n_time=1, threshold=[5, 10, 20], method='exact_auc'):
414
+ """ Aggregate metrics for the whole dataset:
415
+ (This method should be called once per dataset)
416
+ 1. AUC of the pose error (angular) at the threshold [5, 10, 20]
417
+ 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
418
+ """
419
+ # filter duplicates
420
+ unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
421
+ unq_ids = list(unq_ids.values())
422
+ logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
423
+
424
+ # pose auc
425
+ angular_thresholds = threshold
426
+ if eval_n_time >= 1:
427
+ pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0).reshape(-1, eval_n_time)[unq_ids].reshape(-1)
428
+ else:
429
+ pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
430
+ logger.info('num of pose_errors: {}'.format(pose_errors.shape))
431
+ aucs = error_auc(pose_errors, angular_thresholds, method=method) # (auc@5, auc@10, auc@20)
432
+
433
+ if eval_n_time >= 1:
434
+ for i in range(eval_n_time):
435
+ aucs_i = error_auc(pose_errors.reshape(-1, eval_n_time)[:,i], angular_thresholds, method=method)
436
+ logger.info('\n' + f'results of {i}-th RANSAC' + pprint.pformat(aucs_i))
437
+ # matching precision
438
+ dist_thresholds = [epi_err_thr]
439
+ precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr)
440
+
441
+ u_num_mathces = np.array(metrics['num_matches'], dtype=object)[unq_ids]
442
+ u_percent_inliers = np.array(metrics['percent_inliers'], dtype=object)[unq_ids]
443
+ num_matches = {f'num_matches': u_num_mathces.mean() }
444
+ percent_inliers = {f'percent_inliers': u_percent_inliers.mean()}
445
+ return {**aucs, **precs, **num_matches, **percent_inliers}
imcui/third_party/MatchAnything/src/utils/misc.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import contextlib
3
+ import joblib
4
+ from typing import Union
5
+ from loguru import _Logger, logger
6
+ from itertools import chain
7
+
8
+ import torch
9
+ from yacs.config import CfgNode as CN
10
+ from pytorch_lightning.utilities import rank_zero_only
11
+
12
+
13
+ def lower_config(yacs_cfg):
14
+ if not isinstance(yacs_cfg, CN):
15
+ return yacs_cfg
16
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
17
+
18
+
19
+ def upper_config(dict_cfg):
20
+ if not isinstance(dict_cfg, dict):
21
+ return dict_cfg
22
+ return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
23
+
24
+
25
+ def log_on(condition, message, level):
26
+ if condition:
27
+ assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
28
+ logger.log(level, message)
29
+
30
+
31
+ def get_rank_zero_only_logger(logger: _Logger):
32
+ if rank_zero_only.rank == 0:
33
+ return logger
34
+ else:
35
+ for _level in logger._core.levels.keys():
36
+ level = _level.lower()
37
+ setattr(logger, level,
38
+ lambda x: None)
39
+ logger._log = lambda x: None
40
+ return logger
41
+
42
+
43
+ def setup_gpus(gpus: Union[str, int]) -> int:
44
+ """ A temporary fix for pytorch-lighting 1.3.x """
45
+ gpus = str(gpus)
46
+ gpu_ids = []
47
+
48
+ if ',' not in gpus:
49
+ n_gpus = int(gpus)
50
+ return n_gpus if n_gpus != -1 else torch.cuda.device_count()
51
+ else:
52
+ gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
53
+
54
+ # setup environment variables
55
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
56
+ if visible_devices is None:
57
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
58
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
59
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
60
+ logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
61
+ else:
62
+ logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
63
+ return len(gpu_ids)
64
+
65
+
66
+ def flattenList(x):
67
+ return list(chain(*x))
68
+
69
+
70
+ @contextlib.contextmanager
71
+ def tqdm_joblib(tqdm_object):
72
+ """Context manager to patch joblib to report into tqdm progress bar given as argument
73
+
74
+ Usage:
75
+ with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
76
+ Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
77
+
78
+ When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
79
+ ret_vals = Parallel(n_jobs=args.world_size)(
80
+ delayed(lambda x: _compute_cov_score(pid, *x))(param)
81
+ for param in tqdm(combinations(image_ids, 2),
82
+ desc=f'Computing cov_score of [{pid}]',
83
+ total=len(image_ids)*(len(image_ids)-1)/2))
84
+ Src: https://stackoverflow.com/a/58936697
85
+ """
86
+ class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
87
+ def __init__(self, *args, **kwargs):
88
+ super().__init__(*args, **kwargs)
89
+
90
+ def __call__(self, *args, **kwargs):
91
+ tqdm_object.update(n=self.batch_size)
92
+ return super().__call__(*args, **kwargs)
93
+
94
+ old_batch_callback = joblib.parallel.BatchCompletionCallBack
95
+ joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
96
+ try:
97
+ yield tqdm_object
98
+ finally:
99
+ joblib.parallel.BatchCompletionCallBack = old_batch_callback
100
+ tqdm_object.close()
101
+