Spaces:
Running
on
Zero
Running
on
Zero
ADD: MatchAnything
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- imcui/third_party/MatchAnything/LICENSE +202 -0
- imcui/third_party/MatchAnything/README.md +104 -0
- imcui/third_party/MatchAnything/configs/models/eloftr_model.py +128 -0
- imcui/third_party/MatchAnything/configs/models/roma_model.py +27 -0
- imcui/third_party/MatchAnything/environment.yaml +14 -0
- imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py +1 -0
- imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py +344 -0
- imcui/third_party/MatchAnything/requirements.txt +22 -0
- imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh +17 -0
- imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh +17 -0
- imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh +17 -0
- imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh +17 -0
- imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh +17 -0
- imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh +17 -0
- imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh +17 -0
- imcui/third_party/MatchAnything/src/__init__.py +0 -0
- imcui/third_party/MatchAnything/src/config/default.py +344 -0
- imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py +343 -0
- imcui/third_party/MatchAnything/src/loftr/__init__.py +1 -0
- imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py +61 -0
- imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py +319 -0
- imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py +1094 -0
- imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py +131 -0
- imcui/third_party/MatchAnything/src/loftr/loftr.py +273 -0
- imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py +2 -0
- imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py +350 -0
- imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py +217 -0
- imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py +1768 -0
- imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py +76 -0
- imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py +266 -0
- imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py +493 -0
- imcui/third_party/MatchAnything/src/loftr/utils/geometry.py +298 -0
- imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py +131 -0
- imcui/third_party/MatchAnything/src/loftr/utils/supervision.py +475 -0
- imcui/third_party/MatchAnything/src/optimizers/__init__.py +50 -0
- imcui/third_party/MatchAnything/src/utils/__init__.py +0 -0
- imcui/third_party/MatchAnything/src/utils/augment.py +55 -0
- imcui/third_party/MatchAnything/src/utils/colmap.py +530 -0
- imcui/third_party/MatchAnything/src/utils/colmap/__init__.py +0 -0
- imcui/third_party/MatchAnything/src/utils/colmap/database.py +417 -0
- imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py +232 -0
- imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py +509 -0
- imcui/third_party/MatchAnything/src/utils/comm.py +265 -0
- imcui/third_party/MatchAnything/src/utils/dataloader.py +23 -0
- imcui/third_party/MatchAnything/src/utils/dataset.py +518 -0
- imcui/third_party/MatchAnything/src/utils/easydict.py +148 -0
- imcui/third_party/MatchAnything/src/utils/geometry.py +366 -0
- imcui/third_party/MatchAnything/src/utils/homography_utils.py +366 -0
- imcui/third_party/MatchAnything/src/utils/metrics.py +445 -0
- imcui/third_party/MatchAnything/src/utils/misc.py +101 -0
imcui/third_party/MatchAnything/LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
imcui/third_party/MatchAnything/README.md
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training
|
2 |
+
### [Project Page](https://zju3dv.github.io/MatchAnything) | [Paper](??)
|
3 |
+
|
4 |
+
> MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training\
|
5 |
+
> [Xingyi He](https://hxy-123.github.io/),
|
6 |
+
[Hao Yu](https://ritianyu.github.io/),
|
7 |
+
[Sida Peng](https://pengsida.net),
|
8 |
+
[Dongli Tan](https://github.com/Cuistiano),
|
9 |
+
[Zehong Shen](https://zehongs.github.io),
|
10 |
+
[Xiaowei Zhou](https://xzhou.me/),
|
11 |
+
[Hujun Bao](http://www.cad.zju.edu.cn/home/bao/)<sup>†</sup>\
|
12 |
+
> Arxiv 2025
|
13 |
+
|
14 |
+
<p align="center">
|
15 |
+
<img src=docs/teaser_demo.gif alt="animated" />
|
16 |
+
</p>
|
17 |
+
|
18 |
+
## TODO List
|
19 |
+
- [x] Pre-trained models and inference code
|
20 |
+
- [x] Huggingface demo
|
21 |
+
- [ ] Data generation and training code
|
22 |
+
- [ ] Finetune code to further train on your own data
|
23 |
+
- [ ] Incorporate more synthetic modalities and image generation methods
|
24 |
+
|
25 |
+
## Quick Start
|
26 |
+
|
27 |
+
### [<img src="https://s2.loli.net/2024/09/15/aw3rElfQAsOkNCn.png" width="20"> HuggingFace demo for MatchAnything](https://huggingface.co/spaces/LittleFrog/MatchAnything)
|
28 |
+
|
29 |
+
## Setup
|
30 |
+
Create the python environment by:
|
31 |
+
```
|
32 |
+
conda env create -f environment.yaml
|
33 |
+
conda activate env
|
34 |
+
```
|
35 |
+
We have tested our code on the device with CUDA 11.7.
|
36 |
+
|
37 |
+
Download pretrained weights from [here](https://drive.google.com/file/d/12L3g9-w8rR9K2L4rYaGaDJ7NqX1D713d/view?usp=sharing) and place it under repo directory. Then unzip it by running the following command:
|
38 |
+
```
|
39 |
+
unzip weights.zip
|
40 |
+
rm -rf weights.zip
|
41 |
+
```
|
42 |
+
|
43 |
+
## Test:
|
44 |
+
We evaluate the models pretrained by our framework using a single network weight on all cross-modality matching and registration tasks.
|
45 |
+
|
46 |
+
### Data Preparing
|
47 |
+
Download the `test_data` directory from [here](https://drive.google.com/drive/folders/1jpxIOcgnQfl9IEPPifdXQ7S7xuj9K4j7?usp=sharing) and plase it under `repo_directory/data`. Then, unzip all datasets by:
|
48 |
+
```shell
|
49 |
+
cd repo_directiry/data/test_data
|
50 |
+
|
51 |
+
for file in *.zip; do
|
52 |
+
unzip "$file" && rm "$file"
|
53 |
+
done
|
54 |
+
```
|
55 |
+
|
56 |
+
The data structure should looks like:
|
57 |
+
```
|
58 |
+
repo_directiry/data/test_data
|
59 |
+
- Liver_CT-MR
|
60 |
+
- havard_medical_matching
|
61 |
+
- remote_sense_thermal
|
62 |
+
- MTV_cross_modal_data
|
63 |
+
- thermal_visible_ground
|
64 |
+
- visible_sar_dataset
|
65 |
+
- visible_vectorized_map
|
66 |
+
```
|
67 |
+
|
68 |
+
### Evaluation
|
69 |
+
```shell
|
70 |
+
# For Tomography datasets:
|
71 |
+
sh scripts/evaluate/eval_liver_ct_mr.sh
|
72 |
+
sh scripts/evaluate/eval_harvard_brain.sh
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
# For visible-thermal datasets:
|
77 |
+
sh scripts/evaluate/eval_thermal_remote_sense.sh
|
78 |
+
sh scripts/evaluate/eval_thermal_mtv.sh
|
79 |
+
sh scripts/evaluate/eval_thermal_ground.sh
|
80 |
+
|
81 |
+
# For visible-sar dataset:
|
82 |
+
sh scripts/evaluate/eval_visible_sar.sh
|
83 |
+
|
84 |
+
# For visible-vectorized map dataset:
|
85 |
+
sh scripts/evaluate/eval_visible_vectorized_map.sh
|
86 |
+
```
|
87 |
+
|
88 |
+
# Citation
|
89 |
+
|
90 |
+
If you find this code useful for your research, please use the following BibTeX entry.
|
91 |
+
|
92 |
+
```
|
93 |
+
@inproceedings{he2025matchanything,
|
94 |
+
title={MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training},
|
95 |
+
author={He, Xingyi and Yu, Hao and Peng, Sida and Tan, Dongli and Shen, Zehong and Bao, Hujun and Zhou, Xiaowei},
|
96 |
+
booktitle={Arxiv},
|
97 |
+
year={2025}
|
98 |
+
}
|
99 |
+
```
|
100 |
+
|
101 |
+
# Acknowledgement
|
102 |
+
We thank the authors of
|
103 |
+
[ELoFTR](https://github.com/zju3dv/EfficientLoFTR),
|
104 |
+
[ROMA](https://github.com/Parskatt/RoMa) for their great works, without which our project/code would not be possible.
|
imcui/third_party/MatchAnything/configs/models/eloftr_model.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.config.default import _CN as cfg
|
2 |
+
|
3 |
+
cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
|
4 |
+
cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False
|
5 |
+
|
6 |
+
cfg.TRAINER.CANONICAL_LR = 8e-3
|
7 |
+
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
|
8 |
+
cfg.TRAINER.WARMUP_RATIO = 0.1
|
9 |
+
|
10 |
+
cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16]
|
11 |
+
|
12 |
+
# pose estimation
|
13 |
+
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
|
14 |
+
|
15 |
+
cfg.TRAINER.OPTIMIZER = "adamw"
|
16 |
+
cfg.TRAINER.ADAMW_DECAY = 0.1
|
17 |
+
|
18 |
+
cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.1
|
19 |
+
|
20 |
+
cfg.LOFTR.MATCH_COARSE.MTD_SPVS = True
|
21 |
+
cfg.LOFTR.FINE.MTD_SPVS = True
|
22 |
+
|
23 |
+
cfg.LOFTR.RESOLUTION = (8, 1) # options: [(8, 2), (16, 4)]
|
24 |
+
cfg.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be odd
|
25 |
+
cfg.LOFTR.MATCH_FINE.THR = 0
|
26 |
+
cfg.LOFTR.LOSS.FINE_TYPE = 'l2' # ['l2_with_std', 'l2']
|
27 |
+
|
28 |
+
cfg.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
|
29 |
+
|
30 |
+
cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
|
31 |
+
|
32 |
+
# PAN
|
33 |
+
cfg.LOFTR.COARSE.PAN = True
|
34 |
+
cfg.LOFTR.COARSE.POOl_SIZE = 4
|
35 |
+
cfg.LOFTR.COARSE.BN = False
|
36 |
+
cfg.LOFTR.COARSE.XFORMER = True
|
37 |
+
cfg.LOFTR.COARSE.ATTENTION = 'full' # options: ['linear', 'full']
|
38 |
+
|
39 |
+
cfg.LOFTR.FINE.PAN = False
|
40 |
+
cfg.LOFTR.FINE.POOl_SIZE = 4
|
41 |
+
cfg.LOFTR.FINE.BN = False
|
42 |
+
cfg.LOFTR.FINE.XFORMER = False
|
43 |
+
|
44 |
+
# noalign
|
45 |
+
cfg.LOFTR.ALIGN_CORNER = False
|
46 |
+
|
47 |
+
# fp16
|
48 |
+
cfg.DATASET.FP16 = False
|
49 |
+
cfg.LOFTR.FP16 = False
|
50 |
+
|
51 |
+
# DEBUG
|
52 |
+
cfg.LOFTR.FP16LOG = False
|
53 |
+
cfg.LOFTR.MATCH_COARSE.FP16LOG = False
|
54 |
+
|
55 |
+
# fine skip
|
56 |
+
cfg.LOFTR.FINE.SKIP = True
|
57 |
+
|
58 |
+
# clip
|
59 |
+
cfg.TRAINER.GRADIENT_CLIPPING = 0.5
|
60 |
+
|
61 |
+
# backbone
|
62 |
+
cfg.LOFTR.BACKBONE_TYPE = 'RepVGG'
|
63 |
+
|
64 |
+
# A1
|
65 |
+
cfg.LOFTR.RESNETFPN.INITIAL_DIM = 64
|
66 |
+
cfg.LOFTR.RESNETFPN.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3
|
67 |
+
cfg.LOFTR.COARSE.D_MODEL = 256
|
68 |
+
cfg.LOFTR.FINE.D_MODEL = 64
|
69 |
+
|
70 |
+
# FPN backbone_inter_feat with coarse_attn.
|
71 |
+
cfg.LOFTR.COARSE_FEAT_ONLY = True
|
72 |
+
cfg.LOFTR.INTER_FEAT = True
|
73 |
+
cfg.LOFTR.RESNETFPN.COARSE_FEAT_ONLY = True
|
74 |
+
cfg.LOFTR.RESNETFPN.INTER_FEAT = True
|
75 |
+
|
76 |
+
# loop back spv coarse match
|
77 |
+
cfg.LOFTR.FORCE_LOOP_BACK = False
|
78 |
+
|
79 |
+
# fix norm fine match
|
80 |
+
cfg.LOFTR.MATCH_FINE.NORMFINEM = True
|
81 |
+
|
82 |
+
# loss cf weight
|
83 |
+
cfg.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = True
|
84 |
+
cfg.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = True
|
85 |
+
|
86 |
+
# leaky relu
|
87 |
+
cfg.LOFTR.RESNETFPN.LEAKY = False
|
88 |
+
cfg.LOFTR.COARSE.LEAKY = 0.01
|
89 |
+
|
90 |
+
# prevent FP16 OVERFLOW in dirty data
|
91 |
+
cfg.LOFTR.NORM_FPNFEAT = True
|
92 |
+
cfg.LOFTR.REPLACE_NAN = True
|
93 |
+
|
94 |
+
# force mutual nearest
|
95 |
+
cfg.LOFTR.MATCH_COARSE.FORCE_NEAREST = True
|
96 |
+
cfg.LOFTR.MATCH_COARSE.THR = 0.1
|
97 |
+
|
98 |
+
# fix fine matching
|
99 |
+
cfg.LOFTR.MATCH_FINE.FIX_FINE_MATCHING = True
|
100 |
+
|
101 |
+
# dwconv
|
102 |
+
cfg.LOFTR.COARSE.DWCONV = True
|
103 |
+
|
104 |
+
# localreg
|
105 |
+
cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS = True
|
106 |
+
cfg.LOFTR.LOSS.LOCAL_WEIGHT = 0.25
|
107 |
+
|
108 |
+
# it5
|
109 |
+
cfg.LOFTR.EVAL_TIMES = 1
|
110 |
+
|
111 |
+
# rope
|
112 |
+
cfg.LOFTR.COARSE.ROPE = True
|
113 |
+
|
114 |
+
# local regress temperature
|
115 |
+
cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0
|
116 |
+
|
117 |
+
# SLICE
|
118 |
+
cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICE = True
|
119 |
+
cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
|
120 |
+
|
121 |
+
# inner with no mask [64,100]
|
122 |
+
cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_INNER = True
|
123 |
+
cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK = True
|
124 |
+
|
125 |
+
cfg.LOFTR.MATCH_FINE.TOPK = 1
|
126 |
+
cfg.LOFTR.MATCH_COARSE.FINE_TOPK = 1
|
127 |
+
|
128 |
+
cfg.LOFTR.MATCH_COARSE.FP16MATMUL = False
|
imcui/third_party/MatchAnything/configs/models/roma_model.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.config.default import _CN as cfg
|
2 |
+
cfg.ROMA.RESIZE_BY_STRETCH = True
|
3 |
+
cfg.DATASET.RESIZE_BY_STRETCH = True
|
4 |
+
|
5 |
+
cfg.TRAINER.CANONICAL_LR = 8e-3
|
6 |
+
cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
|
7 |
+
cfg.TRAINER.WARMUP_RATIO = 0.1
|
8 |
+
|
9 |
+
cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16, 18, 20]
|
10 |
+
|
11 |
+
# pose estimation
|
12 |
+
cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
|
13 |
+
|
14 |
+
cfg.TRAINER.OPTIMIZER = "adamw"
|
15 |
+
cfg.TRAINER.ADAMW_DECAY = 0.1
|
16 |
+
cfg.TRAINER.OPTIMIZER_EPS = 5e-7
|
17 |
+
|
18 |
+
cfg.TRAINER.EPI_ERR_THR = 5e-4
|
19 |
+
|
20 |
+
# fp16
|
21 |
+
cfg.DATASET.FP16 = False
|
22 |
+
cfg.LOFTR.FP16 = True
|
23 |
+
|
24 |
+
# clip
|
25 |
+
cfg.TRAINER.GRADIENT_CLIPPING = 0.5
|
26 |
+
|
27 |
+
cfg.LOFTR.ROMA_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = True
|
imcui/third_party/MatchAnything/environment.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: env
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- python=3.8
|
9 |
+
- pytorch-cuda=11.7
|
10 |
+
- pytorch=1.12.1
|
11 |
+
- torchvision=0.13.1
|
12 |
+
- pip
|
13 |
+
- pip:
|
14 |
+
- -r requirements.txt
|
imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .plotting import *
|
imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import matplotlib
|
4 |
+
from matplotlib.colors import hsv_to_rgb
|
5 |
+
import pylab as pl
|
6 |
+
import matplotlib.cm as cm
|
7 |
+
from PIL import Image
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
def visualize_features(feat, img_h, img_w, save_path=None):
|
12 |
+
from sklearn.decomposition import PCA
|
13 |
+
pca = PCA(n_components=3, svd_solver="arpack")
|
14 |
+
img = pca.fit_transform(feat).reshape(img_h * 2, img_w, 3)
|
15 |
+
img_norm = cv2.normalize(
|
16 |
+
img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC3
|
17 |
+
)
|
18 |
+
img_resized = cv2.resize(
|
19 |
+
img_norm, (img_w * 8, img_h * 2 * 8), interpolation=cv2.INTER_LINEAR
|
20 |
+
)
|
21 |
+
img_colormap = img_resized
|
22 |
+
img1, img2 = img_colormap[: img_h * 8, :, :], img_colormap[img_h * 8 :, :, :]
|
23 |
+
img_gapped = np.hstack(
|
24 |
+
(img1, np.ones((img_h * 8, 10, 3), dtype=np.uint8) * 255, img2)
|
25 |
+
)
|
26 |
+
if save_path is not None:
|
27 |
+
cv2.imwrite(save_path, img_gapped)
|
28 |
+
|
29 |
+
fig, axes = plt.subplots(1, 1, dpi=200)
|
30 |
+
axes.imshow(img_gapped)
|
31 |
+
axes.get_yaxis().set_ticks([])
|
32 |
+
axes.get_xaxis().set_ticks([])
|
33 |
+
plt.tight_layout(pad=0.5)
|
34 |
+
return fig
|
35 |
+
|
36 |
+
def make_matching_figure(
|
37 |
+
img0,
|
38 |
+
img1,
|
39 |
+
mkpts0,
|
40 |
+
mkpts1,
|
41 |
+
color,
|
42 |
+
kpts0=None,
|
43 |
+
kpts1=None,
|
44 |
+
text=[],
|
45 |
+
path=None,
|
46 |
+
draw_detection=False,
|
47 |
+
draw_match_type='corres', # ['color', 'corres', None]
|
48 |
+
r_normalize_factor=0.4,
|
49 |
+
white_center=True,
|
50 |
+
vertical=False,
|
51 |
+
use_position_color=False,
|
52 |
+
draw_local_window=False,
|
53 |
+
window_size=(9, 9),
|
54 |
+
plot_size_factor=1, # Point size and line width
|
55 |
+
anchor_pts0=None,
|
56 |
+
anchor_pts1=None,
|
57 |
+
rescale_thr=5000,
|
58 |
+
):
|
59 |
+
if (max(img0.shape) > rescale_thr) or (max(img1.shape) > rescale_thr):
|
60 |
+
scale_factor = 0.5
|
61 |
+
img0 = np.array(Image.fromarray((img0 * 255).astype(np.uint8)).resize((int(img0.shape[1] * scale_factor), int(img0.shape[0] * scale_factor)))) / 255.
|
62 |
+
img1 = np.array(Image.fromarray((img1 * 255).astype(np.uint8)).resize((int(img1.shape[1] * scale_factor), int(img1.shape[0] * scale_factor)))) / 255.
|
63 |
+
mkpts0, mkpts1 = mkpts0 * scale_factor, mkpts1 * scale_factor
|
64 |
+
if kpts0 is not None:
|
65 |
+
kpts0, kpts1 = kpts0 * scale_factor, kpts1 * scale_factor
|
66 |
+
|
67 |
+
# draw image pair
|
68 |
+
fig, axes = (
|
69 |
+
plt.subplots(2, 1, figsize=(10, 6), dpi=600)
|
70 |
+
if vertical
|
71 |
+
else plt.subplots(1, 2, figsize=(10, 6), dpi=600)
|
72 |
+
)
|
73 |
+
axes[0].imshow(img0, aspect='auto')
|
74 |
+
axes[1].imshow(img1, aspect='auto')
|
75 |
+
|
76 |
+
# axes[0].imshow(img0, aspect='equal')
|
77 |
+
# axes[1].imshow(img1, aspect='equal')
|
78 |
+
for i in range(2): # clear all frames
|
79 |
+
axes[i].get_yaxis().set_ticks([])
|
80 |
+
axes[i].get_xaxis().set_ticks([])
|
81 |
+
for spine in axes[i].spines.values():
|
82 |
+
spine.set_visible(False)
|
83 |
+
plt.tight_layout(pad=1)
|
84 |
+
|
85 |
+
if use_position_color:
|
86 |
+
mean_coord = np.mean(mkpts0, axis=0)
|
87 |
+
x_center, y_center = mean_coord
|
88 |
+
# NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive.
|
89 |
+
position_color = matching_coord2color(
|
90 |
+
mkpts0,
|
91 |
+
x_center,
|
92 |
+
y_center,
|
93 |
+
r_normalize_factor=r_normalize_factor,
|
94 |
+
white_center=white_center,
|
95 |
+
)
|
96 |
+
color[:, :3] = position_color
|
97 |
+
|
98 |
+
if draw_detection and kpts0 is not None and kpts1 is not None:
|
99 |
+
# color = 'g'
|
100 |
+
color = 'r'
|
101 |
+
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=1 * plot_size_factor)
|
102 |
+
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=1 * plot_size_factor)
|
103 |
+
|
104 |
+
if draw_match_type is 'corres':
|
105 |
+
# draw matches
|
106 |
+
fig.canvas.draw()
|
107 |
+
plt.pause(2.0)
|
108 |
+
transFigure = fig.transFigure.inverted()
|
109 |
+
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
110 |
+
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
|
111 |
+
fig.lines = [
|
112 |
+
matplotlib.lines.Line2D(
|
113 |
+
(fkpts0[i, 0], fkpts1[i, 0]),
|
114 |
+
(fkpts0[i, 1], fkpts1[i, 1]),
|
115 |
+
transform=fig.transFigure,
|
116 |
+
c=color[i],
|
117 |
+
linewidth=1* plot_size_factor,
|
118 |
+
)
|
119 |
+
for i in range(len(mkpts0))
|
120 |
+
]
|
121 |
+
|
122 |
+
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=2* plot_size_factor)
|
123 |
+
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=2* plot_size_factor)
|
124 |
+
elif draw_match_type is 'color':
|
125 |
+
# x_center = img0.shape[-1] / 2
|
126 |
+
# y_center = img1.shape[-2] / 2
|
127 |
+
|
128 |
+
mean_coord = np.mean(mkpts0, axis=0)
|
129 |
+
x_center, y_center = mean_coord
|
130 |
+
# NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive.
|
131 |
+
kpts_color = matching_coord2color(
|
132 |
+
mkpts0,
|
133 |
+
x_center,
|
134 |
+
y_center,
|
135 |
+
r_normalize_factor=r_normalize_factor,
|
136 |
+
white_center=white_center,
|
137 |
+
)
|
138 |
+
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=kpts_color, s=1 * plot_size_factor)
|
139 |
+
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=kpts_color, s=1 * plot_size_factor)
|
140 |
+
|
141 |
+
if draw_local_window:
|
142 |
+
anchor_pts0 = mkpts0 if anchor_pts0 is None else anchor_pts0
|
143 |
+
anchor_pts1 = mkpts1 if anchor_pts1 is None else anchor_pts1
|
144 |
+
plot_local_windows(
|
145 |
+
anchor_pts0, color=(1, 0, 0, 0.4), lw=0.2, ax_=0, window_size=window_size
|
146 |
+
)
|
147 |
+
plot_local_windows(
|
148 |
+
anchor_pts1, color=(1, 0, 0, 0.4), lw=0.2, ax_=1, window_size=window_size
|
149 |
+
) # lw =0.2
|
150 |
+
|
151 |
+
# put txts
|
152 |
+
txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
|
153 |
+
fig.text(
|
154 |
+
0.01,
|
155 |
+
0.99,
|
156 |
+
"\n".join(text),
|
157 |
+
transform=fig.axes[0].transAxes,
|
158 |
+
fontsize=15,
|
159 |
+
va="top",
|
160 |
+
ha="left",
|
161 |
+
color=txt_color,
|
162 |
+
)
|
163 |
+
plt.tight_layout(pad=1)
|
164 |
+
|
165 |
+
# save or return figure
|
166 |
+
if path:
|
167 |
+
plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
|
168 |
+
plt.close()
|
169 |
+
else:
|
170 |
+
return fig
|
171 |
+
|
172 |
+
def make_triple_matching_figure(
|
173 |
+
img0,
|
174 |
+
img1,
|
175 |
+
img2,
|
176 |
+
mkpts01,
|
177 |
+
mkpts12,
|
178 |
+
color01,
|
179 |
+
color12,
|
180 |
+
text=[],
|
181 |
+
path=None,
|
182 |
+
draw_match=True,
|
183 |
+
r_normalize_factor=0.4,
|
184 |
+
white_center=True,
|
185 |
+
vertical=False,
|
186 |
+
draw_local_window=False,
|
187 |
+
window_size=(9, 9),
|
188 |
+
anchor_pts0=None,
|
189 |
+
anchor_pts1=None,
|
190 |
+
):
|
191 |
+
# draw image pair
|
192 |
+
fig, axes = (
|
193 |
+
plt.subplots(3, 1, figsize=(10, 6), dpi=600)
|
194 |
+
if vertical
|
195 |
+
else plt.subplots(1, 3, figsize=(10, 6), dpi=600)
|
196 |
+
)
|
197 |
+
axes[0].imshow(img0)
|
198 |
+
axes[1].imshow(img1)
|
199 |
+
axes[2].imshow(img2)
|
200 |
+
for i in range(3): # clear all frames
|
201 |
+
axes[i].get_yaxis().set_ticks([])
|
202 |
+
axes[i].get_xaxis().set_ticks([])
|
203 |
+
for spine in axes[i].spines.values():
|
204 |
+
spine.set_visible(False)
|
205 |
+
plt.tight_layout(pad=1)
|
206 |
+
|
207 |
+
if draw_match:
|
208 |
+
# draw matches for [0,1]
|
209 |
+
fig.canvas.draw()
|
210 |
+
transFigure = fig.transFigure.inverted()
|
211 |
+
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts01[0]))
|
212 |
+
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts01[1]))
|
213 |
+
fig.lines = [
|
214 |
+
matplotlib.lines.Line2D(
|
215 |
+
(fkpts0[i, 0], fkpts1[i, 0]),
|
216 |
+
(fkpts0[i, 1], fkpts1[i, 1]),
|
217 |
+
transform=fig.transFigure,
|
218 |
+
c=color01[i],
|
219 |
+
linewidth=1,
|
220 |
+
)
|
221 |
+
for i in range(len(mkpts01[0]))
|
222 |
+
]
|
223 |
+
|
224 |
+
axes[0].scatter(mkpts01[0][:, 0], mkpts01[0][:, 1], c=color01[:, :3], s=1)
|
225 |
+
axes[1].scatter(mkpts01[1][:, 0], mkpts01[1][:, 1], c=color01[:, :3], s=1)
|
226 |
+
|
227 |
+
fig.canvas.draw()
|
228 |
+
# draw matches for [1,2]
|
229 |
+
fkpts1_1 = transFigure.transform(axes[1].transData.transform(mkpts12[0]))
|
230 |
+
fkpts2 = transFigure.transform(axes[2].transData.transform(mkpts12[1]))
|
231 |
+
fig.lines += [
|
232 |
+
matplotlib.lines.Line2D(
|
233 |
+
(fkpts1_1[i, 0], fkpts2[i, 0]),
|
234 |
+
(fkpts1_1[i, 1], fkpts2[i, 1]),
|
235 |
+
transform=fig.transFigure,
|
236 |
+
c=color12[i],
|
237 |
+
linewidth=1,
|
238 |
+
)
|
239 |
+
for i in range(len(mkpts12[0]))
|
240 |
+
]
|
241 |
+
|
242 |
+
axes[1].scatter(mkpts12[0][:, 0], mkpts12[0][:, 1], c=color12[:, :3], s=1)
|
243 |
+
axes[2].scatter(mkpts12[1][:, 0], mkpts12[1][:, 1], c=color12[:, :3], s=1)
|
244 |
+
|
245 |
+
# # put txts
|
246 |
+
# txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
|
247 |
+
# fig.text(
|
248 |
+
# 0.01,
|
249 |
+
# 0.99,
|
250 |
+
# "\n".join(text),
|
251 |
+
# transform=fig.axes[0].transAxes,
|
252 |
+
# fontsize=15,
|
253 |
+
# va="top",
|
254 |
+
# ha="left",
|
255 |
+
# color=txt_color,
|
256 |
+
# )
|
257 |
+
plt.tight_layout(pad=0.1)
|
258 |
+
|
259 |
+
# save or return figure
|
260 |
+
if path:
|
261 |
+
plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
|
262 |
+
plt.close()
|
263 |
+
else:
|
264 |
+
return fig
|
265 |
+
|
266 |
+
|
267 |
+
def matching_coord2color(kpts, x_center, y_center, r_normalize_factor=0.4, white_center=True):
|
268 |
+
"""
|
269 |
+
r_normalize_factor is used to visualize clearer according to points space distribution
|
270 |
+
r_normalize_factor maxium=1, larger->points darker/brighter
|
271 |
+
"""
|
272 |
+
if not white_center:
|
273 |
+
# dark center points
|
274 |
+
V, H = np.mgrid[0:1:10j, 0:1:360j]
|
275 |
+
S = np.ones_like(V)
|
276 |
+
else:
|
277 |
+
# white center points
|
278 |
+
S, H = np.mgrid[0:1:10j, 0:1:360j]
|
279 |
+
V = np.ones_like(S)
|
280 |
+
|
281 |
+
HSV = np.dstack((H, S, V))
|
282 |
+
RGB = hsv_to_rgb(HSV)
|
283 |
+
"""
|
284 |
+
# used to visualize hsv
|
285 |
+
pl.imshow(RGB, origin="lower", extent=[0, 360, 0, 1], aspect=150)
|
286 |
+
pl.xlabel("H")
|
287 |
+
pl.ylabel("S")
|
288 |
+
pl.title("$V_{HSV}=1$")
|
289 |
+
pl.show()
|
290 |
+
"""
|
291 |
+
kpts = np.copy(kpts)
|
292 |
+
distance = kpts - np.array([x_center, y_center])[None]
|
293 |
+
r_max = np.percentile(np.linalg.norm(distance, axis=1), 85)
|
294 |
+
# r_max = np.sqrt((x_center) ** 2 + (y_center) ** 2)
|
295 |
+
kpts[:, 0] = kpts[:, 0] - x_center # x
|
296 |
+
kpts[:, 1] = kpts[:, 1] - y_center # y
|
297 |
+
|
298 |
+
r = np.sqrt(kpts[:, 0] ** 2 + kpts[:, 1] ** 2) + 1e-6
|
299 |
+
r_normalized = r / (r_max * r_normalize_factor)
|
300 |
+
r_normalized[r_normalized > 1] = 1
|
301 |
+
r_normalized = (r_normalized) * 9
|
302 |
+
|
303 |
+
cos_theta = kpts[:, 0] / r # x / r
|
304 |
+
theta = np.arccos(cos_theta) # from 0 to pi
|
305 |
+
change_angle_mask = kpts[:, 1] < 0
|
306 |
+
theta[change_angle_mask] = 2 * np.pi - theta[change_angle_mask]
|
307 |
+
theta_degree = np.degrees(theta)
|
308 |
+
theta_degree[theta_degree == 360] = 0 # to avoid overflow
|
309 |
+
theta_degree = theta_degree / 360 * 360
|
310 |
+
kpts_color = RGB[r_normalized.astype(int), theta_degree.astype(int)]
|
311 |
+
return kpts_color
|
312 |
+
|
313 |
+
|
314 |
+
def show_image_pair(img0, img1, path=None):
|
315 |
+
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=200)
|
316 |
+
axes[0].imshow(img0, cmap="gray")
|
317 |
+
axes[1].imshow(img1, cmap="gray")
|
318 |
+
for i in range(2): # clear all frames
|
319 |
+
axes[i].get_yaxis().set_ticks([])
|
320 |
+
axes[i].get_xaxis().set_ticks([])
|
321 |
+
for spine in axes[i].spines.values():
|
322 |
+
spine.set_visible(False)
|
323 |
+
plt.tight_layout(pad=1)
|
324 |
+
if path:
|
325 |
+
plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
|
326 |
+
return fig
|
327 |
+
|
328 |
+
def plot_local_windows(kpts, color="r", lw=1, ax_=0, window_size=(9, 9)):
|
329 |
+
ax = plt.gcf().axes
|
330 |
+
for kpt in kpts:
|
331 |
+
ax[ax_].add_patch(
|
332 |
+
matplotlib.patches.Rectangle(
|
333 |
+
(
|
334 |
+
kpt[0] - (window_size[0] // 2) - 1,
|
335 |
+
kpt[1] - (window_size[1] // 2) - 1,
|
336 |
+
),
|
337 |
+
window_size[0] + 1,
|
338 |
+
window_size[1] + 1,
|
339 |
+
lw=lw,
|
340 |
+
color=color,
|
341 |
+
fill=False,
|
342 |
+
)
|
343 |
+
)
|
344 |
+
|
imcui/third_party/MatchAnything/requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv_python==4.4.0.46
|
2 |
+
albumentations==0.5.1 --no-binary=imgaug,albumentations
|
3 |
+
Pillow==9.5.0
|
4 |
+
ray==2.9.3
|
5 |
+
einops==0.3.0
|
6 |
+
kornia==0.4.1
|
7 |
+
loguru==0.5.3
|
8 |
+
yacs>=0.1.8
|
9 |
+
tqdm
|
10 |
+
autopep8
|
11 |
+
pylint
|
12 |
+
ipython
|
13 |
+
jupyterlab
|
14 |
+
matplotlib
|
15 |
+
h5py==3.1.0
|
16 |
+
pytorch-lightning==1.3.5
|
17 |
+
torchmetrics==0.6.0 # version problem: https://github.com/NVIDIA/DeepLearningExamples/issues/1113#issuecomment-1102969461
|
18 |
+
joblib>=1.0.1
|
19 |
+
pynvml
|
20 |
+
gpustat
|
21 |
+
safetensors
|
22 |
+
timm==0.6.7
|
imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
SCRIPTPATH=$(dirname $(readlink -f "$0"))
|
4 |
+
PROJECT_DIR="${SCRIPTPATH}/../../"
|
5 |
+
|
6 |
+
cd $PROJECT_DIR
|
7 |
+
|
8 |
+
DEVICE_ID='0'
|
9 |
+
NPZ_ROOT=data/test_data/havard_medical_matching/all_eval
|
10 |
+
NPZ_LIST_PATH=data/test_data/havard_medical_matching/all_eval/val_list.txt
|
11 |
+
OUTPUT_PATH=results/havard_medical_matching
|
12 |
+
|
13 |
+
# ELoFTR pretrained:
|
14 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --thr 0.05 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
15 |
+
|
16 |
+
# ROMA pretrained:
|
17 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
SCRIPTPATH=$(dirname $(readlink -f "$0"))
|
4 |
+
PROJECT_DIR="${SCRIPTPATH}/../../"
|
5 |
+
|
6 |
+
cd $PROJECT_DIR
|
7 |
+
|
8 |
+
DEVICE_ID='0'
|
9 |
+
NPZ_ROOT=data/test_data/Liver_CT-MR/eval_indexs
|
10 |
+
NPZ_LIST_PATH=data/test_data/Liver_CT-MR/eval_indexs/val_list.txt
|
11 |
+
OUTPUT_PATH=results/Liver_CT-MR
|
12 |
+
|
13 |
+
# ELoFTR pretrained:
|
14 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
15 |
+
|
16 |
+
# ROMA pretrained:
|
17 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
SCRIPTPATH=$(dirname $(readlink -f "$0"))
|
4 |
+
PROJECT_DIR="${SCRIPTPATH}/../../"
|
5 |
+
|
6 |
+
cd $PROJECT_DIR
|
7 |
+
|
8 |
+
DEVICE_ID='0'
|
9 |
+
NPZ_ROOT=data/test_data/thermal_visible_ground/eval_indexs
|
10 |
+
NPZ_LIST_PATH=data/test_data/thermal_visible_ground/eval_indexs/val_list.txt
|
11 |
+
OUTPUT_PATH=results/thermal_visible_ground
|
12 |
+
|
13 |
+
# ELoFTR pretrained:
|
14 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
15 |
+
|
16 |
+
# ROMA pretrained:
|
17 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
SCRIPTPATH=$(dirname $(readlink -f "$0"))
|
4 |
+
PROJECT_DIR="${SCRIPTPATH}/../../"
|
5 |
+
|
6 |
+
cd $PROJECT_DIR
|
7 |
+
|
8 |
+
DEVICE_ID='0'
|
9 |
+
NPZ_ROOT=data/test_data/MTV_cross_modal_data/scene_info/scene_info
|
10 |
+
NPZ_LIST_PATH=data/test_data/MTV_cross_modal_data/scene_info/test_list.txt
|
11 |
+
OUTPUT_PATH=results/MTV_cross_modal_data
|
12 |
+
|
13 |
+
# ELoFTR pretrained:
|
14 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
15 |
+
|
16 |
+
# ROMA pretrained:
|
17 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
SCRIPTPATH=$(dirname $(readlink -f "$0"))
|
4 |
+
PROJECT_DIR="${SCRIPTPATH}/../../"
|
5 |
+
|
6 |
+
cd $PROJECT_DIR
|
7 |
+
|
8 |
+
DEVICE_ID='0'
|
9 |
+
NPZ_ROOT=data/test_data/remote_sense_thermal/eval_Optical-Infrared
|
10 |
+
NPZ_LIST_PATH=data/test_data/remote_sense_thermal/eval_Optical-Infrared/val_list.txt
|
11 |
+
OUTPUT_PATH=results/remote_sense_thermal
|
12 |
+
|
13 |
+
# ELoFTR pretrained:
|
14 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
15 |
+
|
16 |
+
# ROMA pretrained:
|
17 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
SCRIPTPATH=$(dirname $(readlink -f "$0"))
|
4 |
+
PROJECT_DIR="${SCRIPTPATH}/../../"
|
5 |
+
|
6 |
+
cd $PROJECT_DIR
|
7 |
+
|
8 |
+
DEVICE_ID='0'
|
9 |
+
NPZ_ROOT=data/test_data/visible_sar_dataset/eval
|
10 |
+
NPZ_LIST_PATH=data/test_data/visible_sar_dataset/eval/val_list.txt
|
11 |
+
OUTPUT_PATH=results/visible_sar_dataset
|
12 |
+
|
13 |
+
# ELoFTR pretrained:
|
14 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --thr 0.05 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
15 |
+
|
16 |
+
# ROMA pretrained:
|
17 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
SCRIPTPATH=$(dirname $(readlink -f "$0"))
|
4 |
+
PROJECT_DIR="${SCRIPTPATH}/../../"
|
5 |
+
|
6 |
+
cd $PROJECT_DIR
|
7 |
+
|
8 |
+
DEVICE_ID='0'
|
9 |
+
NPZ_ROOT=data/test_data/visible_vectorized_map/scene_indices
|
10 |
+
NPZ_LIST_PATH=data/test_data/visible_vectorized_map/scene_indices/val_list.txt
|
11 |
+
OUTPUT_PATH=results/visible_vectorized_map
|
12 |
+
|
13 |
+
# ELoFTR pretrained:
|
14 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
15 |
+
|
16 |
+
# ROMA pretrained:
|
17 |
+
CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
|
imcui/third_party/MatchAnything/src/__init__.py
ADDED
File without changes
|
imcui/third_party/MatchAnything/src/config/default.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
_CN = CN()
|
3 |
+
############## ROMA Pipeline #########
|
4 |
+
_CN.ROMA = CN()
|
5 |
+
_CN.ROMA.MATCH_THRESH = 0.0
|
6 |
+
_CN.ROMA.RESIZE_BY_STRETCH = False # Used for test mode
|
7 |
+
_CN.ROMA.NORMALIZE_IMG = False # Used for test mode
|
8 |
+
|
9 |
+
_CN.ROMA.MODE = "train_framework" # Used in Lightning Train & Val
|
10 |
+
_CN.ROMA.MODEL = CN()
|
11 |
+
_CN.ROMA.MODEL.COARSE_BACKBONE = 'DINOv2_large'
|
12 |
+
_CN.ROMA.MODEL.COARSE_FEAT_DIM = 1024
|
13 |
+
_CN.ROMA.MODEL.MEDIUM_FEAT_DIM = 512
|
14 |
+
_CN.ROMA.MODEL.COARSE_PATCH_SIZE = 14
|
15 |
+
_CN.ROMA.MODEL.AMP = True # FP16 mode
|
16 |
+
|
17 |
+
_CN.ROMA.SAMPLE = CN()
|
18 |
+
_CN.ROMA.SAMPLE.METHOD = "threshold_balanced"
|
19 |
+
_CN.ROMA.SAMPLE.N_SAMPLE = 5000
|
20 |
+
_CN.ROMA.SAMPLE.THRESH = 0.05
|
21 |
+
|
22 |
+
_CN.ROMA.TEST_TIME = CN()
|
23 |
+
_CN.ROMA.TEST_TIME.COARSE_RES = (560, 560) # need to divisable by 14 & 8
|
24 |
+
_CN.ROMA.TEST_TIME.UPSAMPLE = True
|
25 |
+
_CN.ROMA.TEST_TIME.UPSAMPLE_RES = (864, 864) # need to divisable by 8
|
26 |
+
_CN.ROMA.TEST_TIME.SYMMETRIC = True
|
27 |
+
_CN.ROMA.TEST_TIME.ATTENUTATE_CERT = True
|
28 |
+
|
29 |
+
############## ↓ LoFTR Pipeline ↓ ##############
|
30 |
+
_CN.LOFTR = CN()
|
31 |
+
_CN.LOFTR.BACKBONE_TYPE = 'ResNetFPN'
|
32 |
+
_CN.LOFTR.ALIGN_CORNER = True
|
33 |
+
_CN.LOFTR.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
|
34 |
+
_CN.LOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
|
35 |
+
_CN.LOFTR.FINE_WINDOW_MATCHING_SIZE = 5 # window_size for loftr fine-matching, odd for select and even for average
|
36 |
+
_CN.LOFTR.FINE_CONCAT_COARSE_FEAT = True
|
37 |
+
_CN.LOFTR.FINE_SAMPLE_COARSE_FEAT = False
|
38 |
+
_CN.LOFTR.COARSE_FEAT_ONLY = False # TO BE DONE
|
39 |
+
_CN.LOFTR.INTER_FEAT = False # FPN backbone inter feat with coarse_attn.
|
40 |
+
_CN.LOFTR.FP16 = False
|
41 |
+
_CN.LOFTR.FIX_BIAS = False
|
42 |
+
_CN.LOFTR.MATCHABILITY = False
|
43 |
+
_CN.LOFTR.FORCE_LOOP_BACK = False
|
44 |
+
_CN.LOFTR.NORM_FPNFEAT = False
|
45 |
+
_CN.LOFTR.NORM_FPNFEAT2 = False
|
46 |
+
_CN.LOFTR.REPLACE_NAN = False
|
47 |
+
_CN.LOFTR.PLOT_SCORES = False
|
48 |
+
_CN.LOFTR.REP_FPN = False
|
49 |
+
_CN.LOFTR.REP_DEPLOY = False
|
50 |
+
_CN.LOFTR.EVAL_TIMES = 1
|
51 |
+
|
52 |
+
# 1. LoFTR-backbone (local feature CNN) config
|
53 |
+
_CN.LOFTR.RESNETFPN = CN()
|
54 |
+
_CN.LOFTR.RESNETFPN.INITIAL_DIM = 128
|
55 |
+
_CN.LOFTR.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
|
56 |
+
_CN.LOFTR.RESNETFPN.SAMPLE_FINE = False
|
57 |
+
_CN.LOFTR.RESNETFPN.COARSE_FEAT_ONLY = False # TO BE DONE
|
58 |
+
_CN.LOFTR.RESNETFPN.INTER_FEAT = False # FPN backbone inter feat with coarse_attn.
|
59 |
+
_CN.LOFTR.RESNETFPN.LEAKY = False
|
60 |
+
_CN.LOFTR.RESNETFPN.REPVGGMODEL = None
|
61 |
+
|
62 |
+
# 2. LoFTR-coarse module config
|
63 |
+
_CN.LOFTR.COARSE = CN()
|
64 |
+
_CN.LOFTR.COARSE.D_MODEL = 256
|
65 |
+
_CN.LOFTR.COARSE.D_FFN = 256
|
66 |
+
_CN.LOFTR.COARSE.NHEAD = 8
|
67 |
+
_CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
|
68 |
+
_CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
|
69 |
+
_CN.LOFTR.COARSE.TEMP_BUG_FIX = True
|
70 |
+
_CN.LOFTR.COARSE.NPE = False
|
71 |
+
_CN.LOFTR.COARSE.PAN = False
|
72 |
+
_CN.LOFTR.COARSE.POOl_SIZE = 4
|
73 |
+
_CN.LOFTR.COARSE.POOl_SIZE2 = 4
|
74 |
+
_CN.LOFTR.COARSE.BN = True
|
75 |
+
_CN.LOFTR.COARSE.XFORMER = False
|
76 |
+
_CN.LOFTR.COARSE.BIDIRECTION = False
|
77 |
+
_CN.LOFTR.COARSE.DEPTH_CONFIDENCE = -1.0
|
78 |
+
_CN.LOFTR.COARSE.WIDTH_CONFIDENCE = -1.0
|
79 |
+
_CN.LOFTR.COARSE.LEAKY = -1.0
|
80 |
+
_CN.LOFTR.COARSE.ASYMMETRIC = False
|
81 |
+
_CN.LOFTR.COARSE.ASYMMETRIC_SELF = False
|
82 |
+
_CN.LOFTR.COARSE.ROPE = False
|
83 |
+
_CN.LOFTR.COARSE.TOKEN_MIXER = None
|
84 |
+
_CN.LOFTR.COARSE.SKIP = False
|
85 |
+
_CN.LOFTR.COARSE.DWCONV = False
|
86 |
+
_CN.LOFTR.COARSE.DWCONV2 = False
|
87 |
+
_CN.LOFTR.COARSE.SCATTER = False
|
88 |
+
_CN.LOFTR.COARSE.ROPE = False
|
89 |
+
_CN.LOFTR.COARSE.NPE = None
|
90 |
+
_CN.LOFTR.COARSE.NORM_BEFORE = True
|
91 |
+
_CN.LOFTR.COARSE.VIT_NORM = False
|
92 |
+
_CN.LOFTR.COARSE.ROPE_DWPROJ = False
|
93 |
+
_CN.LOFTR.COARSE.ABSPE = False
|
94 |
+
|
95 |
+
|
96 |
+
# 3. Coarse-Matching config
|
97 |
+
_CN.LOFTR.MATCH_COARSE = CN()
|
98 |
+
_CN.LOFTR.MATCH_COARSE.THR = 0.2
|
99 |
+
_CN.LOFTR.MATCH_COARSE.BORDER_RM = 2
|
100 |
+
_CN.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
|
101 |
+
_CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
|
102 |
+
_CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3
|
103 |
+
_CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
|
104 |
+
_CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False
|
105 |
+
_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
|
106 |
+
_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
|
107 |
+
_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
|
108 |
+
_CN.LOFTR.MATCH_COARSE.MTD_SPVS = False
|
109 |
+
_CN.LOFTR.MATCH_COARSE.FIX_BIAS = False
|
110 |
+
_CN.LOFTR.MATCH_COARSE.BINARY = False
|
111 |
+
_CN.LOFTR.MATCH_COARSE.BINARY_SPV = 'l2'
|
112 |
+
_CN.LOFTR.MATCH_COARSE.NORMFEAT = False
|
113 |
+
_CN.LOFTR.MATCH_COARSE.NORMFEATMUL = False
|
114 |
+
_CN.LOFTR.MATCH_COARSE.DIFFSIGN2 = False
|
115 |
+
_CN.LOFTR.MATCH_COARSE.DIFFSIGN3 = False
|
116 |
+
_CN.LOFTR.MATCH_COARSE.CLASSIFY = False
|
117 |
+
_CN.LOFTR.MATCH_COARSE.D_CLASSIFY = 256
|
118 |
+
_CN.LOFTR.MATCH_COARSE.SKIP_SOFTMAX = False
|
119 |
+
_CN.LOFTR.MATCH_COARSE.FORCE_NEAREST = False # in case binary is True, force nearest neighbor, preventing finding a reasonable threshold
|
120 |
+
_CN.LOFTR.MATCH_COARSE.FP16MATMUL = False
|
121 |
+
_CN.LOFTR.MATCH_COARSE.SEQSOFTMAX = False
|
122 |
+
_CN.LOFTR.MATCH_COARSE.SEQSOFTMAX2 = False
|
123 |
+
_CN.LOFTR.MATCH_COARSE.RATIO_TEST = False
|
124 |
+
_CN.LOFTR.MATCH_COARSE.RATIO_TEST_VAL = -1.0
|
125 |
+
_CN.LOFTR.MATCH_COARSE.USE_GT_COARSE = False
|
126 |
+
_CN.LOFTR.MATCH_COARSE.CROSS_SOFTMAX = False
|
127 |
+
_CN.LOFTR.MATCH_COARSE.PLOT_ORIGIN_SCORES = False
|
128 |
+
_CN.LOFTR.MATCH_COARSE.USE_PERCENT_THR = False
|
129 |
+
_CN.LOFTR.MATCH_COARSE.PERCENT_THR = 0.1
|
130 |
+
_CN.LOFTR.MATCH_COARSE.ADD_SIGMOID = False
|
131 |
+
_CN.LOFTR.MATCH_COARSE.SIGMOID_BIAS = 20.0
|
132 |
+
_CN.LOFTR.MATCH_COARSE.SIGMOID_SIGMA = 2.5
|
133 |
+
_CN.LOFTR.MATCH_COARSE.CAL_PER_OF_GT = False
|
134 |
+
|
135 |
+
# 4. LoFTR-fine module config
|
136 |
+
_CN.LOFTR.FINE = CN()
|
137 |
+
_CN.LOFTR.FINE.SKIP = False
|
138 |
+
_CN.LOFTR.FINE.D_MODEL = 128
|
139 |
+
_CN.LOFTR.FINE.D_FFN = 128
|
140 |
+
_CN.LOFTR.FINE.NHEAD = 8
|
141 |
+
_CN.LOFTR.FINE.LAYER_NAMES = ['self', 'cross'] * 1
|
142 |
+
_CN.LOFTR.FINE.ATTENTION = 'linear'
|
143 |
+
_CN.LOFTR.FINE.MTD_SPVS = False
|
144 |
+
_CN.LOFTR.FINE.PAN = False
|
145 |
+
_CN.LOFTR.FINE.POOl_SIZE = 4
|
146 |
+
_CN.LOFTR.FINE.BN = True
|
147 |
+
_CN.LOFTR.FINE.XFORMER = False
|
148 |
+
_CN.LOFTR.FINE.BIDIRECTION = False
|
149 |
+
|
150 |
+
|
151 |
+
# Fine-Matching config
|
152 |
+
_CN.LOFTR.MATCH_FINE = CN()
|
153 |
+
_CN.LOFTR.MATCH_FINE.THR = 0
|
154 |
+
_CN.LOFTR.MATCH_FINE.TOPK = 3
|
155 |
+
_CN.LOFTR.MATCH_FINE.NORMFINEM = False
|
156 |
+
_CN.LOFTR.MATCH_FINE.USE_GT_FINE = False
|
157 |
+
_CN.LOFTR.MATCH_COARSE.FINE_TOPK = _CN.LOFTR.MATCH_FINE.TOPK
|
158 |
+
_CN.LOFTR.MATCH_FINE.FIX_FINE_MATCHING = False
|
159 |
+
_CN.LOFTR.MATCH_FINE.SKIP_FINE_SOFTMAX = False
|
160 |
+
_CN.LOFTR.MATCH_FINE.USE_SIGMOID = False
|
161 |
+
_CN.LOFTR.MATCH_FINE.SIGMOID_BIAS = 0.0
|
162 |
+
_CN.LOFTR.MATCH_FINE.NORMFEAT = False
|
163 |
+
_CN.LOFTR.MATCH_FINE.SPARSE_SPVS = True
|
164 |
+
_CN.LOFTR.MATCH_FINE.FORCE_NEAREST = False
|
165 |
+
_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS = False
|
166 |
+
_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_RMBORDER = False
|
167 |
+
_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK = False
|
168 |
+
_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 1.0
|
169 |
+
_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_PADONE = False
|
170 |
+
_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICE = False
|
171 |
+
_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
|
172 |
+
_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_INNER = False
|
173 |
+
_CN.LOFTR.MATCH_FINE.MULTI_REGRESS = False
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
# 5. LoFTR Losses
|
178 |
+
# -- # coarse-level
|
179 |
+
_CN.LOFTR.LOSS = CN()
|
180 |
+
_CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy']
|
181 |
+
_CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0
|
182 |
+
_CN.LOFTR.LOSS.COARSE_SIGMOID_WEIGHT = 1.0
|
183 |
+
_CN.LOFTR.LOSS.LOCAL_WEIGHT = 0.5
|
184 |
+
_CN.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = False
|
185 |
+
_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = False
|
186 |
+
_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2 = False
|
187 |
+
# _CN.LOFTR.LOSS.SPARSE_SPVS = False
|
188 |
+
# -- - -- # focal loss (coarse)
|
189 |
+
_CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25
|
190 |
+
_CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0
|
191 |
+
_CN.LOFTR.LOSS.POS_WEIGHT = 1.0
|
192 |
+
_CN.LOFTR.LOSS.NEG_WEIGHT = 1.0
|
193 |
+
_CN.LOFTR.LOSS.CORRECT_NEG_WEIGHT = False
|
194 |
+
# _CN.LOFTR.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not.
|
195 |
+
# use `_CN.LOFTR.MATCH_COARSE.MATCH_TYPE`
|
196 |
+
|
197 |
+
# -- # fine-level
|
198 |
+
_CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2']
|
199 |
+
_CN.LOFTR.LOSS.FINE_WEIGHT = 1.0
|
200 |
+
_CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
|
201 |
+
|
202 |
+
# -- # ROMA:
|
203 |
+
_CN.LOFTR.ROMA_LOSS = CN()
|
204 |
+
_CN.LOFTR.ROMA_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = False # ['l2_with_std', 'l2']
|
205 |
+
|
206 |
+
# -- # DKM:
|
207 |
+
_CN.LOFTR.DKM_LOSS = CN()
|
208 |
+
_CN.LOFTR.DKM_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = False # ['l2_with_std', 'l2']
|
209 |
+
|
210 |
+
############## Dataset ##############
|
211 |
+
_CN.DATASET = CN()
|
212 |
+
# 1. data config
|
213 |
+
# training and validating
|
214 |
+
_CN.DATASET.TB_LOG_DIR= "logs/tb_logs" # options: ['ScanNet', 'MegaDepth']
|
215 |
+
_CN.DATASET.TRAIN_DATA_SAMPLE_RATIO = [1.0] # options: ['ScanNet', 'MegaDepth']
|
216 |
+
_CN.DATASET.TRAIN_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
|
217 |
+
_CN.DATASET.TRAIN_DATA_ROOT = None
|
218 |
+
_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
|
219 |
+
_CN.DATASET.TRAIN_NPZ_ROOT = None
|
220 |
+
_CN.DATASET.TRAIN_LIST_PATH = None
|
221 |
+
_CN.DATASET.TRAIN_INTRINSIC_PATH = None
|
222 |
+
_CN.DATASET.VAL_DATA_ROOT = None
|
223 |
+
_CN.DATASET.VAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
|
224 |
+
_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
|
225 |
+
_CN.DATASET.VAL_NPZ_ROOT = None
|
226 |
+
_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
|
227 |
+
_CN.DATASET.VAL_INTRINSIC_PATH = None
|
228 |
+
_CN.DATASET.FP16 = False
|
229 |
+
_CN.DATASET.TRAIN_GT_MATCHES_PADDING_N = 8000
|
230 |
+
# testing
|
231 |
+
_CN.DATASET.TEST_DATA_SOURCE = None
|
232 |
+
_CN.DATASET.TEST_DATA_ROOT = None
|
233 |
+
_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
|
234 |
+
_CN.DATASET.TEST_NPZ_ROOT = None
|
235 |
+
_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
|
236 |
+
_CN.DATASET.TEST_INTRINSIC_PATH = None
|
237 |
+
|
238 |
+
# 2. dataset config
|
239 |
+
# general options
|
240 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
|
241 |
+
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
|
242 |
+
_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
|
243 |
+
|
244 |
+
# debug options
|
245 |
+
_CN.DATASET.TEST_N_PAIRS = None # Debug first N pairs
|
246 |
+
# DEBUG
|
247 |
+
_CN.LOFTR.FP16LOG = False
|
248 |
+
_CN.LOFTR.MATCH_COARSE.FP16LOG = False
|
249 |
+
|
250 |
+
# scanNet options
|
251 |
+
_CN.DATASET.SCAN_IMG_RESIZEX = 640 # resize the longer side, zero-pad bottom-right to square.
|
252 |
+
_CN.DATASET.SCAN_IMG_RESIZEY = 480 # resize the shorter side, zero-pad bottom-right to square.
|
253 |
+
|
254 |
+
# MegaDepth options
|
255 |
+
_CN.DATASET.MGDPT_IMG_RESIZE = (640, 640) # resize the longer side, zero-pad bottom-right to square.
|
256 |
+
_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
|
257 |
+
_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
|
258 |
+
_CN.DATASET.MGDPT_DF = 8
|
259 |
+
_CN.DATASET.LOAD_ORIGIN_RGB = False # Only open in test mode, useful for RGB required baselines such as DKM, ROMA.
|
260 |
+
_CN.DATASET.READ_GRAY = True
|
261 |
+
_CN.DATASET.RESIZE_BY_STRETCH = False
|
262 |
+
_CN.DATASET.NORMALIZE_IMG = False # For backbone using pretrained DINO feats, use True may be better.
|
263 |
+
_CN.DATASET.HOMO_WARP_USE_MASK = False
|
264 |
+
|
265 |
+
_CN.DATASET.NPE_NAME = "megadepth"
|
266 |
+
|
267 |
+
############## Trainer ##############
|
268 |
+
_CN.TRAINER = CN()
|
269 |
+
_CN.TRAINER.WORLD_SIZE = 1
|
270 |
+
_CN.TRAINER.CANONICAL_BS = 64
|
271 |
+
_CN.TRAINER.CANONICAL_LR = 6e-3
|
272 |
+
_CN.TRAINER.SCALING = None # this will be calculated automatically
|
273 |
+
_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
|
274 |
+
|
275 |
+
# optimizer
|
276 |
+
_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
|
277 |
+
_CN.TRAINER.OPTIMIZER_EPS = 1e-8 # Default for optimizers, but set smaller, e.g., 1e-7 for fp16 mix training
|
278 |
+
_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
|
279 |
+
_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
|
280 |
+
_CN.TRAINER.ADAMW_DECAY = 0.1
|
281 |
+
|
282 |
+
# step-based warm-up
|
283 |
+
_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
|
284 |
+
_CN.TRAINER.WARMUP_RATIO = 0.
|
285 |
+
_CN.TRAINER.WARMUP_STEP = 4800
|
286 |
+
|
287 |
+
# learning rate scheduler
|
288 |
+
_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
|
289 |
+
_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
|
290 |
+
_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
|
291 |
+
_CN.TRAINER.MSLR_GAMMA = 0.5
|
292 |
+
_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
|
293 |
+
_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
|
294 |
+
|
295 |
+
# plotting related
|
296 |
+
_CN.TRAINER.ENABLE_PLOTTING = True
|
297 |
+
_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 8 # number of val/test paris for plotting
|
298 |
+
_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
|
299 |
+
_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
|
300 |
+
|
301 |
+
# geometric metrics and pose solver
|
302 |
+
_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
|
303 |
+
_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
|
304 |
+
_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
|
305 |
+
_CN.TRAINER.WARP_ESTIMATOR_MODEL = 'affine' # [RANSAC, DEGENSAC, MAGSAC]
|
306 |
+
_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
|
307 |
+
_CN.TRAINER.RANSAC_CONF = 0.99999
|
308 |
+
_CN.TRAINER.RANSAC_MAX_ITERS = 10000
|
309 |
+
_CN.TRAINER.USE_MAGSACPP = False
|
310 |
+
_CN.TRAINER.THRESHOLDS = [5, 10, 20]
|
311 |
+
|
312 |
+
# data sampler for train_dataloader
|
313 |
+
_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
|
314 |
+
# 'scene_balance' config
|
315 |
+
_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
|
316 |
+
_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
|
317 |
+
_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
|
318 |
+
_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
|
319 |
+
_CN.TRAINER.AUC_METHOD = 'exact_auc'
|
320 |
+
# 'random' config
|
321 |
+
_CN.TRAINER.RDM_REPLACEMENT = True
|
322 |
+
_CN.TRAINER.RDM_NUM_SAMPLES = None
|
323 |
+
|
324 |
+
# gradient clipping
|
325 |
+
_CN.TRAINER.GRADIENT_CLIPPING = 0.5
|
326 |
+
|
327 |
+
# Finetune Mode:
|
328 |
+
_CN.FINETUNE = CN()
|
329 |
+
_CN.FINETUNE.ENABLE = False
|
330 |
+
_CN.FINETUNE.METHOD = "lora" #['lora', 'whole_network']
|
331 |
+
|
332 |
+
_CN.FINETUNE.LORA = CN()
|
333 |
+
_CN.FINETUNE.LORA.RANK = 2
|
334 |
+
_CN.FINETUNE.LORA.MODE = "linear&conv" # ["linear&conv", "linear_only"]
|
335 |
+
_CN.FINETUNE.LORA.SCALE = 1.0
|
336 |
+
|
337 |
+
_CN.TRAINER.SEED = 66
|
338 |
+
|
339 |
+
|
340 |
+
def get_cfg_defaults():
|
341 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
342 |
+
# Return a clone so that the defaults will not be altered
|
343 |
+
# This is for the "local variable" use pattern
|
344 |
+
return _CN.clone()
|
imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from collections import defaultdict
|
3 |
+
import pprint
|
4 |
+
from loguru import logger
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
from matplotlib import pyplot as plt
|
11 |
+
|
12 |
+
from src.loftr import LoFTR
|
13 |
+
from src.loftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine, compute_roma_supervision
|
14 |
+
from src.optimizers import build_optimizer, build_scheduler
|
15 |
+
from src.utils.metrics import (
|
16 |
+
compute_symmetrical_epipolar_errors,
|
17 |
+
compute_pose_errors,
|
18 |
+
compute_homo_corner_warp_errors,
|
19 |
+
compute_homo_match_warp_errors,
|
20 |
+
compute_warp_control_pts_errors,
|
21 |
+
aggregate_metrics
|
22 |
+
)
|
23 |
+
from src.utils.plotting import make_matching_figures, make_scores_figures
|
24 |
+
from src.utils.comm import gather, all_gather
|
25 |
+
from src.utils.misc import lower_config, flattenList
|
26 |
+
from src.utils.profiler import PassThroughProfiler
|
27 |
+
from third_party.ROMA.roma.matchanything_roma_model import MatchAnything_Model
|
28 |
+
|
29 |
+
import pynvml
|
30 |
+
|
31 |
+
def reparameter(matcher):
|
32 |
+
module = matcher.backbone.layer0
|
33 |
+
if hasattr(module, 'switch_to_deploy'):
|
34 |
+
module.switch_to_deploy()
|
35 |
+
print('m0 switch to deploy ok')
|
36 |
+
for modules in [matcher.backbone.layer1, matcher.backbone.layer2, matcher.backbone.layer3]:
|
37 |
+
for module in modules:
|
38 |
+
if hasattr(module, 'switch_to_deploy'):
|
39 |
+
module.switch_to_deploy()
|
40 |
+
print('backbone switch to deploy ok')
|
41 |
+
for modules in [matcher.fine_preprocess.layer2_outconv2, matcher.fine_preprocess.layer1_outconv2]:
|
42 |
+
for module in modules:
|
43 |
+
if hasattr(module, 'switch_to_deploy'):
|
44 |
+
module.switch_to_deploy()
|
45 |
+
print('fpn switch to deploy ok')
|
46 |
+
return matcher
|
47 |
+
|
48 |
+
class PL_LoFTR(pl.LightningModule):
|
49 |
+
def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None, test_mode=False, baseline_config=None):
|
50 |
+
"""
|
51 |
+
TODO:
|
52 |
+
- use the new version of PL logging API.
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
# Misc
|
56 |
+
self.config = config # full config
|
57 |
+
_config = lower_config(self.config)
|
58 |
+
self.profiler = profiler or PassThroughProfiler()
|
59 |
+
self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
|
60 |
+
|
61 |
+
if config.METHOD == "matchanything_eloftr":
|
62 |
+
self.matcher = LoFTR(config=_config['loftr'], profiler=self.profiler)
|
63 |
+
elif config.METHOD == "matchanything_roma":
|
64 |
+
self.matcher = MatchAnything_Model(config=_config['roma'], test_mode=test_mode)
|
65 |
+
else:
|
66 |
+
raise NotImplementedError
|
67 |
+
|
68 |
+
if config.FINETUNE.ENABLE and test_mode:
|
69 |
+
# Inference time change model architecture before load pretrained model:
|
70 |
+
raise NotImplementedError
|
71 |
+
|
72 |
+
# Pretrained weights
|
73 |
+
if pretrained_ckpt:
|
74 |
+
if config.METHOD in ["matchanything_eloftr", "matchanything_roma"]:
|
75 |
+
state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
|
76 |
+
logger.info(f"Load model from:{self.matcher.load_state_dict(state_dict, strict=False)}")
|
77 |
+
else:
|
78 |
+
raise NotImplementedError
|
79 |
+
|
80 |
+
if self.config.LOFTR.BACKBONE_TYPE == 'RepVGG' and test_mode and (config.METHOD == 'loftr'):
|
81 |
+
module = self.matcher.backbone.layer0
|
82 |
+
if hasattr(module, 'switch_to_deploy'):
|
83 |
+
module.switch_to_deploy()
|
84 |
+
print('m0 switch to deploy ok')
|
85 |
+
for modules in [self.matcher.backbone.layer1, self.matcher.backbone.layer2, self.matcher.backbone.layer3]:
|
86 |
+
for module in modules:
|
87 |
+
if hasattr(module, 'switch_to_deploy'):
|
88 |
+
module.switch_to_deploy()
|
89 |
+
print('m switch to deploy ok')
|
90 |
+
|
91 |
+
# Testing
|
92 |
+
self.dump_dir = dump_dir
|
93 |
+
self.max_gpu_memory = 0
|
94 |
+
self.GPUID = 0
|
95 |
+
self.warmup = False
|
96 |
+
|
97 |
+
def gpumem(self, des, gpuid=None):
|
98 |
+
NUM_EXPAND = 1024 * 1024 * 1024
|
99 |
+
gpu_id= self.GPUID if self.GPUID is not None else gpuid
|
100 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
101 |
+
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
102 |
+
gpu_Used = info.used
|
103 |
+
logger.info(f"GPU {gpu_id} memory used: {gpu_Used / NUM_EXPAND} GB while {des}")
|
104 |
+
# print(des, gpu_Used / NUM_EXPAND)
|
105 |
+
if gpu_Used / NUM_EXPAND > self.max_gpu_memory:
|
106 |
+
self.max_gpu_memory = gpu_Used / NUM_EXPAND
|
107 |
+
logger.info(f"[MAX]GPU {gpu_id} memory used: {gpu_Used / NUM_EXPAND} GB while {des}")
|
108 |
+
print('max_gpu_memory', self.max_gpu_memory)
|
109 |
+
|
110 |
+
def configure_optimizers(self):
|
111 |
+
optimizer = build_optimizer(self, self.config)
|
112 |
+
scheduler = build_scheduler(self.config, optimizer)
|
113 |
+
return [optimizer], [scheduler]
|
114 |
+
|
115 |
+
def optimizer_step(
|
116 |
+
self, epoch, batch_idx, optimizer, optimizer_idx,
|
117 |
+
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
118 |
+
# learning rate warm up
|
119 |
+
warmup_step = self.config.TRAINER.WARMUP_STEP
|
120 |
+
if self.trainer.global_step < warmup_step:
|
121 |
+
if self.config.TRAINER.WARMUP_TYPE == 'linear':
|
122 |
+
base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
|
123 |
+
lr = base_lr + \
|
124 |
+
(self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
|
125 |
+
abs(self.config.TRAINER.TRUE_LR - base_lr)
|
126 |
+
for pg in optimizer.param_groups:
|
127 |
+
pg['lr'] = lr
|
128 |
+
elif self.config.TRAINER.WARMUP_TYPE == 'constant':
|
129 |
+
pass
|
130 |
+
else:
|
131 |
+
raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
|
132 |
+
|
133 |
+
# update params
|
134 |
+
if self.config.LOFTR.FP16:
|
135 |
+
optimizer.step(closure=optimizer_closure)
|
136 |
+
else:
|
137 |
+
optimizer.step(closure=optimizer_closure)
|
138 |
+
optimizer.zero_grad()
|
139 |
+
|
140 |
+
def _trainval_inference(self, batch):
|
141 |
+
with self.profiler.profile("Compute coarse supervision"):
|
142 |
+
|
143 |
+
with torch.autocast(enabled=False, device_type='cuda'):
|
144 |
+
if ("roma" in self.config.METHOD) or ('dkm' in self.config.METHOD):
|
145 |
+
pass
|
146 |
+
else:
|
147 |
+
compute_supervision_coarse(batch, self.config)
|
148 |
+
|
149 |
+
with self.profiler.profile("LoFTR"):
|
150 |
+
with torch.autocast(enabled=self.config.LOFTR.FP16, device_type='cuda'):
|
151 |
+
self.matcher(batch)
|
152 |
+
|
153 |
+
with self.profiler.profile("Compute fine supervision"):
|
154 |
+
with torch.autocast(enabled=False, device_type='cuda'):
|
155 |
+
if ("roma" in self.config.METHOD) or ('dkm' in self.config.METHOD):
|
156 |
+
compute_roma_supervision(batch, self.config)
|
157 |
+
else:
|
158 |
+
compute_supervision_fine(batch, self.config, self.logger)
|
159 |
+
|
160 |
+
with self.profiler.profile("Compute losses"):
|
161 |
+
pass
|
162 |
+
|
163 |
+
def _compute_metrics(self, batch):
|
164 |
+
if 'gt_2D_matches' in batch:
|
165 |
+
compute_warp_control_pts_errors(batch, self.config)
|
166 |
+
elif batch['homography'].sum() != 0 and batch['T_0to1'].sum() == 0:
|
167 |
+
compute_homo_match_warp_errors(batch, self.config) # compute warp_errors for each match
|
168 |
+
compute_homo_corner_warp_errors(batch, self.config) # compute mean corner warp error each pair
|
169 |
+
else:
|
170 |
+
compute_symmetrical_epipolar_errors(batch, self.config) # compute epi_errs for each match
|
171 |
+
compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
|
172 |
+
|
173 |
+
rel_pair_names = list(zip(*batch['pair_names']))
|
174 |
+
bs = batch['image0'].size(0)
|
175 |
+
if self.config.LOFTR.FINE.MTD_SPVS:
|
176 |
+
topk = self.config.LOFTR.MATCH_FINE.TOPK
|
177 |
+
metrics = {
|
178 |
+
# to filter duplicate pairs caused by DistributedSampler
|
179 |
+
'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
|
180 |
+
'epi_errs': [(batch['epi_errs'].reshape(-1,topk))[batch['m_bids'] == b].reshape(-1).cpu().numpy() for b in range(bs)],
|
181 |
+
'R_errs': batch['R_errs'],
|
182 |
+
't_errs': batch['t_errs'],
|
183 |
+
'inliers': batch['inliers'],
|
184 |
+
'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only
|
185 |
+
'percent_inliers': [ batch['inliers'][0].shape[0] / batch['mconf'].shape[0] if batch['mconf'].shape[0]!=0 else 1], # batch size = 1 only
|
186 |
+
}
|
187 |
+
else:
|
188 |
+
metrics = {
|
189 |
+
# to filter duplicate pairs caused by DistributedSampler
|
190 |
+
'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
|
191 |
+
'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
|
192 |
+
'R_errs': batch['R_errs'],
|
193 |
+
't_errs': batch['t_errs'],
|
194 |
+
'inliers': batch['inliers'],
|
195 |
+
'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only
|
196 |
+
'percent_inliers': [ batch['inliers'][0].shape[0] / batch['mconf'].shape[0] if batch['mconf'].shape[0]!=0 else 1], # batch size = 1 only
|
197 |
+
}
|
198 |
+
ret_dict = {'metrics': metrics}
|
199 |
+
return ret_dict, rel_pair_names
|
200 |
+
|
201 |
+
def training_step(self, batch, batch_idx):
|
202 |
+
self._trainval_inference(batch)
|
203 |
+
|
204 |
+
# logging
|
205 |
+
if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
|
206 |
+
# scalars
|
207 |
+
for k, v in batch['loss_scalars'].items():
|
208 |
+
self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
|
209 |
+
|
210 |
+
# net-params
|
211 |
+
method = 'LOFTR'
|
212 |
+
if self.config[method]['MATCH_COARSE']['MATCH_TYPE'] == 'sinkhorn':
|
213 |
+
self.logger.experiment.add_scalar(
|
214 |
+
f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step)
|
215 |
+
|
216 |
+
figures = {}
|
217 |
+
if self.config.TRAINER.ENABLE_PLOTTING:
|
218 |
+
compute_symmetrical_epipolar_errors(batch, self.config) # compute epi_errs for each match
|
219 |
+
figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
|
220 |
+
for k, v in figures.items():
|
221 |
+
self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
|
222 |
+
|
223 |
+
return {'loss': batch['loss']}
|
224 |
+
|
225 |
+
def training_epoch_end(self, outputs):
|
226 |
+
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
|
227 |
+
if self.trainer.global_rank == 0:
|
228 |
+
self.logger.experiment.add_scalar(
|
229 |
+
'train/avg_loss_on_epoch', avg_loss,
|
230 |
+
global_step=self.current_epoch)
|
231 |
+
|
232 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
233 |
+
self._trainval_inference(batch)
|
234 |
+
|
235 |
+
ret_dict, _ = self._compute_metrics(batch)
|
236 |
+
|
237 |
+
val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
|
238 |
+
figures = {self.config.TRAINER.PLOT_MODE: []}
|
239 |
+
if batch_idx % val_plot_interval == 0:
|
240 |
+
figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
|
241 |
+
if self.config.LOFTR.PLOT_SCORES:
|
242 |
+
figs = make_scores_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
|
243 |
+
figures[self.config.TRAINER.PLOT_MODE] += figs[self.config.TRAINER.PLOT_MODE]
|
244 |
+
del figs
|
245 |
+
|
246 |
+
return {
|
247 |
+
**ret_dict,
|
248 |
+
'loss_scalars': batch['loss_scalars'],
|
249 |
+
'figures': figures,
|
250 |
+
}
|
251 |
+
|
252 |
+
def validation_epoch_end(self, outputs):
|
253 |
+
# handle multiple validation sets
|
254 |
+
multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
|
255 |
+
multi_val_metrics = defaultdict(list)
|
256 |
+
|
257 |
+
for valset_idx, outputs in enumerate(multi_outputs):
|
258 |
+
# since pl performs sanity_check at the very begining of the training
|
259 |
+
cur_epoch = self.trainer.current_epoch
|
260 |
+
if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
|
261 |
+
cur_epoch = -1
|
262 |
+
|
263 |
+
# 1. loss_scalars: dict of list, on cpu
|
264 |
+
_loss_scalars = [o['loss_scalars'] for o in outputs]
|
265 |
+
loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
|
266 |
+
|
267 |
+
# 2. val metrics: dict of list, numpy
|
268 |
+
_metrics = [o['metrics'] for o in outputs]
|
269 |
+
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
|
270 |
+
# NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
|
271 |
+
val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, self.config.LOFTR.EVAL_TIMES)
|
272 |
+
for thr in [5, 10, 20]:
|
273 |
+
multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
|
274 |
+
|
275 |
+
# 3. figures
|
276 |
+
_figures = [o['figures'] for o in outputs]
|
277 |
+
figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
|
278 |
+
|
279 |
+
# tensorboard records only on rank 0
|
280 |
+
if self.trainer.global_rank == 0:
|
281 |
+
for k, v in loss_scalars.items():
|
282 |
+
mean_v = torch.stack(v).mean()
|
283 |
+
self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
|
284 |
+
|
285 |
+
for k, v in val_metrics_4tb.items():
|
286 |
+
self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
|
287 |
+
|
288 |
+
for k, v in figures.items():
|
289 |
+
if self.trainer.global_rank == 0:
|
290 |
+
for plot_idx, fig in enumerate(v):
|
291 |
+
self.logger.experiment.add_figure(
|
292 |
+
f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
|
293 |
+
plt.close('all')
|
294 |
+
|
295 |
+
for thr in [5, 10, 20]:
|
296 |
+
self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
|
297 |
+
|
298 |
+
def test_step(self, batch, batch_idx):
|
299 |
+
if self.warmup:
|
300 |
+
for i in range(50):
|
301 |
+
self.matcher(batch)
|
302 |
+
self.warmup = False
|
303 |
+
|
304 |
+
with torch.autocast(enabled=self.config.LOFTR.FP16, device_type='cuda'):
|
305 |
+
with self.profiler.profile("LoFTR"):
|
306 |
+
self.matcher(batch)
|
307 |
+
|
308 |
+
ret_dict, rel_pair_names = self._compute_metrics(batch)
|
309 |
+
print(ret_dict['metrics']['num_matches'])
|
310 |
+
self.dump_dir = None
|
311 |
+
|
312 |
+
return ret_dict
|
313 |
+
|
314 |
+
def test_epoch_end(self, outputs):
|
315 |
+
print(self.config)
|
316 |
+
print('max GPU memory: ', self.max_gpu_memory)
|
317 |
+
print(self.profiler.summary())
|
318 |
+
# metrics: dict of list, numpy
|
319 |
+
_metrics = [o['metrics'] for o in outputs]
|
320 |
+
metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
|
321 |
+
|
322 |
+
# [{key: [{...}, *#bs]}, *#batch]
|
323 |
+
if self.dump_dir is not None:
|
324 |
+
Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
|
325 |
+
_dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
|
326 |
+
dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
|
327 |
+
logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
|
328 |
+
|
329 |
+
if self.trainer.global_rank == 0:
|
330 |
+
NUM_EXPAND = 1024 * 1024 * 1024
|
331 |
+
gpu_id=self.GPUID
|
332 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
333 |
+
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
334 |
+
gpu_Used = info.used
|
335 |
+
print('pynvml', gpu_Used / NUM_EXPAND)
|
336 |
+
if gpu_Used / NUM_EXPAND > self.max_gpu_memory:
|
337 |
+
self.max_gpu_memory = gpu_Used / NUM_EXPAND
|
338 |
+
|
339 |
+
print(self.profiler.summary())
|
340 |
+
val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, self.config.LOFTR.EVAL_TIMES, self.config.TRAINER.THRESHOLDS, method=self.config.TRAINER.AUC_METHOD)
|
341 |
+
logger.info('\n' + pprint.pformat(val_metrics_4tb))
|
342 |
+
if self.dump_dir is not None:
|
343 |
+
np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps)
|
imcui/third_party/MatchAnything/src/loftr/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .loftr import LoFTR
|
imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4, ResNetFPN_8_1, ResNetFPN_8_2_align, ResNetFPN_8_1_align, ResNetFPN_8_2_fix, ResNet_8_1_align, VGG_8_1_align, RepVGG_8_1_align, \
|
2 |
+
RepVGGnfpn_8_1_align, RepVGG_8_2_fix, s2dnet_8_1_align
|
3 |
+
|
4 |
+
def build_backbone(config):
|
5 |
+
if config['backbone_type'] == 'ResNetFPN':
|
6 |
+
if config['align_corner'] is None or config['align_corner'] is True:
|
7 |
+
if config['resolution'] == (8, 2):
|
8 |
+
return ResNetFPN_8_2(config['resnetfpn'])
|
9 |
+
elif config['resolution'] == (16, 4):
|
10 |
+
return ResNetFPN_16_4(config['resnetfpn'])
|
11 |
+
elif config['resolution'] == (8, 1):
|
12 |
+
return ResNetFPN_8_1(config['resnetfpn'])
|
13 |
+
elif config['align_corner'] is False:
|
14 |
+
if config['resolution'] == (8, 2):
|
15 |
+
return ResNetFPN_8_2_align(config['resnetfpn'])
|
16 |
+
elif config['resolution'] == (16, 4):
|
17 |
+
return ResNetFPN_16_4(config['resnetfpn'])
|
18 |
+
elif config['resolution'] == (8, 1):
|
19 |
+
return ResNetFPN_8_1_align(config['resnetfpn'])
|
20 |
+
elif config['backbone_type'] == 'ResNetFPNFIX':
|
21 |
+
if config['align_corner'] is None or config['align_corner'] is True:
|
22 |
+
if config['resolution'] == (8, 2):
|
23 |
+
return ResNetFPN_8_2_fix(config['resnetfpn'])
|
24 |
+
elif config['backbone_type'] == 'ResNet':
|
25 |
+
if config['align_corner'] is None or config['align_corner'] is True:
|
26 |
+
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
27 |
+
elif config['align_corner'] is False:
|
28 |
+
if config['resolution'] == (8, 1):
|
29 |
+
return ResNet_8_1_align(config['resnetfpn'])
|
30 |
+
elif config['backbone_type'] == 'VGG':
|
31 |
+
if config['align_corner'] is None or config['align_corner'] is True:
|
32 |
+
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
33 |
+
elif config['align_corner'] is False:
|
34 |
+
if config['resolution'] == (8, 1):
|
35 |
+
return VGG_8_1_align(config['resnetfpn'])
|
36 |
+
elif config['backbone_type'] == 'RepVGG':
|
37 |
+
if config['align_corner'] is None or config['align_corner'] is True:
|
38 |
+
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
39 |
+
elif config['align_corner'] is False:
|
40 |
+
if config['resolution'] == (8, 1):
|
41 |
+
return RepVGG_8_1_align(config['resnetfpn'])
|
42 |
+
elif config['backbone_type'] == 'RepVGGNFPN':
|
43 |
+
if config['align_corner'] is None or config['align_corner'] is True:
|
44 |
+
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
45 |
+
elif config['align_corner'] is False:
|
46 |
+
if config['resolution'] == (8, 1):
|
47 |
+
return RepVGGnfpn_8_1_align(config['resnetfpn'])
|
48 |
+
elif config['backbone_type'] == 'RepVGGFPNFIX':
|
49 |
+
if config['align_corner'] is None or config['align_corner'] is True:
|
50 |
+
if config['resolution'] == (8, 2):
|
51 |
+
return RepVGG_8_2_fix(config['resnetfpn'])
|
52 |
+
elif config['align_corner'] is False:
|
53 |
+
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
54 |
+
elif config['backbone_type'] == 's2dnet':
|
55 |
+
if config['align_corner'] is None or config['align_corner'] is True:
|
56 |
+
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
57 |
+
elif config['align_corner'] is False:
|
58 |
+
if config['resolution'] == (8, 1):
|
59 |
+
return s2dnet_8_1_align(config['resnetfpn'])
|
60 |
+
else:
|
61 |
+
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
|
3 |
+
# Github source: https://github.com/DingXiaoH/RepVGG
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import copy
|
10 |
+
# from se_block import SEBlock
|
11 |
+
import torch.utils.checkpoint as checkpoint
|
12 |
+
from loguru import logger
|
13 |
+
|
14 |
+
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
|
15 |
+
result = nn.Sequential()
|
16 |
+
result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
17 |
+
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
|
18 |
+
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
|
19 |
+
return result
|
20 |
+
|
21 |
+
class RepVGGBlock(nn.Module):
|
22 |
+
|
23 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
24 |
+
stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, leaky=-1.0):
|
25 |
+
super(RepVGGBlock, self).__init__()
|
26 |
+
self.deploy = deploy
|
27 |
+
self.groups = groups
|
28 |
+
self.in_channels = in_channels
|
29 |
+
|
30 |
+
assert kernel_size == 3
|
31 |
+
assert padding == 1
|
32 |
+
|
33 |
+
padding_11 = padding - kernel_size // 2
|
34 |
+
|
35 |
+
if leaky == -2:
|
36 |
+
self.nonlinearity = nn.Identity()
|
37 |
+
logger.info(f"Using Identity nonlinearity in repvgg_block")
|
38 |
+
elif leaky < 0:
|
39 |
+
self.nonlinearity = nn.ReLU()
|
40 |
+
else:
|
41 |
+
self.nonlinearity = nn.LeakyReLU(leaky)
|
42 |
+
|
43 |
+
if use_se:
|
44 |
+
# Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity.
|
45 |
+
# self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)
|
46 |
+
raise ValueError(f"SEBlock not supported")
|
47 |
+
else:
|
48 |
+
self.se = nn.Identity()
|
49 |
+
|
50 |
+
if deploy:
|
51 |
+
self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
52 |
+
padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
|
53 |
+
|
54 |
+
else:
|
55 |
+
self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
|
56 |
+
self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
|
57 |
+
self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
|
58 |
+
print('RepVGG Block, identity = ', self.rbr_identity)
|
59 |
+
|
60 |
+
|
61 |
+
def forward(self, inputs):
|
62 |
+
if hasattr(self, 'rbr_reparam'):
|
63 |
+
return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
|
64 |
+
|
65 |
+
if self.rbr_identity is None:
|
66 |
+
id_out = 0
|
67 |
+
else:
|
68 |
+
id_out = self.rbr_identity(inputs)
|
69 |
+
|
70 |
+
return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
|
71 |
+
|
72 |
+
|
73 |
+
# Optional. This may improve the accuracy and facilitates quantization in some cases.
|
74 |
+
# 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
|
75 |
+
# 2. Use like this.
|
76 |
+
# loss = criterion(....)
|
77 |
+
# for every RepVGGBlock blk:
|
78 |
+
# loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
|
79 |
+
# optimizer.zero_grad()
|
80 |
+
# loss.backward()
|
81 |
+
def get_custom_L2(self):
|
82 |
+
K3 = self.rbr_dense.conv.weight
|
83 |
+
K1 = self.rbr_1x1.conv.weight
|
84 |
+
t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
|
85 |
+
t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
|
86 |
+
|
87 |
+
l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
|
88 |
+
eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
|
89 |
+
l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
|
90 |
+
return l2_loss_eq_kernel + l2_loss_circle
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
|
95 |
+
# You can get the equivalent kernel and bias at any time and do whatever you want,
|
96 |
+
# for example, apply some penalties or constraints during training, just like you do to the other models.
|
97 |
+
# May be useful for quantization or pruning.
|
98 |
+
def get_equivalent_kernel_bias(self):
|
99 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
|
100 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
|
101 |
+
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
|
102 |
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
103 |
+
|
104 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
105 |
+
if kernel1x1 is None:
|
106 |
+
return 0
|
107 |
+
else:
|
108 |
+
return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
|
109 |
+
|
110 |
+
def _fuse_bn_tensor(self, branch):
|
111 |
+
if branch is None:
|
112 |
+
return 0, 0
|
113 |
+
if isinstance(branch, nn.Sequential):
|
114 |
+
kernel = branch.conv.weight
|
115 |
+
running_mean = branch.bn.running_mean
|
116 |
+
running_var = branch.bn.running_var
|
117 |
+
gamma = branch.bn.weight
|
118 |
+
beta = branch.bn.bias
|
119 |
+
eps = branch.bn.eps
|
120 |
+
else:
|
121 |
+
assert isinstance(branch, nn.BatchNorm2d)
|
122 |
+
if not hasattr(self, 'id_tensor'):
|
123 |
+
input_dim = self.in_channels // self.groups
|
124 |
+
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
|
125 |
+
for i in range(self.in_channels):
|
126 |
+
kernel_value[i, i % input_dim, 1, 1] = 1
|
127 |
+
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
128 |
+
kernel = self.id_tensor
|
129 |
+
running_mean = branch.running_mean
|
130 |
+
running_var = branch.running_var
|
131 |
+
gamma = branch.weight
|
132 |
+
beta = branch.bias
|
133 |
+
eps = branch.eps
|
134 |
+
std = (running_var + eps).sqrt()
|
135 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
136 |
+
return kernel * t, beta - running_mean * gamma / std
|
137 |
+
|
138 |
+
def switch_to_deploy(self):
|
139 |
+
if hasattr(self, 'rbr_reparam'):
|
140 |
+
return
|
141 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
142 |
+
self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
|
143 |
+
kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
|
144 |
+
padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
|
145 |
+
self.rbr_reparam.weight.data = kernel
|
146 |
+
self.rbr_reparam.bias.data = bias
|
147 |
+
self.__delattr__('rbr_dense')
|
148 |
+
self.__delattr__('rbr_1x1')
|
149 |
+
if hasattr(self, 'rbr_identity'):
|
150 |
+
self.__delattr__('rbr_identity')
|
151 |
+
if hasattr(self, 'id_tensor'):
|
152 |
+
self.__delattr__('id_tensor')
|
153 |
+
self.deploy = True
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
class RepVGG(nn.Module):
|
158 |
+
|
159 |
+
def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False, leaky=-1.0):
|
160 |
+
super(RepVGG, self).__init__()
|
161 |
+
assert len(width_multiplier) == 4
|
162 |
+
self.deploy = deploy
|
163 |
+
self.override_groups_map = override_groups_map or dict()
|
164 |
+
assert 0 not in self.override_groups_map
|
165 |
+
self.use_se = use_se
|
166 |
+
self.use_checkpoint = use_checkpoint
|
167 |
+
|
168 |
+
self.in_planes = min(64, int(64 * width_multiplier[0]))
|
169 |
+
self.stage0 = RepVGGBlock(in_channels=1, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se, leaky=leaky)
|
170 |
+
self.cur_layer_idx = 1
|
171 |
+
self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=1, leaky=leaky)
|
172 |
+
self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2, leaky=leaky)
|
173 |
+
self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2, leaky=leaky)
|
174 |
+
# self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=1)
|
175 |
+
# self.gap = nn.AdaptiveAvgPool2d(output_size=1)
|
176 |
+
# self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
|
177 |
+
|
178 |
+
def _make_stage(self, planes, num_blocks, stride, leaky=-1.0):
|
179 |
+
strides = [stride] + [1]*(num_blocks-1)
|
180 |
+
blocks = []
|
181 |
+
for stride in strides:
|
182 |
+
cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
|
183 |
+
blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
|
184 |
+
stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se, leaky=leaky))
|
185 |
+
self.in_planes = planes
|
186 |
+
self.cur_layer_idx += 1
|
187 |
+
return nn.ModuleList(blocks)
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
out = self.stage0(x)
|
191 |
+
for stage in (self.stage1, self.stage2, self.stage3): # , self.stage4):
|
192 |
+
for block in stage:
|
193 |
+
if self.use_checkpoint:
|
194 |
+
out = checkpoint.checkpoint(block, out)
|
195 |
+
else:
|
196 |
+
out = block(out)
|
197 |
+
out = self.gap(out)
|
198 |
+
out = out.view(out.size(0), -1)
|
199 |
+
out = self.linear(out)
|
200 |
+
return out
|
201 |
+
|
202 |
+
|
203 |
+
optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
|
204 |
+
g2_map = {l: 2 for l in optional_groupwise_layers}
|
205 |
+
g4_map = {l: 4 for l in optional_groupwise_layers}
|
206 |
+
|
207 |
+
def create_RepVGG_A0(deploy=False, use_checkpoint=False):
|
208 |
+
return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
|
209 |
+
width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
|
210 |
+
|
211 |
+
def create_RepVGG_A1(deploy=False, use_checkpoint=False):
|
212 |
+
return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
|
213 |
+
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
|
214 |
+
def create_RepVGG_A15(deploy=False, use_checkpoint=False):
|
215 |
+
return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
|
216 |
+
width_multiplier=[1.25, 1.25, 1.25, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
|
217 |
+
def create_RepVGG_A1_leaky(deploy=False, use_checkpoint=False):
|
218 |
+
return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
|
219 |
+
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint, leaky=0.01)
|
220 |
+
|
221 |
+
def create_RepVGG_A2(deploy=False, use_checkpoint=False):
|
222 |
+
return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
|
223 |
+
width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
|
224 |
+
|
225 |
+
def create_RepVGG_B0(deploy=False, use_checkpoint=False):
|
226 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
227 |
+
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
|
228 |
+
|
229 |
+
def create_RepVGG_B1(deploy=False, use_checkpoint=False):
|
230 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
231 |
+
width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
|
232 |
+
|
233 |
+
def create_RepVGG_B1g2(deploy=False, use_checkpoint=False):
|
234 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
235 |
+
width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
|
236 |
+
|
237 |
+
def create_RepVGG_B1g4(deploy=False, use_checkpoint=False):
|
238 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
239 |
+
width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
|
240 |
+
|
241 |
+
|
242 |
+
def create_RepVGG_B2(deploy=False, use_checkpoint=False):
|
243 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
244 |
+
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
|
245 |
+
|
246 |
+
def create_RepVGG_B2g2(deploy=False, use_checkpoint=False):
|
247 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
248 |
+
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
|
249 |
+
|
250 |
+
def create_RepVGG_B2g4(deploy=False, use_checkpoint=False):
|
251 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
252 |
+
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
|
253 |
+
|
254 |
+
|
255 |
+
def create_RepVGG_B3(deploy=False, use_checkpoint=False):
|
256 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
257 |
+
width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
|
258 |
+
|
259 |
+
def create_RepVGG_B3g2(deploy=False, use_checkpoint=False):
|
260 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
261 |
+
width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
|
262 |
+
|
263 |
+
def create_RepVGG_B3g4(deploy=False, use_checkpoint=False):
|
264 |
+
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
|
265 |
+
width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
|
266 |
+
|
267 |
+
def create_RepVGG_D2se(deploy=False, use_checkpoint=False):
|
268 |
+
return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000,
|
269 |
+
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True, use_checkpoint=use_checkpoint)
|
270 |
+
|
271 |
+
|
272 |
+
func_dict = {
|
273 |
+
'RepVGG-A0': create_RepVGG_A0,
|
274 |
+
'RepVGG-A1': create_RepVGG_A1,
|
275 |
+
'RepVGG-A15': create_RepVGG_A15,
|
276 |
+
'RepVGG-A1_leaky': create_RepVGG_A1_leaky,
|
277 |
+
'RepVGG-A2': create_RepVGG_A2,
|
278 |
+
'RepVGG-B0': create_RepVGG_B0,
|
279 |
+
'RepVGG-B1': create_RepVGG_B1,
|
280 |
+
'RepVGG-B1g2': create_RepVGG_B1g2,
|
281 |
+
'RepVGG-B1g4': create_RepVGG_B1g4,
|
282 |
+
'RepVGG-B2': create_RepVGG_B2,
|
283 |
+
'RepVGG-B2g2': create_RepVGG_B2g2,
|
284 |
+
'RepVGG-B2g4': create_RepVGG_B2g4,
|
285 |
+
'RepVGG-B3': create_RepVGG_B3,
|
286 |
+
'RepVGG-B3g2': create_RepVGG_B3g2,
|
287 |
+
'RepVGG-B3g4': create_RepVGG_B3g4,
|
288 |
+
'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper.
|
289 |
+
}
|
290 |
+
def get_RepVGG_func_by_name(name):
|
291 |
+
return func_dict[name]
|
292 |
+
|
293 |
+
|
294 |
+
|
295 |
+
# Use this for converting a RepVGG model or a bigger model with RepVGG as its component
|
296 |
+
# Use like this
|
297 |
+
# model = create_RepVGG_A0(deploy=False)
|
298 |
+
# train model or load weights
|
299 |
+
# repvgg_model_convert(model, save_path='repvgg_deploy.pth')
|
300 |
+
# If you want to preserve the original model, call with do_copy=True
|
301 |
+
|
302 |
+
# ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
|
303 |
+
# train_backbone = create_RepVGG_B2(deploy=False)
|
304 |
+
# train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
|
305 |
+
# train_pspnet = build_pspnet(backbone=train_backbone)
|
306 |
+
# segmentation_train(train_pspnet)
|
307 |
+
# deploy_pspnet = repvgg_model_convert(train_pspnet)
|
308 |
+
# segmentation_test(deploy_pspnet)
|
309 |
+
# ===================== example_pspnet.py shows an example
|
310 |
+
|
311 |
+
def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):
|
312 |
+
if do_copy:
|
313 |
+
model = copy.deepcopy(model)
|
314 |
+
for module in model.modules():
|
315 |
+
if hasattr(module, 'switch_to_deploy'):
|
316 |
+
module.switch_to_deploy()
|
317 |
+
if save_path is not None:
|
318 |
+
torch.save(model.state_dict(), save_path)
|
319 |
+
return model
|
imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py
ADDED
@@ -0,0 +1,1094 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from .repvgg import get_RepVGG_func_by_name
|
4 |
+
from .s2dnet import S2DNet
|
5 |
+
|
6 |
+
|
7 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
8 |
+
"""1x1 convolution without padding"""
|
9 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
|
10 |
+
|
11 |
+
|
12 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
13 |
+
"""3x3 convolution with padding"""
|
14 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
15 |
+
|
16 |
+
|
17 |
+
class BasicBlock(nn.Module):
|
18 |
+
def __init__(self, in_planes, planes, stride=1):
|
19 |
+
super().__init__()
|
20 |
+
self.conv1 = conv3x3(in_planes, planes, stride)
|
21 |
+
self.conv2 = conv3x3(planes, planes)
|
22 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
23 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
24 |
+
self.relu = nn.ReLU(inplace=True)
|
25 |
+
|
26 |
+
if stride == 1:
|
27 |
+
self.downsample = None
|
28 |
+
else:
|
29 |
+
self.downsample = nn.Sequential(
|
30 |
+
conv1x1(in_planes, planes, stride=stride),
|
31 |
+
nn.BatchNorm2d(planes)
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
y = x
|
36 |
+
y = self.relu(self.bn1(self.conv1(y)))
|
37 |
+
y = self.bn2(self.conv2(y))
|
38 |
+
|
39 |
+
if self.downsample is not None:
|
40 |
+
x = self.downsample(x)
|
41 |
+
|
42 |
+
return self.relu(x+y)
|
43 |
+
|
44 |
+
|
45 |
+
class ResNetFPN_8_2(nn.Module):
|
46 |
+
"""
|
47 |
+
ResNet+FPN, output resolution are 1/8 and 1/2.
|
48 |
+
Each block has 2 layers.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, config):
|
52 |
+
super().__init__()
|
53 |
+
# Config
|
54 |
+
block = BasicBlock
|
55 |
+
initial_dim = config['initial_dim']
|
56 |
+
block_dims = config['block_dims']
|
57 |
+
|
58 |
+
# Class Variable
|
59 |
+
self.in_planes = initial_dim
|
60 |
+
|
61 |
+
# Networks
|
62 |
+
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
63 |
+
self.bn1 = nn.BatchNorm2d(initial_dim)
|
64 |
+
self.relu = nn.ReLU(inplace=True)
|
65 |
+
|
66 |
+
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
67 |
+
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
68 |
+
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
69 |
+
|
70 |
+
# 3. FPN upsample
|
71 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
72 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
73 |
+
self.layer2_outconv2 = nn.Sequential(
|
74 |
+
conv3x3(block_dims[2], block_dims[2]),
|
75 |
+
nn.BatchNorm2d(block_dims[2]),
|
76 |
+
nn.LeakyReLU(),
|
77 |
+
conv3x3(block_dims[2], block_dims[1]),
|
78 |
+
)
|
79 |
+
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
80 |
+
self.layer1_outconv2 = nn.Sequential(
|
81 |
+
conv3x3(block_dims[1], block_dims[1]),
|
82 |
+
nn.BatchNorm2d(block_dims[1]),
|
83 |
+
nn.LeakyReLU(),
|
84 |
+
conv3x3(block_dims[1], block_dims[0]),
|
85 |
+
)
|
86 |
+
|
87 |
+
for m in self.modules():
|
88 |
+
if isinstance(m, nn.Conv2d):
|
89 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
90 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
91 |
+
nn.init.constant_(m.weight, 1)
|
92 |
+
nn.init.constant_(m.bias, 0)
|
93 |
+
|
94 |
+
def _make_layer(self, block, dim, stride=1):
|
95 |
+
layer1 = block(self.in_planes, dim, stride=stride)
|
96 |
+
layer2 = block(dim, dim, stride=1)
|
97 |
+
layers = (layer1, layer2)
|
98 |
+
|
99 |
+
self.in_planes = dim
|
100 |
+
return nn.Sequential(*layers)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
# ResNet Backbone
|
104 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
105 |
+
x1 = self.layer1(x0) # 1/2
|
106 |
+
x2 = self.layer2(x1) # 1/4
|
107 |
+
x3 = self.layer3(x2) # 1/8
|
108 |
+
|
109 |
+
# FPN
|
110 |
+
x3_out = self.layer3_outconv(x3)
|
111 |
+
|
112 |
+
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
113 |
+
x2_out = self.layer2_outconv(x2)
|
114 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
115 |
+
|
116 |
+
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
|
117 |
+
x1_out = self.layer1_outconv(x1)
|
118 |
+
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
119 |
+
|
120 |
+
return {'feats_c': x3_out, 'feats_f': x1_out}
|
121 |
+
|
122 |
+
def pro(self, x, profiler):
|
123 |
+
with profiler.profile('ResNet Backbone'):
|
124 |
+
# ResNet Backbone
|
125 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
126 |
+
x1 = self.layer1(x0) # 1/2
|
127 |
+
x2 = self.layer2(x1) # 1/4
|
128 |
+
x3 = self.layer3(x2) # 1/8
|
129 |
+
|
130 |
+
with profiler.profile('ResNet FPN'):
|
131 |
+
# FPN
|
132 |
+
x3_out = self.layer3_outconv(x3)
|
133 |
+
|
134 |
+
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
135 |
+
x2_out = self.layer2_outconv(x2)
|
136 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
137 |
+
|
138 |
+
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
|
139 |
+
x1_out = self.layer1_outconv(x1)
|
140 |
+
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
141 |
+
|
142 |
+
return {'feats_c': x3_out, 'feats_f': x1_out}
|
143 |
+
|
144 |
+
class ResNetFPN_8_2_fix(nn.Module):
|
145 |
+
"""
|
146 |
+
ResNet+FPN, output resolution are 1/8 and 1/2.
|
147 |
+
Each block has 2 layers.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self, config):
|
151 |
+
super().__init__()
|
152 |
+
# Config
|
153 |
+
block = BasicBlock
|
154 |
+
initial_dim = config['initial_dim']
|
155 |
+
block_dims = config['block_dims']
|
156 |
+
|
157 |
+
# Class Variable
|
158 |
+
self.in_planes = initial_dim
|
159 |
+
self.skip_fine_feature = config['coarse_feat_only']
|
160 |
+
self.inter_feat = config['inter_feat']
|
161 |
+
|
162 |
+
# Networks
|
163 |
+
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
164 |
+
self.bn1 = nn.BatchNorm2d(initial_dim)
|
165 |
+
self.relu = nn.ReLU(inplace=True)
|
166 |
+
|
167 |
+
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
168 |
+
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
169 |
+
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
170 |
+
|
171 |
+
# 3. FPN upsample
|
172 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
173 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
174 |
+
self.layer2_outconv2 = nn.Sequential(
|
175 |
+
conv3x3(block_dims[2], block_dims[2]),
|
176 |
+
nn.BatchNorm2d(block_dims[2]),
|
177 |
+
nn.LeakyReLU(),
|
178 |
+
conv3x3(block_dims[2], block_dims[1]),
|
179 |
+
)
|
180 |
+
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
181 |
+
self.layer1_outconv2 = nn.Sequential(
|
182 |
+
conv3x3(block_dims[1], block_dims[1]),
|
183 |
+
nn.BatchNorm2d(block_dims[1]),
|
184 |
+
nn.LeakyReLU(),
|
185 |
+
conv3x3(block_dims[1], block_dims[0]),
|
186 |
+
)
|
187 |
+
|
188 |
+
for m in self.modules():
|
189 |
+
if isinstance(m, nn.Conv2d):
|
190 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
191 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
192 |
+
nn.init.constant_(m.weight, 1)
|
193 |
+
nn.init.constant_(m.bias, 0)
|
194 |
+
|
195 |
+
def _make_layer(self, block, dim, stride=1):
|
196 |
+
layer1 = block(self.in_planes, dim, stride=stride)
|
197 |
+
layer2 = block(dim, dim, stride=1)
|
198 |
+
layers = (layer1, layer2)
|
199 |
+
|
200 |
+
self.in_planes = dim
|
201 |
+
return nn.Sequential(*layers)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
# ResNet Backbone
|
205 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
206 |
+
x1 = self.layer1(x0) # 1/2
|
207 |
+
x2 = self.layer2(x1) # 1/4
|
208 |
+
x3 = self.layer3(x2) # 1/8
|
209 |
+
|
210 |
+
# FPN
|
211 |
+
if self.skip_fine_feature:
|
212 |
+
if self.inter_feat:
|
213 |
+
return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
|
214 |
+
else:
|
215 |
+
return {'feats_c': x3, 'feats_f': None}
|
216 |
+
|
217 |
+
|
218 |
+
x3_out = self.layer3_outconv(x3) # n+1
|
219 |
+
|
220 |
+
x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) # 2n+1
|
221 |
+
x2_out = self.layer2_outconv(x2)
|
222 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
223 |
+
|
224 |
+
x2_out_2x = F.interpolate(x2_out, size=((x2_out.size(-2)-1)*2+1, (x2_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) # 4n+1
|
225 |
+
x1_out = self.layer1_outconv(x1)
|
226 |
+
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
227 |
+
|
228 |
+
return {'feats_c': x3_out, 'feats_f': x1_out}
|
229 |
+
|
230 |
+
|
231 |
+
class ResNetFPN_16_4(nn.Module):
|
232 |
+
"""
|
233 |
+
ResNet+FPN, output resolution are 1/16 and 1/4.
|
234 |
+
Each block has 2 layers.
|
235 |
+
"""
|
236 |
+
|
237 |
+
def __init__(self, config):
|
238 |
+
super().__init__()
|
239 |
+
# Config
|
240 |
+
block = BasicBlock
|
241 |
+
initial_dim = config['initial_dim']
|
242 |
+
block_dims = config['block_dims']
|
243 |
+
|
244 |
+
# Class Variable
|
245 |
+
self.in_planes = initial_dim
|
246 |
+
|
247 |
+
# Networks
|
248 |
+
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
249 |
+
self.bn1 = nn.BatchNorm2d(initial_dim)
|
250 |
+
self.relu = nn.ReLU(inplace=True)
|
251 |
+
|
252 |
+
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
253 |
+
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
254 |
+
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
255 |
+
self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
|
256 |
+
|
257 |
+
# 3. FPN upsample
|
258 |
+
self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
|
259 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
|
260 |
+
self.layer3_outconv2 = nn.Sequential(
|
261 |
+
conv3x3(block_dims[3], block_dims[3]),
|
262 |
+
nn.BatchNorm2d(block_dims[3]),
|
263 |
+
nn.LeakyReLU(),
|
264 |
+
conv3x3(block_dims[3], block_dims[2]),
|
265 |
+
)
|
266 |
+
|
267 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
268 |
+
self.layer2_outconv2 = nn.Sequential(
|
269 |
+
conv3x3(block_dims[2], block_dims[2]),
|
270 |
+
nn.BatchNorm2d(block_dims[2]),
|
271 |
+
nn.LeakyReLU(),
|
272 |
+
conv3x3(block_dims[2], block_dims[1]),
|
273 |
+
)
|
274 |
+
|
275 |
+
for m in self.modules():
|
276 |
+
if isinstance(m, nn.Conv2d):
|
277 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
278 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
279 |
+
nn.init.constant_(m.weight, 1)
|
280 |
+
nn.init.constant_(m.bias, 0)
|
281 |
+
|
282 |
+
def _make_layer(self, block, dim, stride=1):
|
283 |
+
layer1 = block(self.in_planes, dim, stride=stride)
|
284 |
+
layer2 = block(dim, dim, stride=1)
|
285 |
+
layers = (layer1, layer2)
|
286 |
+
|
287 |
+
self.in_planes = dim
|
288 |
+
return nn.Sequential(*layers)
|
289 |
+
|
290 |
+
def forward(self, x):
|
291 |
+
# ResNet Backbone
|
292 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
293 |
+
x1 = self.layer1(x0) # 1/2
|
294 |
+
x2 = self.layer2(x1) # 1/4
|
295 |
+
x3 = self.layer3(x2) # 1/8
|
296 |
+
x4 = self.layer4(x3) # 1/16
|
297 |
+
|
298 |
+
# FPN
|
299 |
+
x4_out = self.layer4_outconv(x4)
|
300 |
+
|
301 |
+
x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
|
302 |
+
x3_out = self.layer3_outconv(x3)
|
303 |
+
x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
|
304 |
+
|
305 |
+
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
306 |
+
x2_out = self.layer2_outconv(x2)
|
307 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
308 |
+
|
309 |
+
return {'feats_c': x4_out, 'feats_f': x2_out}
|
310 |
+
|
311 |
+
|
312 |
+
class ResNetFPN_8_1(nn.Module):
|
313 |
+
"""
|
314 |
+
ResNet+FPN, output resolution are 1/8 and 1.
|
315 |
+
Each block has 2 layers.
|
316 |
+
"""
|
317 |
+
|
318 |
+
def __init__(self, config):
|
319 |
+
super().__init__()
|
320 |
+
# Config
|
321 |
+
block = BasicBlock
|
322 |
+
initial_dim = config['initial_dim']
|
323 |
+
block_dims = config['block_dims']
|
324 |
+
|
325 |
+
# Class Variable
|
326 |
+
self.in_planes = initial_dim
|
327 |
+
self.skip_fine_feature = config['coarse_feat_only']
|
328 |
+
self.inter_feat = config['inter_feat']
|
329 |
+
|
330 |
+
# Networks
|
331 |
+
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
332 |
+
self.bn1 = nn.BatchNorm2d(initial_dim)
|
333 |
+
self.relu = nn.ReLU(inplace=True)
|
334 |
+
|
335 |
+
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
336 |
+
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
337 |
+
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
338 |
+
|
339 |
+
# 3. FPN upsample
|
340 |
+
if not self.skip_fine_feature:
|
341 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
342 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
343 |
+
self.layer2_outconv2 = nn.Sequential(
|
344 |
+
conv3x3(block_dims[2], block_dims[2]),
|
345 |
+
nn.BatchNorm2d(block_dims[2]),
|
346 |
+
nn.LeakyReLU(),
|
347 |
+
conv3x3(block_dims[2], block_dims[1]),
|
348 |
+
)
|
349 |
+
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
350 |
+
self.layer1_outconv2 = nn.Sequential(
|
351 |
+
conv3x3(block_dims[1], block_dims[1]),
|
352 |
+
nn.BatchNorm2d(block_dims[1]),
|
353 |
+
nn.LeakyReLU(),
|
354 |
+
conv3x3(block_dims[1], block_dims[0]),
|
355 |
+
)
|
356 |
+
|
357 |
+
for m in self.modules():
|
358 |
+
if isinstance(m, nn.Conv2d):
|
359 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
360 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
361 |
+
nn.init.constant_(m.weight, 1)
|
362 |
+
nn.init.constant_(m.bias, 0)
|
363 |
+
|
364 |
+
def _make_layer(self, block, dim, stride=1):
|
365 |
+
layer1 = block(self.in_planes, dim, stride=stride)
|
366 |
+
layer2 = block(dim, dim, stride=1)
|
367 |
+
layers = (layer1, layer2)
|
368 |
+
|
369 |
+
self.in_planes = dim
|
370 |
+
return nn.Sequential(*layers)
|
371 |
+
|
372 |
+
def forward(self, x):
|
373 |
+
# ResNet Backbone
|
374 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
375 |
+
x1 = self.layer1(x0) # 1/2
|
376 |
+
x2 = self.layer2(x1) # 1/4
|
377 |
+
x3 = self.layer3(x2) # 1/8
|
378 |
+
|
379 |
+
# FPN
|
380 |
+
if self.skip_fine_feature:
|
381 |
+
if self.inter_feat:
|
382 |
+
return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
|
383 |
+
else:
|
384 |
+
return {'feats_c': x3, 'feats_f': None}
|
385 |
+
|
386 |
+
x3_out = self.layer3_outconv(x3)
|
387 |
+
|
388 |
+
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
389 |
+
x2_out = self.layer2_outconv(x2)
|
390 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
391 |
+
|
392 |
+
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
|
393 |
+
x1_out = self.layer1_outconv(x1)
|
394 |
+
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
395 |
+
|
396 |
+
x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
|
397 |
+
|
398 |
+
if not self.inter_feat:
|
399 |
+
return {'feats_c': x3, 'feats_f': x0_out}
|
400 |
+
else:
|
401 |
+
return {'feats_c': x3, 'feats_f': x0_out, 'feats_x2': x2, 'feats_x1': x1}
|
402 |
+
|
403 |
+
|
404 |
+
class ResNetFPN_8_1_align(nn.Module):
|
405 |
+
"""
|
406 |
+
ResNet+FPN, output resolution are 1/8 and 1.
|
407 |
+
Each block has 2 layers.
|
408 |
+
"""
|
409 |
+
|
410 |
+
def __init__(self, config):
|
411 |
+
super().__init__()
|
412 |
+
# Config
|
413 |
+
block = BasicBlock
|
414 |
+
initial_dim = config['initial_dim']
|
415 |
+
block_dims = config['block_dims']
|
416 |
+
|
417 |
+
# Class Variable
|
418 |
+
self.in_planes = initial_dim
|
419 |
+
self.skip_fine_feature = config['coarse_feat_only']
|
420 |
+
self.inter_feat = config['inter_feat']
|
421 |
+
# Networks
|
422 |
+
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
423 |
+
self.bn1 = nn.BatchNorm2d(initial_dim)
|
424 |
+
self.relu = nn.ReLU(inplace=True)
|
425 |
+
|
426 |
+
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
427 |
+
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
428 |
+
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
429 |
+
|
430 |
+
# 3. FPN upsample
|
431 |
+
if not self.skip_fine_feature:
|
432 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
433 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
434 |
+
self.layer2_outconv2 = nn.Sequential(
|
435 |
+
conv3x3(block_dims[2], block_dims[2]),
|
436 |
+
nn.BatchNorm2d(block_dims[2]),
|
437 |
+
nn.LeakyReLU(),
|
438 |
+
conv3x3(block_dims[2], block_dims[1]),
|
439 |
+
)
|
440 |
+
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
441 |
+
self.layer1_outconv2 = nn.Sequential(
|
442 |
+
conv3x3(block_dims[1], block_dims[1]),
|
443 |
+
nn.BatchNorm2d(block_dims[1]),
|
444 |
+
nn.LeakyReLU(),
|
445 |
+
conv3x3(block_dims[1], block_dims[0]),
|
446 |
+
)
|
447 |
+
|
448 |
+
for m in self.modules():
|
449 |
+
if isinstance(m, nn.Conv2d):
|
450 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
451 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
452 |
+
nn.init.constant_(m.weight, 1)
|
453 |
+
nn.init.constant_(m.bias, 0)
|
454 |
+
|
455 |
+
def _make_layer(self, block, dim, stride=1):
|
456 |
+
layer1 = block(self.in_planes, dim, stride=stride)
|
457 |
+
layer2 = block(dim, dim, stride=1)
|
458 |
+
layers = (layer1, layer2)
|
459 |
+
|
460 |
+
self.in_planes = dim
|
461 |
+
return nn.Sequential(*layers)
|
462 |
+
|
463 |
+
def forward(self, x):
|
464 |
+
# ResNet Backbone
|
465 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
466 |
+
x1 = self.layer1(x0) # 1/2
|
467 |
+
x2 = self.layer2(x1) # 1/4
|
468 |
+
x3 = self.layer3(x2) # 1/8
|
469 |
+
|
470 |
+
# FPN
|
471 |
+
|
472 |
+
if self.skip_fine_feature:
|
473 |
+
if self.inter_feat:
|
474 |
+
return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
|
475 |
+
else:
|
476 |
+
return {'feats_c': x3, 'feats_f': None}
|
477 |
+
|
478 |
+
x3_out = self.layer3_outconv(x3)
|
479 |
+
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
|
480 |
+
x2_out = self.layer2_outconv(x2)
|
481 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
482 |
+
|
483 |
+
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
|
484 |
+
x1_out = self.layer1_outconv(x1)
|
485 |
+
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
486 |
+
|
487 |
+
x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
|
488 |
+
|
489 |
+
if not self.inter_feat:
|
490 |
+
return {'feats_c': x3, 'feats_f': x0_out}
|
491 |
+
else:
|
492 |
+
return {'feats_c': x3, 'feats_f': x0_out, 'feats_x2': x2, 'feats_x1': x1}
|
493 |
+
|
494 |
+
def pro(self, x, profiler):
|
495 |
+
with profiler.profile('ResNet Backbone'):
|
496 |
+
# ResNet Backbone
|
497 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
498 |
+
x1 = self.layer1(x0) # 1/2
|
499 |
+
x2 = self.layer2(x1) # 1/4
|
500 |
+
x3 = self.layer3(x2) # 1/8
|
501 |
+
|
502 |
+
with profiler.profile('FPN'):
|
503 |
+
# FPN
|
504 |
+
x3_out = self.layer3_outconv(x3)
|
505 |
+
|
506 |
+
if self.skip_fine:
|
507 |
+
return [x3_out, None]
|
508 |
+
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
|
509 |
+
x2_out = self.layer2_outconv(x2)
|
510 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
511 |
+
|
512 |
+
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
|
513 |
+
x1_out = self.layer1_outconv(x1)
|
514 |
+
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
515 |
+
|
516 |
+
with profiler.profile('upsample*1'):
|
517 |
+
x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
|
518 |
+
|
519 |
+
return {'feats_c': x3_out, 'feats_f': x0_out}
|
520 |
+
|
521 |
+
|
522 |
+
class ResNetFPN_8_2_align(nn.Module):
|
523 |
+
"""
|
524 |
+
ResNet+FPN, output resolution are 1/8 and 1/2.
|
525 |
+
Each block has 2 layers.
|
526 |
+
"""
|
527 |
+
|
528 |
+
def __init__(self, config):
|
529 |
+
super().__init__()
|
530 |
+
# Config
|
531 |
+
block = BasicBlock
|
532 |
+
initial_dim = config['initial_dim']
|
533 |
+
block_dims = config['block_dims']
|
534 |
+
|
535 |
+
# Class Variable
|
536 |
+
self.in_planes = initial_dim
|
537 |
+
self.skip_fine_feature = config['coarse_feat_only']
|
538 |
+
self.inter_feat = config['inter_feat']
|
539 |
+
# Networks
|
540 |
+
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
541 |
+
self.bn1 = nn.BatchNorm2d(initial_dim)
|
542 |
+
self.relu = nn.ReLU(inplace=True)
|
543 |
+
|
544 |
+
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
545 |
+
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
546 |
+
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
547 |
+
|
548 |
+
# 3. FPN upsample
|
549 |
+
if not self.skip_fine_feature:
|
550 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
551 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
552 |
+
self.layer2_outconv2 = nn.Sequential(
|
553 |
+
conv3x3(block_dims[2], block_dims[2]),
|
554 |
+
nn.BatchNorm2d(block_dims[2]),
|
555 |
+
nn.LeakyReLU(),
|
556 |
+
conv3x3(block_dims[2], block_dims[1]),
|
557 |
+
)
|
558 |
+
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
559 |
+
self.layer1_outconv2 = nn.Sequential(
|
560 |
+
conv3x3(block_dims[1], block_dims[1]),
|
561 |
+
nn.BatchNorm2d(block_dims[1]),
|
562 |
+
nn.LeakyReLU(),
|
563 |
+
conv3x3(block_dims[1], block_dims[0]),
|
564 |
+
)
|
565 |
+
|
566 |
+
for m in self.modules():
|
567 |
+
if isinstance(m, nn.Conv2d):
|
568 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
569 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
570 |
+
nn.init.constant_(m.weight, 1)
|
571 |
+
nn.init.constant_(m.bias, 0)
|
572 |
+
|
573 |
+
def _make_layer(self, block, dim, stride=1):
|
574 |
+
layer1 = block(self.in_planes, dim, stride=stride)
|
575 |
+
layer2 = block(dim, dim, stride=1)
|
576 |
+
layers = (layer1, layer2)
|
577 |
+
|
578 |
+
self.in_planes = dim
|
579 |
+
return nn.Sequential(*layers)
|
580 |
+
|
581 |
+
def forward(self, x):
|
582 |
+
# ResNet Backbone
|
583 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
584 |
+
x1 = self.layer1(x0) # 1/2
|
585 |
+
x2 = self.layer2(x1) # 1/4
|
586 |
+
x3 = self.layer3(x2) # 1/8
|
587 |
+
|
588 |
+
if self.skip_fine_feature:
|
589 |
+
if self.inter_feat:
|
590 |
+
return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
|
591 |
+
else:
|
592 |
+
return {'feats_c': x3, 'feats_f': None}
|
593 |
+
|
594 |
+
# FPN
|
595 |
+
x3_out = self.layer3_outconv(x3)
|
596 |
+
|
597 |
+
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
|
598 |
+
x2_out = self.layer2_outconv(x2)
|
599 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
600 |
+
|
601 |
+
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
|
602 |
+
x1_out = self.layer1_outconv(x1)
|
603 |
+
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
604 |
+
|
605 |
+
if not self.inter_feat:
|
606 |
+
return {'feats_c': x3, 'feats_f': x1_out}
|
607 |
+
else:
|
608 |
+
return {'feats_c': x3, 'feats_f': x1_out, 'feats_x2': x2, 'feats_x1': x1}
|
609 |
+
|
610 |
+
|
611 |
+
class ResNet_8_1_align(nn.Module):
|
612 |
+
"""
|
613 |
+
ResNet, output resolution are 1/8 and 1.
|
614 |
+
Each block has 2 layers.
|
615 |
+
"""
|
616 |
+
|
617 |
+
def __init__(self, config):
|
618 |
+
super().__init__()
|
619 |
+
# Config
|
620 |
+
block = BasicBlock
|
621 |
+
initial_dim = config['initial_dim']
|
622 |
+
block_dims = config['block_dims']
|
623 |
+
|
624 |
+
# Class Variable
|
625 |
+
self.in_planes = initial_dim
|
626 |
+
|
627 |
+
# Networks
|
628 |
+
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
629 |
+
self.bn1 = nn.BatchNorm2d(initial_dim)
|
630 |
+
self.relu = nn.ReLU(inplace=True)
|
631 |
+
|
632 |
+
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
633 |
+
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
634 |
+
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
635 |
+
|
636 |
+
# 3. FPN upsample
|
637 |
+
# self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
638 |
+
# self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
639 |
+
# self.layer2_outconv2 = nn.Sequential(
|
640 |
+
# conv3x3(block_dims[2], block_dims[2]),
|
641 |
+
# nn.BatchNorm2d(block_dims[2]),
|
642 |
+
# nn.LeakyReLU(),
|
643 |
+
# conv3x3(block_dims[2], block_dims[1]),
|
644 |
+
# )
|
645 |
+
# self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
646 |
+
# self.layer1_outconv2 = nn.Sequential(
|
647 |
+
# conv3x3(block_dims[1], block_dims[1]),
|
648 |
+
# nn.BatchNorm2d(block_dims[1]),
|
649 |
+
# nn.LeakyReLU(),
|
650 |
+
# conv3x3(block_dims[1], block_dims[0]),
|
651 |
+
# )
|
652 |
+
self.layer0_outconv = conv1x1(block_dims[2], block_dims[0])
|
653 |
+
for m in self.modules():
|
654 |
+
if isinstance(m, nn.Conv2d):
|
655 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
656 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
657 |
+
nn.init.constant_(m.weight, 1)
|
658 |
+
nn.init.constant_(m.bias, 0)
|
659 |
+
|
660 |
+
def _make_layer(self, block, dim, stride=1):
|
661 |
+
layer1 = block(self.in_planes, dim, stride=stride)
|
662 |
+
layer2 = block(dim, dim, stride=1)
|
663 |
+
layers = (layer1, layer2)
|
664 |
+
|
665 |
+
self.in_planes = dim
|
666 |
+
return nn.Sequential(*layers)
|
667 |
+
|
668 |
+
def forward(self, x):
|
669 |
+
# ResNet Backbone
|
670 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
671 |
+
x1 = self.layer1(x0) # 1/2
|
672 |
+
x2 = self.layer2(x1) # 1/4
|
673 |
+
x3 = self.layer3(x2) # 1/8
|
674 |
+
|
675 |
+
# FPN
|
676 |
+
# x3_out = self.layer3_outconv(x3)
|
677 |
+
|
678 |
+
# x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
|
679 |
+
# x2_out = self.layer2_outconv(x2)
|
680 |
+
# x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
681 |
+
|
682 |
+
# x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
|
683 |
+
# x1_out = self.layer1_outconv(x1)
|
684 |
+
# x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
685 |
+
|
686 |
+
x0_out = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False)
|
687 |
+
x0_out = self.layer0_outconv(x0_out)
|
688 |
+
|
689 |
+
return {'feats_c': x3, 'feats_f': x0_out}
|
690 |
+
|
691 |
+
class VGG_8_1_align(nn.Module):
|
692 |
+
"""
|
693 |
+
VGG-like backbone, output resolution are 1/8 and 1.
|
694 |
+
Each block has 2 layers.
|
695 |
+
"""
|
696 |
+
|
697 |
+
def __init__(self, config):
|
698 |
+
super().__init__()
|
699 |
+
# Config
|
700 |
+
block = BasicBlock
|
701 |
+
initial_dim = config['initial_dim']
|
702 |
+
block_dims = config['block_dims']
|
703 |
+
|
704 |
+
self.relu = nn.ReLU(inplace=True)
|
705 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
706 |
+
c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
|
707 |
+
|
708 |
+
self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
|
709 |
+
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
|
710 |
+
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
|
711 |
+
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
|
712 |
+
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
|
713 |
+
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
|
714 |
+
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
|
715 |
+
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
|
716 |
+
|
717 |
+
# self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
718 |
+
# self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
|
719 |
+
|
720 |
+
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
721 |
+
self.convDb = nn.Conv2d(
|
722 |
+
c5, 256,
|
723 |
+
kernel_size=1, stride=1, padding=0)
|
724 |
+
self.layer0_outconv = conv1x1(block_dims[2], block_dims[0])
|
725 |
+
for m in self.modules():
|
726 |
+
if isinstance(m, nn.Conv2d):
|
727 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
728 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
729 |
+
nn.init.constant_(m.weight, 1)
|
730 |
+
nn.init.constant_(m.bias, 0)
|
731 |
+
|
732 |
+
def _make_layer(self, block, dim, stride=1):
|
733 |
+
layer1 = block(self.in_planes, dim, stride=stride)
|
734 |
+
layer2 = block(dim, dim, stride=1)
|
735 |
+
layers = (layer1, layer2)
|
736 |
+
|
737 |
+
self.in_planes = dim
|
738 |
+
return nn.Sequential(*layers)
|
739 |
+
|
740 |
+
def forward(self, x):
|
741 |
+
# Shared Encoder
|
742 |
+
x = self.relu(self.conv1a(x))
|
743 |
+
x = self.relu(self.conv1b(x))
|
744 |
+
x = self.pool(x)
|
745 |
+
x = self.relu(self.conv2a(x))
|
746 |
+
x = self.relu(self.conv2b(x))
|
747 |
+
x = self.pool(x)
|
748 |
+
x = self.relu(self.conv3a(x))
|
749 |
+
x = self.relu(self.conv3b(x))
|
750 |
+
x = self.pool(x)
|
751 |
+
x = self.relu(self.conv4a(x))
|
752 |
+
x = self.relu(self.conv4b(x))
|
753 |
+
|
754 |
+
cDa = self.relu(self.convDa(x))
|
755 |
+
descriptors = self.convDb(cDa)
|
756 |
+
x3_out = nn.functional.normalize(descriptors, p=2, dim=1)
|
757 |
+
|
758 |
+
x0_out = F.interpolate(x3_out, scale_factor=8., mode='bilinear', align_corners=False)
|
759 |
+
x0_out = self.layer0_outconv(x0_out)
|
760 |
+
# ResNet Backbone
|
761 |
+
# x0 = self.relu(self.bn1(self.conv1(x)))
|
762 |
+
# x1 = self.layer1(x0) # 1/2
|
763 |
+
# x2 = self.layer2(x1) # 1/4
|
764 |
+
# x3 = self.layer3(x2) # 1/8
|
765 |
+
|
766 |
+
# # FPN
|
767 |
+
# x3_out = self.layer3_outconv(x3)
|
768 |
+
|
769 |
+
# x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
|
770 |
+
# x2_out = self.layer2_outconv(x2)
|
771 |
+
# x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
772 |
+
|
773 |
+
# x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
|
774 |
+
# x1_out = self.layer1_outconv(x1)
|
775 |
+
# x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
776 |
+
|
777 |
+
# x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
|
778 |
+
|
779 |
+
return {'feats_c': x3_out, 'feats_f': x0_out}
|
780 |
+
|
781 |
+
class RepVGG_8_1_align(nn.Module):
|
782 |
+
"""
|
783 |
+
RepVGG backbone, output resolution are 1/8 and 1.
|
784 |
+
Each block has 2 layers.
|
785 |
+
"""
|
786 |
+
|
787 |
+
def __init__(self, config):
|
788 |
+
super().__init__()
|
789 |
+
# Config
|
790 |
+
# block = BasicBlock
|
791 |
+
# initial_dim = config['initial_dim']
|
792 |
+
block_dims = config['block_dims']
|
793 |
+
self.skip_fine_feature = config['coarse_feat_only']
|
794 |
+
self.inter_feat = config['inter_feat']
|
795 |
+
self.leaky = config['leaky']
|
796 |
+
|
797 |
+
# backbone_name='RepVGG-B0'
|
798 |
+
if config.get('repvggmodel') is not None:
|
799 |
+
backbone_name=config['repvggmodel']
|
800 |
+
elif self.leaky:
|
801 |
+
backbone_name='RepVGG-A1_leaky'
|
802 |
+
else:
|
803 |
+
backbone_name='RepVGG-A1'
|
804 |
+
repvgg_fn = get_RepVGG_func_by_name(backbone_name)
|
805 |
+
backbone = repvgg_fn(False)
|
806 |
+
self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4
|
807 |
+
# self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3, backbone.stage4
|
808 |
+
|
809 |
+
# 3. FPN upsample
|
810 |
+
if not self.skip_fine_feature:
|
811 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
812 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
813 |
+
self.layer2_outconv2 = nn.Sequential(
|
814 |
+
conv3x3(block_dims[2], block_dims[2]),
|
815 |
+
nn.BatchNorm2d(block_dims[2]),
|
816 |
+
nn.LeakyReLU(),
|
817 |
+
conv3x3(block_dims[2], block_dims[1]),
|
818 |
+
)
|
819 |
+
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
820 |
+
self.layer1_outconv2 = nn.Sequential(
|
821 |
+
conv3x3(block_dims[1], block_dims[1]),
|
822 |
+
nn.BatchNorm2d(block_dims[1]),
|
823 |
+
nn.LeakyReLU(),
|
824 |
+
conv3x3(block_dims[1], block_dims[0]),
|
825 |
+
)
|
826 |
+
|
827 |
+
# self.layer0_outconv = conv1x1(192, 48)
|
828 |
+
|
829 |
+
for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
|
830 |
+
for m in layer.modules():
|
831 |
+
if isinstance(m, nn.Conv2d):
|
832 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
833 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
834 |
+
nn.init.constant_(m.weight, 1)
|
835 |
+
nn.init.constant_(m.bias, 0)
|
836 |
+
# for layer in [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4]:
|
837 |
+
# for m in layer.modules():
|
838 |
+
# if isinstance(m, nn.Conv2d):
|
839 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
840 |
+
# elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
841 |
+
# nn.init.constant_(m.weight, 1)
|
842 |
+
# nn.init.constant_(m.bias, 0)
|
843 |
+
|
844 |
+
def forward(self, x):
|
845 |
+
|
846 |
+
out = self.layer0(x) # 1/2
|
847 |
+
for module in self.layer1:
|
848 |
+
out = module(out) # 1/2
|
849 |
+
x1 = out
|
850 |
+
for module in self.layer2:
|
851 |
+
out = module(out) # 1/4
|
852 |
+
x2 = out
|
853 |
+
for module in self.layer3:
|
854 |
+
out = module(out) # 1/8
|
855 |
+
x3 = out
|
856 |
+
# for module in self.layer4:
|
857 |
+
# out = module(out)
|
858 |
+
# x3 = out
|
859 |
+
|
860 |
+
if self.skip_fine_feature:
|
861 |
+
if self.inter_feat:
|
862 |
+
return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
|
863 |
+
else:
|
864 |
+
return {'feats_c': x3, 'feats_f': None}
|
865 |
+
x3_out = self.layer3_outconv(x3)
|
866 |
+
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
|
867 |
+
x2_out = self.layer2_outconv(x2)
|
868 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
869 |
+
|
870 |
+
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
|
871 |
+
x1_out = self.layer1_outconv(x1)
|
872 |
+
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
873 |
+
|
874 |
+
x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
|
875 |
+
|
876 |
+
# x_f = F.interpolate(x_c, scale_factor=8., mode='bilinear', align_corners=False)
|
877 |
+
# x_f = self.layer0_outconv(x_f)
|
878 |
+
return {'feats_c': x3_out, 'feats_f': x0_out}
|
879 |
+
|
880 |
+
|
881 |
+
class RepVGG_8_2_fix(nn.Module):
|
882 |
+
"""
|
883 |
+
RepVGG backbone, output resolution are 1/8 and 1.
|
884 |
+
Each block has 2 layers.
|
885 |
+
"""
|
886 |
+
|
887 |
+
def __init__(self, config):
|
888 |
+
super().__init__()
|
889 |
+
# Config
|
890 |
+
# block = BasicBlock
|
891 |
+
# initial_dim = config['initial_dim']
|
892 |
+
block_dims = config['block_dims']
|
893 |
+
self.skip_fine_feature = config['coarse_feat_only']
|
894 |
+
self.inter_feat = config['inter_feat']
|
895 |
+
|
896 |
+
# backbone_name='RepVGG-B0'
|
897 |
+
backbone_name='RepVGG-A1'
|
898 |
+
repvgg_fn = get_RepVGG_func_by_name(backbone_name)
|
899 |
+
backbone = repvgg_fn(False)
|
900 |
+
self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4
|
901 |
+
|
902 |
+
# 3. FPN upsample
|
903 |
+
if not self.skip_fine_feature:
|
904 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
905 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
906 |
+
self.layer2_outconv2 = nn.Sequential(
|
907 |
+
conv3x3(block_dims[2], block_dims[2]),
|
908 |
+
nn.BatchNorm2d(block_dims[2]),
|
909 |
+
nn.LeakyReLU(),
|
910 |
+
conv3x3(block_dims[2], block_dims[1]),
|
911 |
+
)
|
912 |
+
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
913 |
+
self.layer1_outconv2 = nn.Sequential(
|
914 |
+
conv3x3(block_dims[1], block_dims[1]),
|
915 |
+
nn.BatchNorm2d(block_dims[1]),
|
916 |
+
nn.LeakyReLU(),
|
917 |
+
conv3x3(block_dims[1], block_dims[0]),
|
918 |
+
)
|
919 |
+
|
920 |
+
# self.layer0_outconv = conv1x1(192, 48)
|
921 |
+
|
922 |
+
for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
|
923 |
+
for m in layer.modules():
|
924 |
+
if isinstance(m, nn.Conv2d):
|
925 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
926 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
927 |
+
nn.init.constant_(m.weight, 1)
|
928 |
+
nn.init.constant_(m.bias, 0)
|
929 |
+
|
930 |
+
def forward(self, x):
|
931 |
+
|
932 |
+
x0 = self.layer0(x) # 1/2
|
933 |
+
out = x0
|
934 |
+
for module in self.layer1:
|
935 |
+
out = module(out) # 1/2
|
936 |
+
x1 = out
|
937 |
+
for module in self.layer2:
|
938 |
+
out = module(out) # 1/4
|
939 |
+
x2 = out
|
940 |
+
for module in self.layer3:
|
941 |
+
out = module(out) # 1/8
|
942 |
+
x3 = out
|
943 |
+
# for module in self.layer4:
|
944 |
+
# out = module(out)
|
945 |
+
|
946 |
+
if self.skip_fine_feature:
|
947 |
+
if self.inter_feat:
|
948 |
+
return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
|
949 |
+
else:
|
950 |
+
return {'feats_c': x3, 'feats_f': None}
|
951 |
+
x3_out = self.layer3_outconv(x3)
|
952 |
+
x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True)
|
953 |
+
x2_out = self.layer2_outconv(x2)
|
954 |
+
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
955 |
+
|
956 |
+
x2_out_2x = F.interpolate(x2_out, size=((x2_out.size(-2)-1)*2+1, (x2_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True)
|
957 |
+
x1_out = self.layer1_outconv(x1)
|
958 |
+
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
959 |
+
|
960 |
+
# x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
|
961 |
+
|
962 |
+
# x_f = F.interpolate(x_c, scale_factor=8., mode='bilinear', align_corners=False)
|
963 |
+
# x_f = self.layer0_outconv(x_f)
|
964 |
+
return {'feats_c': x3_out, 'feats_f': x1_out}
|
965 |
+
|
966 |
+
|
967 |
+
class RepVGGnfpn_8_1_align(nn.Module):
|
968 |
+
"""
|
969 |
+
RepVGG backbone, output resolution are 1/8 and 1.
|
970 |
+
Each block has 2 layers.
|
971 |
+
"""
|
972 |
+
|
973 |
+
def __init__(self, config):
|
974 |
+
super().__init__()
|
975 |
+
# Config
|
976 |
+
# block = BasicBlock
|
977 |
+
# initial_dim = config['initial_dim']
|
978 |
+
block_dims = config['block_dims']
|
979 |
+
self.skip_fine_feature = config['coarse_feat_only']
|
980 |
+
self.inter_feat = config['inter_feat']
|
981 |
+
|
982 |
+
# backbone_name='RepVGG-B0'
|
983 |
+
backbone_name='RepVGG-A1'
|
984 |
+
repvgg_fn = get_RepVGG_func_by_name(backbone_name)
|
985 |
+
backbone = repvgg_fn(False)
|
986 |
+
self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4
|
987 |
+
|
988 |
+
# 3. FPN upsample
|
989 |
+
if not self.skip_fine_feature:
|
990 |
+
self.layer0_outconv = conv1x1(block_dims[2], block_dims[0])
|
991 |
+
# self.layer0_outconv = conv1x1(192, 48)
|
992 |
+
|
993 |
+
for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
|
994 |
+
for m in layer.modules():
|
995 |
+
if isinstance(m, nn.Conv2d):
|
996 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
997 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
998 |
+
nn.init.constant_(m.weight, 1)
|
999 |
+
nn.init.constant_(m.bias, 0)
|
1000 |
+
|
1001 |
+
def forward(self, x):
|
1002 |
+
|
1003 |
+
x0 = self.layer0(x) # 1/2
|
1004 |
+
out = x0
|
1005 |
+
for module in self.layer1:
|
1006 |
+
out = module(out) # 1/2
|
1007 |
+
x1 = out
|
1008 |
+
for module in self.layer2:
|
1009 |
+
out = module(out) # 1/4
|
1010 |
+
x2 = out
|
1011 |
+
for module in self.layer3:
|
1012 |
+
out = module(out) # 1/8
|
1013 |
+
x3 = out
|
1014 |
+
# for module in self.layer4:
|
1015 |
+
# out = module(out)
|
1016 |
+
|
1017 |
+
if self.skip_fine_feature:
|
1018 |
+
if self.inter_feat:
|
1019 |
+
return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
|
1020 |
+
else:
|
1021 |
+
return {'feats_c': x3, 'feats_f': None}
|
1022 |
+
# x3_out = self.layer3_outconv(x3)
|
1023 |
+
# x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
|
1024 |
+
# x2_out = self.layer2_outconv(x2)
|
1025 |
+
# x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
1026 |
+
|
1027 |
+
# x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
|
1028 |
+
# x1_out = self.layer1_outconv(x1)
|
1029 |
+
# x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
1030 |
+
|
1031 |
+
# x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
|
1032 |
+
|
1033 |
+
x_f = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False)
|
1034 |
+
x_f = self.layer0_outconv(x_f)
|
1035 |
+
# x_f2 = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False)
|
1036 |
+
# x_f2 = self.layer0_outconv(x_f2)
|
1037 |
+
return {'feats_c': x3, 'feats_f': x_f}
|
1038 |
+
|
1039 |
+
|
1040 |
+
class s2dnet_8_1_align(nn.Module):
|
1041 |
+
"""
|
1042 |
+
ResNet+FPN, output resolution are 1/8 and 1.
|
1043 |
+
Each block has 2 layers.
|
1044 |
+
"""
|
1045 |
+
|
1046 |
+
def __init__(self, config):
|
1047 |
+
super().__init__()
|
1048 |
+
# Config
|
1049 |
+
block = BasicBlock
|
1050 |
+
initial_dim = config['initial_dim']
|
1051 |
+
block_dims = config['block_dims']
|
1052 |
+
|
1053 |
+
# Class Variable
|
1054 |
+
self.in_planes = initial_dim
|
1055 |
+
self.skip_fine_feature = config['coarse_feat_only']
|
1056 |
+
self.inter_feat = config['inter_feat']
|
1057 |
+
# Networks
|
1058 |
+
self.backbone = S2DNet(checkpoint_path = '/cephfs-mvs/3dv-research/hexingyi/code_yf/loftrdev/weights/s2dnet/s2dnet_weights.pth')
|
1059 |
+
# 3. FPN upsample
|
1060 |
+
# self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
1061 |
+
# if not self.skip_fine_feature:
|
1062 |
+
# self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
1063 |
+
# self.layer2_outconv2 = nn.Sequential(
|
1064 |
+
# conv3x3(block_dims[2], block_dims[2]),
|
1065 |
+
# nn.BatchNorm2d(block_dims[2]),
|
1066 |
+
# nn.LeakyReLU(),
|
1067 |
+
# conv3x3(block_dims[2], block_dims[1]),
|
1068 |
+
# )
|
1069 |
+
# self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
1070 |
+
# self.layer1_outconv2 = nn.Sequential(
|
1071 |
+
# conv3x3(block_dims[1], block_dims[1]),
|
1072 |
+
# nn.BatchNorm2d(block_dims[1]),
|
1073 |
+
# nn.LeakyReLU(),
|
1074 |
+
# conv3x3(block_dims[1], block_dims[0]),
|
1075 |
+
# )
|
1076 |
+
|
1077 |
+
# for m in self.modules():
|
1078 |
+
# if isinstance(m, nn.Conv2d):
|
1079 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
1080 |
+
# elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
1081 |
+
# nn.init.constant_(m.weight, 1)
|
1082 |
+
# nn.init.constant_(m.bias, 0)
|
1083 |
+
|
1084 |
+
def forward(self, x):
|
1085 |
+
ret = self.backbone(x)
|
1086 |
+
ret[2] = F.interpolate(ret[2], scale_factor=2., mode='bilinear', align_corners=False)
|
1087 |
+
if self.skip_fine_feature:
|
1088 |
+
if self.inter_feat:
|
1089 |
+
return {'feats_c': ret[2], 'feats_f': None, 'feats_x2': ret[1], 'feats_x1': ret[0]}
|
1090 |
+
else:
|
1091 |
+
return {'feats_c': ret[2], 'feats_f': None,}
|
1092 |
+
|
1093 |
+
def pro(self, x, profiler):
|
1094 |
+
pass
|
imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
# from torchvision import models
|
6 |
+
from typing import List, Dict
|
7 |
+
|
8 |
+
# VGG-16 Layer Names and Channels
|
9 |
+
vgg16_layers = {
|
10 |
+
"conv1_1": 64,
|
11 |
+
"relu1_1": 64,
|
12 |
+
"conv1_2": 64,
|
13 |
+
"relu1_2": 64,
|
14 |
+
"pool1": 64,
|
15 |
+
"conv2_1": 128,
|
16 |
+
"relu2_1": 128,
|
17 |
+
"conv2_2": 128,
|
18 |
+
"relu2_2": 128,
|
19 |
+
"pool2": 128,
|
20 |
+
"conv3_1": 256,
|
21 |
+
"relu3_1": 256,
|
22 |
+
"conv3_2": 256,
|
23 |
+
"relu3_2": 256,
|
24 |
+
"conv3_3": 256,
|
25 |
+
"relu3_3": 256,
|
26 |
+
"pool3": 256,
|
27 |
+
"conv4_1": 512,
|
28 |
+
"relu4_1": 512,
|
29 |
+
"conv4_2": 512,
|
30 |
+
"relu4_2": 512,
|
31 |
+
"conv4_3": 512,
|
32 |
+
"relu4_3": 512,
|
33 |
+
"pool4": 512,
|
34 |
+
"conv5_1": 512,
|
35 |
+
"relu5_1": 512,
|
36 |
+
"conv5_2": 512,
|
37 |
+
"relu5_2": 512,
|
38 |
+
"conv5_3": 512,
|
39 |
+
"relu5_3": 512,
|
40 |
+
"pool5": 512,
|
41 |
+
}
|
42 |
+
|
43 |
+
class AdapLayers(nn.Module):
|
44 |
+
"""Small adaptation layers.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128):
|
48 |
+
"""Initialize one adaptation layer for every extraction point.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
hypercolumn_layers: The list of the hypercolumn layer names.
|
52 |
+
output_dim: The output channel dimension.
|
53 |
+
"""
|
54 |
+
super(AdapLayers, self).__init__()
|
55 |
+
self.layers = []
|
56 |
+
channel_sizes = [vgg16_layers[name] for name in hypercolumn_layers]
|
57 |
+
for i, l in enumerate(channel_sizes):
|
58 |
+
layer = nn.Sequential(
|
59 |
+
nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0),
|
60 |
+
nn.ReLU(),
|
61 |
+
nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2),
|
62 |
+
nn.BatchNorm2d(output_dim),
|
63 |
+
)
|
64 |
+
self.layers.append(layer)
|
65 |
+
self.add_module("adap_layer_{}".format(i), layer)
|
66 |
+
|
67 |
+
def forward(self, features: List[torch.tensor]):
|
68 |
+
"""Apply adaptation layers.
|
69 |
+
"""
|
70 |
+
for i, _ in enumerate(features):
|
71 |
+
features[i] = getattr(self, "adap_layer_{}".format(i))(features[i])
|
72 |
+
return features
|
73 |
+
|
74 |
+
class S2DNet(nn.Module):
|
75 |
+
"""The S2DNet model
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
# hypercolumn_layers: List[str] = ["conv2_2", "conv3_3", "relu4_3"],
|
81 |
+
hypercolumn_layers: List[str] = ["conv1_2", "conv3_3", "conv5_3"],
|
82 |
+
checkpoint_path: str = None,
|
83 |
+
):
|
84 |
+
"""Initialize S2DNet.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
device: The torch device to put the model on
|
88 |
+
hypercolumn_layers: Names of the layers to extract features from
|
89 |
+
checkpoint_path: Path to the pre-trained model.
|
90 |
+
"""
|
91 |
+
super(S2DNet, self).__init__()
|
92 |
+
self._checkpoint_path = checkpoint_path
|
93 |
+
self.layer_to_index = dict((k, v) for v, k in enumerate(vgg16_layers.keys()))
|
94 |
+
self._hypercolumn_layers = hypercolumn_layers
|
95 |
+
|
96 |
+
# Initialize architecture
|
97 |
+
vgg16 = models.vgg16(pretrained=False)
|
98 |
+
# layers = list(vgg16.features.children())[:-2]
|
99 |
+
layers = list(vgg16.features.children())[:-1]
|
100 |
+
# layers = list(vgg16.features.children())[:23] # relu4_3
|
101 |
+
self.encoder = nn.Sequential(*layers)
|
102 |
+
self.adaptation_layers = AdapLayers(self._hypercolumn_layers) # .to(self._device)
|
103 |
+
self.eval()
|
104 |
+
|
105 |
+
# Restore params from checkpoint
|
106 |
+
if checkpoint_path:
|
107 |
+
print(">> Loading weights from {}".format(checkpoint_path))
|
108 |
+
self._checkpoint = torch.load(checkpoint_path)
|
109 |
+
self._hypercolumn_layers = self._checkpoint["hypercolumn_layers"]
|
110 |
+
self.load_state_dict(self._checkpoint["state_dict"])
|
111 |
+
|
112 |
+
def forward(self, image_tensor: torch.FloatTensor):
|
113 |
+
"""Compute intermediate feature maps at the provided extraction levels.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
image_tensor: The [N x 3 x H x Ws] input image tensor.
|
117 |
+
Returns:
|
118 |
+
feature_maps: The list of output feature maps.
|
119 |
+
"""
|
120 |
+
feature_maps, j = [], 0
|
121 |
+
feature_map = image_tensor.repeat(1,3,1,1)
|
122 |
+
layer_list = list(self.encoder.modules())[0]
|
123 |
+
for i, layer in enumerate(layer_list):
|
124 |
+
feature_map = layer(feature_map)
|
125 |
+
if j < len(self._hypercolumn_layers):
|
126 |
+
next_extraction_index = self.layer_to_index[self._hypercolumn_layers[j]]
|
127 |
+
if i == next_extraction_index:
|
128 |
+
feature_maps.append(feature_map)
|
129 |
+
j += 1
|
130 |
+
feature_maps = self.adaptation_layers(feature_maps)
|
131 |
+
return feature_maps
|
imcui/third_party/MatchAnything/src/loftr/loftr.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops.einops import rearrange
|
4 |
+
|
5 |
+
from .backbone import build_backbone
|
6 |
+
# from third_party.matchformer.model.backbone import build_backbone as build_backbone_matchformer
|
7 |
+
from .utils.position_encoding import PositionEncodingSine
|
8 |
+
from .loftr_module import LocalFeatureTransformer, FinePreprocess
|
9 |
+
from .utils.coarse_matching import CoarseMatching
|
10 |
+
from .utils.fine_matching import FineMatching
|
11 |
+
|
12 |
+
from loguru import logger
|
13 |
+
|
14 |
+
class LoFTR(nn.Module):
|
15 |
+
def __init__(self, config, profiler=None):
|
16 |
+
super().__init__()
|
17 |
+
# Misc
|
18 |
+
self.config = config
|
19 |
+
self.profiler = profiler
|
20 |
+
|
21 |
+
# Modules
|
22 |
+
self.backbone = build_backbone(config)
|
23 |
+
if not (self.config['coarse']['skip'] or self.config['coarse']['rope'] or self.config['coarse']['pan'] or self.config['coarse']['token_mixer'] is not None):
|
24 |
+
self.pos_encoding = PositionEncodingSine(
|
25 |
+
config['coarse']['d_model'],
|
26 |
+
temp_bug_fix=config['coarse']['temp_bug_fix'],
|
27 |
+
npe=config['coarse']['npe'],
|
28 |
+
)
|
29 |
+
if self.config['coarse']['abspe']:
|
30 |
+
self.pos_encoding = PositionEncodingSine(
|
31 |
+
config['coarse']['d_model'],
|
32 |
+
temp_bug_fix=config['coarse']['temp_bug_fix'],
|
33 |
+
npe=config['coarse']['npe'],
|
34 |
+
)
|
35 |
+
|
36 |
+
if self.config['coarse']['skip'] is False:
|
37 |
+
self.loftr_coarse = LocalFeatureTransformer(config)
|
38 |
+
self.coarse_matching = CoarseMatching(config['match_coarse'])
|
39 |
+
# self.fine_preprocess = FinePreprocess(config).float()
|
40 |
+
self.fine_preprocess = FinePreprocess(config)
|
41 |
+
if self.config['fine']['skip'] is False:
|
42 |
+
self.loftr_fine = LocalFeatureTransformer(config["fine"])
|
43 |
+
self.fine_matching = FineMatching(config)
|
44 |
+
|
45 |
+
def forward(self, data):
|
46 |
+
"""
|
47 |
+
Update:
|
48 |
+
data (dict): {
|
49 |
+
'image0': (torch.Tensor): (N, 1, H, W)
|
50 |
+
'image1': (torch.Tensor): (N, 1, H, W)
|
51 |
+
'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
|
52 |
+
'mask1'(optional) : (torch.Tensor): (N, H, W)
|
53 |
+
}
|
54 |
+
"""
|
55 |
+
# 1. Local Feature CNN
|
56 |
+
data.update({
|
57 |
+
'bs': data['image0'].size(0),
|
58 |
+
'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
|
59 |
+
})
|
60 |
+
|
61 |
+
if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
|
62 |
+
# feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
|
63 |
+
ret_dict = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
|
64 |
+
feats_c, feats_f = ret_dict['feats_c'], ret_dict['feats_f']
|
65 |
+
if self.config['inter_feat']:
|
66 |
+
data.update({
|
67 |
+
'feats_x2': ret_dict['feats_x2'],
|
68 |
+
'feats_x1': ret_dict['feats_x1'],
|
69 |
+
})
|
70 |
+
if self.config['coarse_feat_only']:
|
71 |
+
(feat_c0, feat_c1) = feats_c.split(data['bs'])
|
72 |
+
feat_f0, feat_f1 = None, None
|
73 |
+
else:
|
74 |
+
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
|
75 |
+
else: # handle different input shapes
|
76 |
+
# (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
|
77 |
+
ret_dict0, ret_dict1 = self.backbone(data['image0']), self.backbone(data['image1'])
|
78 |
+
feat_c0, feat_f0 = ret_dict0['feats_c'], ret_dict0['feats_f']
|
79 |
+
feat_c1, feat_f1 = ret_dict1['feats_c'], ret_dict1['feats_f']
|
80 |
+
if self.config['inter_feat']:
|
81 |
+
data.update({
|
82 |
+
'feats_x2_0': ret_dict0['feats_x2'],
|
83 |
+
'feats_x1_0': ret_dict0['feats_x1'],
|
84 |
+
'feats_x2_1': ret_dict1['feats_x2'],
|
85 |
+
'feats_x1_1': ret_dict1['feats_x1'],
|
86 |
+
})
|
87 |
+
if self.config['coarse_feat_only']:
|
88 |
+
feat_f0, feat_f1 = None, None
|
89 |
+
|
90 |
+
|
91 |
+
mul = self.config['resolution'][0] // self.config['resolution'][1]
|
92 |
+
# mul = 4
|
93 |
+
if self.config['fix_bias']:
|
94 |
+
data.update({
|
95 |
+
'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
|
96 |
+
'hw0_f': feat_f0.shape[2:] if feat_f0 is not None else [(feat_c0.shape[2]-1) * mul+1, (feat_c0.shape[3]-1) * mul+1] ,
|
97 |
+
'hw1_f': feat_f1.shape[2:] if feat_f1 is not None else [(feat_c1.shape[2]-1) * mul+1, (feat_c1.shape[3]-1) * mul+1]
|
98 |
+
})
|
99 |
+
else:
|
100 |
+
data.update({
|
101 |
+
'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
|
102 |
+
'hw0_f': feat_f0.shape[2:] if feat_f0 is not None else [feat_c0.shape[2] * mul, feat_c0.shape[3] * mul] ,
|
103 |
+
'hw1_f': feat_f1.shape[2:] if feat_f1 is not None else [feat_c1.shape[2] * mul, feat_c1.shape[3] * mul]
|
104 |
+
})
|
105 |
+
|
106 |
+
# 2. coarse-level loftr module
|
107 |
+
# add featmap with positional encoding, then flatten it to sequence [N, HW, C]
|
108 |
+
if self.config['coarse']['skip']:
|
109 |
+
mask_c0 = mask_c1 = None # mask is useful in training
|
110 |
+
if 'mask0' in data:
|
111 |
+
mask_c0, mask_c1 = data['mask0'], data['mask1']
|
112 |
+
feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
|
113 |
+
feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
|
114 |
+
|
115 |
+
elif self.config['coarse']['pan']:
|
116 |
+
# assert feat_c0.shape[0] == 1, 'batch size must be 1 when using mask Xformer now'
|
117 |
+
if self.config['coarse']['abspe']:
|
118 |
+
feat_c0 = self.pos_encoding(feat_c0)
|
119 |
+
feat_c1 = self.pos_encoding(feat_c1)
|
120 |
+
|
121 |
+
mask_c0 = mask_c1 = None # mask is useful in training
|
122 |
+
if 'mask0' in data:
|
123 |
+
mask_c0, mask_c1 = data['mask0'], data['mask1']
|
124 |
+
if self.config['matchability']: # else match in loftr_coarse
|
125 |
+
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1, data=data)
|
126 |
+
else:
|
127 |
+
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
|
128 |
+
|
129 |
+
feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
|
130 |
+
feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
|
131 |
+
else:
|
132 |
+
if not (self.config['coarse']['rope'] or self.config['coarse']['token_mixer'] is not None):
|
133 |
+
feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
|
134 |
+
feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
|
135 |
+
|
136 |
+
mask_c0 = mask_c1 = None # mask is useful in training
|
137 |
+
if self.config['coarse']['rope']:
|
138 |
+
if 'mask0' in data:
|
139 |
+
mask_c0, mask_c1 = data['mask0'], data['mask1']
|
140 |
+
else:
|
141 |
+
if 'mask0' in data:
|
142 |
+
mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
|
143 |
+
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
|
144 |
+
if self.config['coarse']['rope']:
|
145 |
+
feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
|
146 |
+
feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
|
147 |
+
|
148 |
+
# detect nan
|
149 |
+
if self.config['replace_nan'] and (torch.any(torch.isnan(feat_c0)) or torch.any(torch.isnan(feat_c1))):
|
150 |
+
logger.info(f'replace nan in coarse attention')
|
151 |
+
logger.info(f"feat_c0_nan_num: {torch.isnan(feat_c0).int().sum()}, feat_c1_nan_num: {torch.isnan(feat_c1).int().sum()}")
|
152 |
+
logger.info(f"feat_c0: {feat_c0}, feat_c1: {feat_c1}")
|
153 |
+
logger.info(f"feat_c0_max: {feat_c0.abs().max()}, feat_c1_max: {feat_c1.abs().max()}")
|
154 |
+
feat_c0[torch.isnan(feat_c0)] = 0
|
155 |
+
feat_c1[torch.isnan(feat_c1)] = 0
|
156 |
+
logger.info(f"feat_c0_nanmax: {feat_c0.abs().max()}, feat_c1_nanmax: {feat_c1.abs().max()}")
|
157 |
+
|
158 |
+
# 3. match coarse-level
|
159 |
+
if not self.config['matchability']: # else match in loftr_coarse
|
160 |
+
self.coarse_matching(feat_c0, feat_c1, data,
|
161 |
+
mask_c0=mask_c0.view(mask_c0.size(0), -1) if mask_c0 is not None else mask_c0,
|
162 |
+
mask_c1=mask_c1.view(mask_c1.size(0), -1) if mask_c1 is not None else mask_c1
|
163 |
+
)
|
164 |
+
|
165 |
+
#return data['conf_matrix'],feat_c0,feat_c1,data['feats_x2'],data['feats_x1']
|
166 |
+
|
167 |
+
# norm FPNfeat
|
168 |
+
if self.config['norm_fpnfeat']:
|
169 |
+
feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
|
170 |
+
[feat_c0, feat_c1])
|
171 |
+
if self.config['norm_fpnfeat2']:
|
172 |
+
assert self.config['inter_feat']
|
173 |
+
logger.info(f'before norm_fpnfeat2 max of feat_c0, feat_c1:{feat_c0.abs().max()}, {feat_c1.abs().max()}')
|
174 |
+
if data['hw0_i'] == data['hw1_i']:
|
175 |
+
logger.info(f'before norm_fpnfeat2 max of data[feats_x2], data[feats_x1]:{data["feats_x2"].abs().max()}, {data["feats_x1"].abs().max()}')
|
176 |
+
feat_c0, feat_c1, data['feats_x2'], data['feats_x1'] = map(lambda feat: feat / feat.shape[-1]**.5,
|
177 |
+
[feat_c0, feat_c1, data['feats_x2'], data['feats_x1']])
|
178 |
+
else:
|
179 |
+
feat_c0, feat_c1, data['feats_x2_0'], data['feats_x2_1'], data['feats_x1_0'], data['feats_x1_1'] = map(lambda feat: feat / feat.shape[-1]**.5,
|
180 |
+
[feat_c0, feat_c1, data['feats_x2_0'], data['feats_x2_1'], data['feats_x1_0'], data['feats_x1_1']])
|
181 |
+
|
182 |
+
|
183 |
+
# 4. fine-level refinement
|
184 |
+
with torch.autocast(enabled=False, device_type="cuda"):
|
185 |
+
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
|
186 |
+
|
187 |
+
# detect nan
|
188 |
+
if self.config['replace_nan'] and (torch.any(torch.isnan(feat_f0_unfold)) or torch.any(torch.isnan(feat_f1_unfold))):
|
189 |
+
logger.info(f'replace nan in fine_preprocess')
|
190 |
+
logger.info(f"feat_f0_unfold_nan_num: {torch.isnan(feat_f0_unfold).int().sum()}, feat_f1_unfold_nan_num: {torch.isnan(feat_f1_unfold).int().sum()}")
|
191 |
+
logger.info(f"feat_f0_unfold: {feat_f0_unfold}, feat_f1_unfold: {feat_f1_unfold}")
|
192 |
+
logger.info(f"feat_f0_unfold_max: {feat_f0_unfold}, feat_f1_unfold_max: {feat_f1_unfold}")
|
193 |
+
feat_f0_unfold[torch.isnan(feat_f0_unfold)] = 0
|
194 |
+
feat_f1_unfold[torch.isnan(feat_f1_unfold)] = 0
|
195 |
+
logger.info(f"feat_f0_unfold_nanmax: {feat_f0_unfold}, feat_f1_unfold_nanmax: {feat_f1_unfold}")
|
196 |
+
|
197 |
+
if self.config['fp16log'] and feat_c0 is not None:
|
198 |
+
logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}")
|
199 |
+
del feat_c0, feat_c1, mask_c0, mask_c1
|
200 |
+
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
|
201 |
+
if self.config['fine']['pan']:
|
202 |
+
m, ww, c = feat_f0_unfold.size() # [m, ww, c]
|
203 |
+
w = self.config['fine_window_size']
|
204 |
+
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold.reshape(m, c, w, w), feat_f1_unfold.reshape(m, c, w, w))
|
205 |
+
feat_f0_unfold = rearrange(feat_f0_unfold, 'm c w h -> m (w h) c')
|
206 |
+
feat_f1_unfold = rearrange(feat_f1_unfold, 'm c w h -> m (w h) c')
|
207 |
+
elif self.config['fine']['skip']:
|
208 |
+
pass
|
209 |
+
else:
|
210 |
+
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
|
211 |
+
# 5. match fine-level
|
212 |
+
# log forward nan
|
213 |
+
if self.config['fp16log']:
|
214 |
+
if feat_f0_unfold.size(0) != 0 and feat_f0 is not None:
|
215 |
+
logger.info(f"f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}, uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}")
|
216 |
+
elif feat_f0_unfold.size(0) != 0:
|
217 |
+
logger.info(f"uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}")
|
218 |
+
# elif feat_c0 is not None:
|
219 |
+
# logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}")
|
220 |
+
|
221 |
+
with torch.autocast(enabled=False, device_type="cuda"):
|
222 |
+
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
|
223 |
+
|
224 |
+
return data
|
225 |
+
|
226 |
+
def load_state_dict(self, state_dict, *args, **kwargs):
|
227 |
+
for k in list(state_dict.keys()):
|
228 |
+
if k.startswith('matcher.'):
|
229 |
+
state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
|
230 |
+
return super().load_state_dict(state_dict, *args, **kwargs)
|
231 |
+
|
232 |
+
def refine(self, data):
|
233 |
+
"""
|
234 |
+
Update:
|
235 |
+
data (dict): {
|
236 |
+
'image0': (torch.Tensor): (N, 1, H, W)
|
237 |
+
'image1': (torch.Tensor): (N, 1, H, W)
|
238 |
+
'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
|
239 |
+
'mask1'(optional) : (torch.Tensor): (N, H, W)
|
240 |
+
}
|
241 |
+
"""
|
242 |
+
# 1. Local Feature CNN
|
243 |
+
data.update({
|
244 |
+
'bs': data['image0'].size(0),
|
245 |
+
'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
|
246 |
+
})
|
247 |
+
feat_f0, feat_f1 = None, None
|
248 |
+
feat_c0, feat_c1 = data['feat_c0'], data['feat_c1']
|
249 |
+
# 4. fine-level refinement
|
250 |
+
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
|
251 |
+
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
|
252 |
+
if self.config['fine']['pan']:
|
253 |
+
m, ww, c = feat_f0_unfold.size() # [m, ww, c]
|
254 |
+
w = self.config['fine_window_size']
|
255 |
+
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold.reshape(m, c, w, w), feat_f1_unfold.reshape(m, c, w, w))
|
256 |
+
feat_f0_unfold = rearrange(feat_f0_unfold, 'm c w h -> m (w h) c')
|
257 |
+
feat_f1_unfold = rearrange(feat_f1_unfold, 'm c w h -> m (w h) c')
|
258 |
+
elif self.config['fine']['skip']:
|
259 |
+
pass
|
260 |
+
else:
|
261 |
+
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
|
262 |
+
# 5. match fine-level
|
263 |
+
# log forward nan
|
264 |
+
if self.config['fp16log']:
|
265 |
+
if feat_f0_unfold.size(0) != 0 and feat_f0 is not None and feat_c0 is not None:
|
266 |
+
logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}, f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}, uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}")
|
267 |
+
elif feat_f0 is not None and feat_c0 is not None:
|
268 |
+
logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}, f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}")
|
269 |
+
elif feat_c0 is not None:
|
270 |
+
logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}")
|
271 |
+
|
272 |
+
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
|
273 |
+
return data
|
imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .transformer import LocalFeatureTransformer
|
2 |
+
from .fine_preprocess import FinePreprocess
|
imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops.einops import rearrange, repeat
|
5 |
+
from ..backbone.repvgg import RepVGGBlock
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
10 |
+
"""1x1 convolution without padding"""
|
11 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
|
12 |
+
|
13 |
+
|
14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
15 |
+
"""3x3 convolution with padding"""
|
16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
17 |
+
|
18 |
+
class FinePreprocess(nn.Module):
|
19 |
+
def __init__(self, config):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.config = config
|
23 |
+
self.cat_c_feat = config['fine_concat_coarse_feat']
|
24 |
+
self.sample_c_feat = config['fine_sample_coarse_feat']
|
25 |
+
self.fpn_inter_feat = config['inter_feat']
|
26 |
+
self.rep_fpn = config['rep_fpn']
|
27 |
+
self.deploy = config['rep_deploy']
|
28 |
+
self.multi_regress = config['match_fine']['multi_regress']
|
29 |
+
self.local_regress = config['match_fine']['local_regress']
|
30 |
+
self.local_regress_inner = config['match_fine']['local_regress_inner']
|
31 |
+
block_dims = config['resnetfpn']['block_dims']
|
32 |
+
|
33 |
+
self.mtd_spvs = self.config['fine']['mtd_spvs']
|
34 |
+
self.align_corner = self.config['align_corner']
|
35 |
+
self.fix_bias = self.config['fix_bias']
|
36 |
+
|
37 |
+
if self.mtd_spvs:
|
38 |
+
self.W = self.config['fine_window_size']
|
39 |
+
else:
|
40 |
+
# assert False, 'fine_window_matching_size to be revised' # good notification!
|
41 |
+
# self.W = self.config['fine_window_matching_size']
|
42 |
+
self.W = self.config['fine_window_size']
|
43 |
+
|
44 |
+
self.backbone_type = self.config['backbone_type']
|
45 |
+
|
46 |
+
d_model_c = self.config['coarse']['d_model']
|
47 |
+
d_model_f = self.config['fine']['d_model']
|
48 |
+
self.d_model_f = d_model_f
|
49 |
+
if self.fpn_inter_feat:
|
50 |
+
if self.rep_fpn:
|
51 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
52 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
53 |
+
self.layer2_outconv2 = []
|
54 |
+
self.layer2_outconv2.append(RepVGGBlock(in_channels=block_dims[2], out_channels=block_dims[2], kernel_size=3,
|
55 |
+
stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=0.01))
|
56 |
+
self.layer2_outconv2.append(RepVGGBlock(in_channels=block_dims[2], out_channels=block_dims[1], kernel_size=3,
|
57 |
+
stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=-2))
|
58 |
+
self.layer2_outconv2 = nn.ModuleList(self.layer2_outconv2)
|
59 |
+
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
60 |
+
self.layer1_outconv2 = []
|
61 |
+
self.layer1_outconv2.append(RepVGGBlock(in_channels=block_dims[1], out_channels=block_dims[1], kernel_size=3,
|
62 |
+
stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=0.01))
|
63 |
+
self.layer1_outconv2.append(RepVGGBlock(in_channels=block_dims[1], out_channels=block_dims[0], kernel_size=3,
|
64 |
+
stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=-2))
|
65 |
+
self.layer1_outconv2 = nn.ModuleList(self.layer1_outconv2)
|
66 |
+
|
67 |
+
else:
|
68 |
+
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
69 |
+
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
70 |
+
self.layer2_outconv2 = nn.Sequential(
|
71 |
+
conv3x3(block_dims[2], block_dims[2]),
|
72 |
+
nn.BatchNorm2d(block_dims[2]),
|
73 |
+
nn.LeakyReLU(),
|
74 |
+
conv3x3(block_dims[2], block_dims[1]),
|
75 |
+
)
|
76 |
+
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
77 |
+
self.layer1_outconv2 = nn.Sequential(
|
78 |
+
conv3x3(block_dims[1], block_dims[1]),
|
79 |
+
nn.BatchNorm2d(block_dims[1]),
|
80 |
+
nn.LeakyReLU(),
|
81 |
+
conv3x3(block_dims[1], block_dims[0]),
|
82 |
+
)
|
83 |
+
elif self.cat_c_feat:
|
84 |
+
self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
|
85 |
+
self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
|
86 |
+
if self.sample_c_feat:
|
87 |
+
self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
|
88 |
+
|
89 |
+
|
90 |
+
self._reset_parameters()
|
91 |
+
|
92 |
+
def _reset_parameters(self):
|
93 |
+
for p in self.parameters():
|
94 |
+
if p.dim() > 1:
|
95 |
+
nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
|
96 |
+
|
97 |
+
def inter_fpn(self, feat_c, x2, x1, stride):
|
98 |
+
feat_c = self.layer3_outconv(feat_c)
|
99 |
+
feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False)
|
100 |
+
x2 = self.layer2_outconv(x2)
|
101 |
+
if self.rep_fpn:
|
102 |
+
x2 = x2 + feat_c
|
103 |
+
for layer in self.layer2_outconv2:
|
104 |
+
x2 = layer(x2)
|
105 |
+
else:
|
106 |
+
x2 = self.layer2_outconv2(x2+feat_c)
|
107 |
+
|
108 |
+
x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False)
|
109 |
+
x1 = self.layer1_outconv(x1)
|
110 |
+
if self.rep_fpn:
|
111 |
+
x1 = x1 + x2
|
112 |
+
for layer in self.layer1_outconv2:
|
113 |
+
x1 = layer(x1)
|
114 |
+
else:
|
115 |
+
x1 = self.layer1_outconv2(x1+x2)
|
116 |
+
|
117 |
+
if stride == 4:
|
118 |
+
logger.info('stride == 4')
|
119 |
+
|
120 |
+
elif stride == 8:
|
121 |
+
logger.info('stride == 8')
|
122 |
+
x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False)
|
123 |
+
else:
|
124 |
+
logger.info('stride not in {4,8}')
|
125 |
+
assert False
|
126 |
+
return x1
|
127 |
+
|
128 |
+
def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
|
129 |
+
W = self.W
|
130 |
+
if self.fix_bias:
|
131 |
+
stride = 4
|
132 |
+
else:
|
133 |
+
stride = data['hw0_f'][0] // data['hw0_c'][0]
|
134 |
+
|
135 |
+
data.update({'W': W})
|
136 |
+
if data['b_ids'].shape[0] == 0:
|
137 |
+
feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_c0.device)
|
138 |
+
feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_c0.device)
|
139 |
+
# return feat0, feat1
|
140 |
+
return feat0.float(), feat1.float()
|
141 |
+
|
142 |
+
if self.fpn_inter_feat:
|
143 |
+
if data['hw0_i'] != data['hw1_i']:
|
144 |
+
if self.align_corner is False:
|
145 |
+
assert self.backbone_type != 's2dnet'
|
146 |
+
|
147 |
+
feat_c0 = rearrange(feat_c0, 'b (h w) c -> b c h w', h=data['hw0_c'][0])
|
148 |
+
feat_c1 = rearrange(feat_c1, 'b (h w) c -> b c h w', h=data['hw1_c'][0])
|
149 |
+
x2_0, x1_0 = data['feats_x2_0'], data['feats_x1_0']
|
150 |
+
x2_1, x1_1 = data['feats_x2_1'], data['feats_x1_1']
|
151 |
+
del data['feats_x2_0'], data['feats_x1_0'], data['feats_x2_1'], data['feats_x1_1']
|
152 |
+
feat_f0, feat_f1 = self.inter_fpn(feat_c0, x2_0, x1_0, stride), self.inter_fpn(feat_c1, x2_1, x1_1, stride)
|
153 |
+
|
154 |
+
if self.local_regress_inner:
|
155 |
+
assert W == 8
|
156 |
+
feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
|
157 |
+
feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
|
158 |
+
feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1)
|
159 |
+
feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2)
|
160 |
+
elif W == 10 and self.multi_regress:
|
161 |
+
feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
|
162 |
+
feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
|
163 |
+
feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
|
164 |
+
feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
|
165 |
+
elif W == 10:
|
166 |
+
feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
|
167 |
+
feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
|
168 |
+
feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
|
169 |
+
feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
|
170 |
+
else:
|
171 |
+
assert not self.multi_regress
|
172 |
+
feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
|
173 |
+
feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
|
174 |
+
feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0)
|
175 |
+
feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
|
176 |
+
|
177 |
+
# 2. select only the predicted matches
|
178 |
+
feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf]
|
179 |
+
feat_f1 = feat_f1[data['b_ids'], data['j_ids']]
|
180 |
+
|
181 |
+
return feat_f0, feat_f1
|
182 |
+
|
183 |
+
else:
|
184 |
+
if self.align_corner is False:
|
185 |
+
feat_c = torch.cat([feat_c0, feat_c1], 0)
|
186 |
+
feat_c = rearrange(feat_c, 'b (h w) c -> b c h w', h=data['hw0_c'][0]) # 1/8 256
|
187 |
+
x2 = data['feats_x2'].float() # 1/4 128
|
188 |
+
x1 = data['feats_x1'].float() # 1/2 64
|
189 |
+
del data['feats_x2'], data['feats_x1']
|
190 |
+
assert self.backbone_type != 's2dnet'
|
191 |
+
feat_c = self.layer3_outconv(feat_c)
|
192 |
+
feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False)
|
193 |
+
x2 = self.layer2_outconv(x2)
|
194 |
+
if self.rep_fpn:
|
195 |
+
x2 = x2 + feat_c
|
196 |
+
for layer in self.layer2_outconv2:
|
197 |
+
x2 = layer(x2)
|
198 |
+
else:
|
199 |
+
x2 = self.layer2_outconv2(x2+feat_c)
|
200 |
+
|
201 |
+
x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False)
|
202 |
+
x1 = self.layer1_outconv(x1)
|
203 |
+
if self.rep_fpn:
|
204 |
+
x1 = x1 + x2
|
205 |
+
for layer in self.layer1_outconv2:
|
206 |
+
x1 = layer(x1)
|
207 |
+
else:
|
208 |
+
x1 = self.layer1_outconv2(x1+x2)
|
209 |
+
|
210 |
+
if stride == 4:
|
211 |
+
# logger.info('stride == 4')
|
212 |
+
pass
|
213 |
+
elif stride == 8:
|
214 |
+
# logger.info('stride == 8')
|
215 |
+
x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False)
|
216 |
+
else:
|
217 |
+
# logger.info('stride not in {4,8}')
|
218 |
+
assert False
|
219 |
+
|
220 |
+
feat_f0, feat_f1 = torch.chunk(x1, 2, dim=0)
|
221 |
+
|
222 |
+
# 1. unfold(crop) all local windows
|
223 |
+
if self.local_regress_inner:
|
224 |
+
assert W == 8
|
225 |
+
feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
|
226 |
+
feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
|
227 |
+
feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1)
|
228 |
+
feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2)
|
229 |
+
elif self.multi_regress or (self.local_regress and W == 10):
|
230 |
+
feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
|
231 |
+
feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
|
232 |
+
feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
|
233 |
+
feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
|
234 |
+
elif W == 10:
|
235 |
+
feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
|
236 |
+
feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
|
237 |
+
feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
|
238 |
+
feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
|
239 |
+
|
240 |
+
else:
|
241 |
+
feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
|
242 |
+
feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
|
243 |
+
feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0)
|
244 |
+
feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
|
245 |
+
|
246 |
+
# 2. select only the predicted matches
|
247 |
+
feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf]
|
248 |
+
feat_f1 = feat_f1[data['b_ids'], data['j_ids']]
|
249 |
+
|
250 |
+
return feat_f0, feat_f1
|
251 |
+
elif self.fix_bias:
|
252 |
+
feat_c = torch.cat([feat_c0, feat_c1], 0)
|
253 |
+
feat_c = rearrange(feat_c, 'b (h w) c -> b c h w', h=data['hw0_c'][0])
|
254 |
+
x2 = data['feats_x2'].float()
|
255 |
+
x1 = data['feats_x1'].float()
|
256 |
+
assert self.backbone_type != 's2dnet'
|
257 |
+
x3_out = self.layer3_outconv(feat_c)
|
258 |
+
x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=False)
|
259 |
+
x2 = self.layer2_outconv(x2)
|
260 |
+
x2 = self.layer2_outconv2(x2+x3_out_2x)
|
261 |
+
|
262 |
+
x2 = F.interpolate(x2, size=((x2.size(-2)-1)*2+1, (x2.size(-1)-1)*2+1), mode='bilinear', align_corners=False)
|
263 |
+
x1_out = self.layer1_outconv(x1)
|
264 |
+
x1_out = self.layer1_outconv2(x1_out+x2)
|
265 |
+
x0_out = x1_out
|
266 |
+
|
267 |
+
feat_f0, feat_f1 = torch.chunk(x0_out, 2, dim=0)
|
268 |
+
|
269 |
+
# 1. unfold(crop) all local windows
|
270 |
+
feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
|
271 |
+
feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
272 |
+
feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
|
273 |
+
feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
274 |
+
|
275 |
+
# 2. select only the predicted matches
|
276 |
+
feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
|
277 |
+
feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
|
278 |
+
|
279 |
+
return feat_f0_unfold, feat_f1_unfold
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
elif self.sample_c_feat:
|
284 |
+
if self.align_corner is False:
|
285 |
+
# easy implemented but memory consuming
|
286 |
+
feat_c = self.down_proj(torch.cat([feat_c0,
|
287 |
+
feat_c1], 0)) # [n, (h w), c] -> [2n, (h w), cf]
|
288 |
+
feat_c = rearrange(feat_c, 'n (h w) c -> n c h w', h=data['hw0_c'][0], w=data['hw0_c'][1])
|
289 |
+
feat_f = F.interpolate(feat_c, scale_factor=8., mode='bilinear', align_corners=False) # [2n, cf, hf, wf]
|
290 |
+
feat_f_unfold = F.unfold(feat_f, kernel_size=(W, W), stride=stride, padding=0)
|
291 |
+
feat_f_unfold = rearrange(feat_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
292 |
+
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_f_unfold, 2, dim=0)
|
293 |
+
feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [m, ww, cf]
|
294 |
+
feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] # [m, ww, cf]
|
295 |
+
# return feat_f0_unfold, feat_f1_unfold
|
296 |
+
return feat_f0_unfold.float(), feat_f1_unfold.float()
|
297 |
+
else:
|
298 |
+
if self.align_corner is False:
|
299 |
+
# 1. unfold(crop) all local windows
|
300 |
+
assert False, 'maybe exist bugs'
|
301 |
+
feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
|
302 |
+
feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
303 |
+
feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0)
|
304 |
+
feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
305 |
+
|
306 |
+
# 2. select only the predicted matches
|
307 |
+
feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
|
308 |
+
feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
|
309 |
+
|
310 |
+
# option: use coarse-level loftr feature as context: concat and linear
|
311 |
+
if self.cat_c_feat:
|
312 |
+
feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
|
313 |
+
feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
|
314 |
+
feat_cf_win = self.merge_feat(torch.cat([
|
315 |
+
torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
|
316 |
+
repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
|
317 |
+
], -1))
|
318 |
+
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
|
319 |
+
|
320 |
+
return feat_f0_unfold, feat_f1_unfold
|
321 |
+
|
322 |
+
else:
|
323 |
+
# 1. unfold(crop) all local windows
|
324 |
+
if self.fix_bias:
|
325 |
+
feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
|
326 |
+
feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
327 |
+
feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
|
328 |
+
feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
329 |
+
else:
|
330 |
+
feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
|
331 |
+
feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
332 |
+
feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
|
333 |
+
feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
334 |
+
|
335 |
+
# 2. select only the predicted matches
|
336 |
+
feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
|
337 |
+
feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
|
338 |
+
|
339 |
+
# option: use coarse-level loftr feature as context: concat and linear
|
340 |
+
if self.cat_c_feat:
|
341 |
+
feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
|
342 |
+
feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
|
343 |
+
feat_cf_win = self.merge_feat(torch.cat([
|
344 |
+
torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
|
345 |
+
repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
|
346 |
+
], -1))
|
347 |
+
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
|
348 |
+
|
349 |
+
# return feat_f0_unfold, feat_f1_unfold
|
350 |
+
return feat_f0_unfold.float(), feat_f1_unfold.float()
|
imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
|
3 |
+
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.nn import Module, Dropout
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
# if hasattr(F, 'scaled_dot_product_attention'):
|
11 |
+
# FLASH_AVAILABLE = True
|
12 |
+
# else: # v100
|
13 |
+
FLASH_AVAILABLE = False
|
14 |
+
# import xformers.ops
|
15 |
+
from ..utils.position_encoding import PositionEncodingSine, RoPEPositionEncodingSine
|
16 |
+
from einops.einops import rearrange
|
17 |
+
from loguru import logger
|
18 |
+
|
19 |
+
|
20 |
+
# flash_attn_func_ok = True
|
21 |
+
# try:
|
22 |
+
# from flash_attn import flash_attn_func
|
23 |
+
# except ModuleNotFoundError:
|
24 |
+
# flash_attn_func_ok = False
|
25 |
+
|
26 |
+
def elu_feature_map(x):
|
27 |
+
return torch.nn.functional.elu(x) + 1
|
28 |
+
|
29 |
+
|
30 |
+
class LinearAttention(Module):
|
31 |
+
def __init__(self, eps=1e-6):
|
32 |
+
super().__init__()
|
33 |
+
self.feature_map = elu_feature_map
|
34 |
+
self.eps = eps
|
35 |
+
|
36 |
+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
37 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
38 |
+
""" Multi-Head linear attention proposed in "Transformers are RNNs"
|
39 |
+
Args:
|
40 |
+
queries: [N, L, H, D]
|
41 |
+
keys: [N, S, H, D]
|
42 |
+
values: [N, S, H, D]
|
43 |
+
q_mask: [N, L]
|
44 |
+
kv_mask: [N, S]
|
45 |
+
Returns:
|
46 |
+
queried_values: (N, L, H, D)
|
47 |
+
"""
|
48 |
+
Q = self.feature_map(queries)
|
49 |
+
K = self.feature_map(keys)
|
50 |
+
|
51 |
+
# set padded position to zero
|
52 |
+
if q_mask is not None:
|
53 |
+
Q = Q * q_mask[:, :, None, None]
|
54 |
+
if kv_mask is not None:
|
55 |
+
K = K * kv_mask[:, :, None, None]
|
56 |
+
values = values * kv_mask[:, :, None, None]
|
57 |
+
|
58 |
+
v_length = values.size(1)
|
59 |
+
values = values / v_length # prevent fp16 overflow
|
60 |
+
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
|
61 |
+
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
62 |
+
# queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
63 |
+
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
64 |
+
|
65 |
+
return queried_values.contiguous()
|
66 |
+
|
67 |
+
class RoPELinearAttention(Module):
|
68 |
+
def __init__(self, eps=1e-6):
|
69 |
+
super().__init__()
|
70 |
+
self.feature_map = elu_feature_map
|
71 |
+
self.eps = eps
|
72 |
+
self.RoPE = RoPEPositionEncodingSine(256, max_shape=(256, 256))
|
73 |
+
|
74 |
+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
75 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None, H=None, W=None):
|
76 |
+
""" Multi-Head linear attention proposed in "Transformers are RNNs"
|
77 |
+
Args:
|
78 |
+
queries: [N, L, H, D]
|
79 |
+
keys: [N, S, H, D]
|
80 |
+
values: [N, S, H, D]
|
81 |
+
q_mask: [N, L]
|
82 |
+
kv_mask: [N, S]
|
83 |
+
Returns:
|
84 |
+
queried_values: (N, L, H, D)
|
85 |
+
"""
|
86 |
+
Q = self.feature_map(queries)
|
87 |
+
K = self.feature_map(keys)
|
88 |
+
nhead, d = Q.size(2), Q.size(3)
|
89 |
+
# set padded position to zero
|
90 |
+
if q_mask is not None:
|
91 |
+
Q = Q * q_mask[:, :, None, None]
|
92 |
+
if kv_mask is not None:
|
93 |
+
K = K * kv_mask[:, :, None, None]
|
94 |
+
values = values * kv_mask[:, :, None, None]
|
95 |
+
|
96 |
+
v_length = values.size(1)
|
97 |
+
values = values / v_length # prevent fp16 overflow
|
98 |
+
# Q = Q / Q.size(1)
|
99 |
+
# logger.info(f"Q: {Q.dtype}, K: {K.dtype}, values: {values.dtype}")
|
100 |
+
|
101 |
+
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
102 |
+
# logger.info(f"Z_max: {Z.abs().max()}")
|
103 |
+
Q = rearrange(Q, 'n (h w) nhead d -> n h w (nhead d)', h=H, w=W)
|
104 |
+
K = rearrange(K, 'n (h w) nhead d -> n h w (nhead d)', h=H, w=W)
|
105 |
+
Q, K = self.RoPE(Q), self.RoPE(K)
|
106 |
+
# logger.info(f"Q_rope: {Q.abs().max()}, K_rope: {K.abs().max()}")
|
107 |
+
Q = rearrange(Q, 'n h w (nhead d) -> n (h w) nhead d', nhead=nhead, d=d)
|
108 |
+
K = rearrange(K, 'n h w (nhead d) -> n (h w) nhead d', nhead=nhead, d=d)
|
109 |
+
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
|
110 |
+
del K, values
|
111 |
+
# logger.info(f"KV_max: {KV.abs().max()}")
|
112 |
+
# queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
113 |
+
# Q = torch.einsum("nlhd,nlh->nlhd", Q, Z)
|
114 |
+
# logger.info(f"QZ_max: {Q.abs().max()}")
|
115 |
+
# queried_values = torch.einsum("nlhd,nhdv->nlhv", Q, KV) * v_length
|
116 |
+
# logger.info(f"message_max: {queried_values.abs().max()}")
|
117 |
+
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
118 |
+
|
119 |
+
return queried_values.contiguous()
|
120 |
+
|
121 |
+
|
122 |
+
class FullAttention(Module):
|
123 |
+
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
124 |
+
super().__init__()
|
125 |
+
self.use_dropout = use_dropout
|
126 |
+
self.dropout = Dropout(attention_dropout)
|
127 |
+
|
128 |
+
# @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
129 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
130 |
+
""" Multi-head scaled dot-product attention, a.k.a full attention.
|
131 |
+
Args:
|
132 |
+
queries: [N, L, H, D]
|
133 |
+
keys: [N, S, H, D]
|
134 |
+
values: [N, S, H, D]
|
135 |
+
q_mask: [N, L]
|
136 |
+
kv_mask: [N, S]
|
137 |
+
Returns:
|
138 |
+
queried_values: (N, L, H, D)
|
139 |
+
"""
|
140 |
+
# assert kv_mask is None
|
141 |
+
# mask = torch.zeros(queries.size(0)*queries.size(2), queries.size(1), keys.size(1), device=queries.device)
|
142 |
+
# mask.masked_fill(~(q_mask[:, :, None] * kv_mask[:, None, :]), float('-inf'))
|
143 |
+
# if keys.size(1) % 8 != 0:
|
144 |
+
# mask = torch.cat([mask, torch.zeros(queries.size(0)*queries.size(2), queries.size(1), 8-keys.size(1)%8, device=queries.device)], dim=-1)
|
145 |
+
# out = xformers.ops.memory_efficient_attention(queries, keys, values, attn_bias=mask[...,:keys.size(1)])
|
146 |
+
# return out
|
147 |
+
|
148 |
+
# N = queries.size(0)
|
149 |
+
# list_q = [queries[i, :q_mask[i].sum, ...] for i in N]
|
150 |
+
# list_k = [keys[i, :kv_mask[i].sum, ...] for i in N]
|
151 |
+
# list_v = [values[i, :kv_mask[i].sum, ...] for i in N]
|
152 |
+
# assert N == 1
|
153 |
+
# out = xformers.ops.memory_efficient_attention(queries[:,:q_mask.sum(),...], keys[:,:kv_mask.sum(),...], values[:,:kv_mask.sum(),...])
|
154 |
+
# out = torch.cat([out, torch.zeros(out.size(0), queries.size(1)-q_mask.sum(), queries.size(2), queries.size(3), device=queries.device)], dim=1)
|
155 |
+
# return out
|
156 |
+
# Compute the unnormalized attention and apply the masks
|
157 |
+
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
|
158 |
+
if kv_mask is not None:
|
159 |
+
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), -1e5) # float('-inf')
|
160 |
+
|
161 |
+
# Compute the attention and the weighted average
|
162 |
+
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
|
163 |
+
A = torch.softmax(softmax_temp * QK, dim=2)
|
164 |
+
if self.use_dropout:
|
165 |
+
A = self.dropout(A)
|
166 |
+
|
167 |
+
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
|
168 |
+
|
169 |
+
return queried_values.contiguous()
|
170 |
+
|
171 |
+
|
172 |
+
class XAttention(Module):
|
173 |
+
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
174 |
+
super().__init__()
|
175 |
+
self.use_dropout = use_dropout
|
176 |
+
if use_dropout:
|
177 |
+
self.dropout = Dropout(attention_dropout)
|
178 |
+
|
179 |
+
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
180 |
+
""" Multi-head scaled dot-product attention, a.k.a full attention.
|
181 |
+
Args:
|
182 |
+
if FLASH_AVAILABLE: # pytorch scaled_dot_product_attention
|
183 |
+
queries: [N, H, L, D]
|
184 |
+
keys: [N, H, S, D]
|
185 |
+
values: [N, H, S, D]
|
186 |
+
else:
|
187 |
+
queries: [N, L, H, D]
|
188 |
+
keys: [N, S, H, D]
|
189 |
+
values: [N, S, H, D]
|
190 |
+
q_mask: [N, L]
|
191 |
+
kv_mask: [N, S]
|
192 |
+
Returns:
|
193 |
+
queried_values: (N, L, H, D)
|
194 |
+
"""
|
195 |
+
|
196 |
+
assert q_mask is None and kv_mask is None, "already been sliced"
|
197 |
+
if FLASH_AVAILABLE:
|
198 |
+
# args = [x.half().contiguous() for x in [queries, keys, values]]
|
199 |
+
# out = F.scaled_dot_product_attention(*args, attn_mask=mask).to(queries.dtype)
|
200 |
+
args = [x.contiguous() for x in [queries, keys, values]]
|
201 |
+
out = F.scaled_dot_product_attention(*args)
|
202 |
+
else:
|
203 |
+
# if flash_attn_func_ok:
|
204 |
+
# out = flash_attn_func(queries, keys, values)
|
205 |
+
# else:
|
206 |
+
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
|
207 |
+
|
208 |
+
# Compute the attention and the weighted average
|
209 |
+
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
|
210 |
+
A = torch.softmax(softmax_temp * QK, dim=2)
|
211 |
+
|
212 |
+
out = torch.einsum("nlsh,nshd->nlhd", A, values)
|
213 |
+
|
214 |
+
# out = xformers.ops.memory_efficient_attention(queries, keys, values)
|
215 |
+
# out = xformers.ops.memory_efficient_attention(queries[:,:q_mask.sum(),...], keys[:,:kv_mask.sum(),...], values[:,:kv_mask.sum(),...])
|
216 |
+
# out = torch.cat([out, torch.zeros(out.size(0), queries.size(1)-q_mask.sum(), queries.size(2), queries.size(3), device=queries.device)], dim=1)
|
217 |
+
return out
|
imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py
ADDED
@@ -0,0 +1,1768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .linear_attention import LinearAttention, RoPELinearAttention, FullAttention, XAttention
|
6 |
+
from einops.einops import rearrange
|
7 |
+
from collections import OrderedDict
|
8 |
+
from .transformer_utils import TokenConfidence, MatchAssignment, filter_matches
|
9 |
+
from ..utils.coarse_matching import CoarseMatching
|
10 |
+
from ..utils.position_encoding import RoPEPositionEncodingSine
|
11 |
+
import numpy as np
|
12 |
+
from loguru import logger
|
13 |
+
|
14 |
+
PFLASH_AVAILABLE = False
|
15 |
+
|
16 |
+
class PANEncoderLayer(nn.Module):
|
17 |
+
def __init__(self,
|
18 |
+
d_model,
|
19 |
+
nhead,
|
20 |
+
attention='linear',
|
21 |
+
pool_size=4,
|
22 |
+
bn=True,
|
23 |
+
xformer=False,
|
24 |
+
leaky=-1.0,
|
25 |
+
dw_conv=False,
|
26 |
+
scatter=False,
|
27 |
+
):
|
28 |
+
super(PANEncoderLayer, self).__init__()
|
29 |
+
|
30 |
+
self.pool_size = pool_size
|
31 |
+
self.dw_conv = dw_conv
|
32 |
+
self.scatter = scatter
|
33 |
+
if self.dw_conv:
|
34 |
+
self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
|
35 |
+
|
36 |
+
assert not self.scatter, 'buggy implemented here'
|
37 |
+
self.dim = d_model // nhead
|
38 |
+
self.nhead = nhead
|
39 |
+
|
40 |
+
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
|
41 |
+
# multi-head attention
|
42 |
+
if bn:
|
43 |
+
method = 'dw_bn'
|
44 |
+
else:
|
45 |
+
method = 'dw'
|
46 |
+
self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
|
47 |
+
self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
|
48 |
+
self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
|
49 |
+
|
50 |
+
# self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
|
51 |
+
# self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
52 |
+
# self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
53 |
+
if xformer:
|
54 |
+
self.attention = XAttention()
|
55 |
+
else:
|
56 |
+
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
57 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
58 |
+
|
59 |
+
# feed-forward network
|
60 |
+
if leaky > 0:
|
61 |
+
self.mlp = nn.Sequential(
|
62 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
63 |
+
nn.LeakyReLU(leaky, True),
|
64 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
65 |
+
)
|
66 |
+
|
67 |
+
else:
|
68 |
+
self.mlp = nn.Sequential(
|
69 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
70 |
+
nn.ReLU(True),
|
71 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
72 |
+
)
|
73 |
+
|
74 |
+
# norm and dropout
|
75 |
+
self.norm1 = nn.LayerNorm(d_model)
|
76 |
+
self.norm2 = nn.LayerNorm(d_model)
|
77 |
+
|
78 |
+
# self.norm1 = nn.BatchNorm2d(d_model)
|
79 |
+
|
80 |
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
x (torch.Tensor): [N, C, H1, W1]
|
84 |
+
source (torch.Tensor): [N, C, H2, W2]
|
85 |
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
86 |
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
87 |
+
"""
|
88 |
+
bs = x.size(0)
|
89 |
+
H1, W1 = x.size(-2), x.size(-1)
|
90 |
+
H2, W2 = source.size(-2), source.size(-1)
|
91 |
+
|
92 |
+
query, key, value = x, source, source
|
93 |
+
|
94 |
+
if self.dw_conv:
|
95 |
+
query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
|
96 |
+
else:
|
97 |
+
query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
|
98 |
+
# only need to cal key or value...
|
99 |
+
key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
|
100 |
+
value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
|
101 |
+
|
102 |
+
# After 0617 bnorm to prevent permute*6
|
103 |
+
# query = self.norm1(self.max_pool(query))
|
104 |
+
# key = self.norm1(self.max_pool(key))
|
105 |
+
# value = self.norm1(self.max_pool(value))
|
106 |
+
# multi-head attention
|
107 |
+
query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
|
108 |
+
key = self.k_proj_conv(key)
|
109 |
+
value = self.v_proj_conv(value)
|
110 |
+
|
111 |
+
C = query.shape[-3]
|
112 |
+
|
113 |
+
ismask = x_mask is not None and source_mask is not None
|
114 |
+
if bs == 1 or not ismask:
|
115 |
+
if ismask:
|
116 |
+
x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
|
117 |
+
source_mask = self.max_pool(source_mask.float()).bool()
|
118 |
+
|
119 |
+
mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
|
120 |
+
mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
|
121 |
+
|
122 |
+
query = query[:, :, :mask_h0, :mask_w0]
|
123 |
+
key = key[:, :, :mask_h1, :mask_w1]
|
124 |
+
value = value[:, :, :mask_h1, :mask_w1]
|
125 |
+
|
126 |
+
else:
|
127 |
+
assert x_mask is None and source_mask is None
|
128 |
+
|
129 |
+
# query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
|
130 |
+
# key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
|
131 |
+
# value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
|
132 |
+
if PFLASH_AVAILABLE: # N H L D
|
133 |
+
query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
134 |
+
key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
135 |
+
value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
136 |
+
|
137 |
+
else: # N L H D
|
138 |
+
query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
139 |
+
key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
140 |
+
value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
141 |
+
|
142 |
+
message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
|
143 |
+
|
144 |
+
if PFLASH_AVAILABLE: # N H L D
|
145 |
+
message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
146 |
+
|
147 |
+
if ismask:
|
148 |
+
message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
|
149 |
+
if mask_h0 != x_mask.size(-2):
|
150 |
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
|
151 |
+
elif mask_w0 != x_mask.size(-1):
|
152 |
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
|
153 |
+
# message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
|
154 |
+
|
155 |
+
else:
|
156 |
+
assert x_mask is None and source_mask is None
|
157 |
+
|
158 |
+
|
159 |
+
message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
160 |
+
# message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
|
161 |
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
162 |
+
|
163 |
+
if self.scatter:
|
164 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
165 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
166 |
+
# message = self.aggregate(message)
|
167 |
+
message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
|
168 |
+
else:
|
169 |
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
170 |
+
|
171 |
+
# message = self.norm1(message)
|
172 |
+
|
173 |
+
# feed-forward network
|
174 |
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
175 |
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
176 |
+
|
177 |
+
return x + message
|
178 |
+
else:
|
179 |
+
x_mask = self.max_pool(x_mask.float()).bool()
|
180 |
+
source_mask = self.max_pool(source_mask.float()).bool()
|
181 |
+
m_list = []
|
182 |
+
for i in range(bs):
|
183 |
+
mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
|
184 |
+
mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
|
185 |
+
|
186 |
+
q = query[i:i+1, :, :mask_h0, :mask_w0]
|
187 |
+
k = key[i:i+1, :, :mask_h1, :mask_w1]
|
188 |
+
v = value[i:i+1, :, :mask_h1, :mask_w1]
|
189 |
+
|
190 |
+
if PFLASH_AVAILABLE: # N H L D
|
191 |
+
q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
192 |
+
k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
193 |
+
v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
194 |
+
|
195 |
+
else: # N L H D
|
196 |
+
|
197 |
+
q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
198 |
+
k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
199 |
+
v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
200 |
+
|
201 |
+
m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
|
202 |
+
|
203 |
+
if PFLASH_AVAILABLE: # N H L D
|
204 |
+
m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
205 |
+
|
206 |
+
m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
|
207 |
+
if mask_h0 != x_mask.size(-2):
|
208 |
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
|
209 |
+
elif mask_w0 != x_mask.size(-1):
|
210 |
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
|
211 |
+
m_list.append(m)
|
212 |
+
message = torch.cat(m_list, dim=0)
|
213 |
+
|
214 |
+
|
215 |
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
216 |
+
# message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
|
217 |
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
218 |
+
|
219 |
+
if self.scatter:
|
220 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
221 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
222 |
+
# message = self.aggregate(message)
|
223 |
+
# assert False
|
224 |
+
else:
|
225 |
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
226 |
+
|
227 |
+
# message = self.norm1(message)
|
228 |
+
|
229 |
+
# feed-forward network
|
230 |
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
231 |
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
232 |
+
|
233 |
+
return x + message
|
234 |
+
|
235 |
+
|
236 |
+
def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
|
237 |
+
"""
|
238 |
+
Args:
|
239 |
+
x (torch.Tensor): [N, C, H1, W1]
|
240 |
+
source (torch.Tensor): [N, C, H2, W2]
|
241 |
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
242 |
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
243 |
+
"""
|
244 |
+
bs = x.size(0)
|
245 |
+
H1, W1 = x.size(-2), x.size(-1)
|
246 |
+
H2, W2 = source.size(-2), source.size(-1)
|
247 |
+
|
248 |
+
query, key, value = x, source, source
|
249 |
+
|
250 |
+
with profiler.profile("permute*6+norm1*3+max_pool*3"):
|
251 |
+
if self.dw_conv:
|
252 |
+
query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
|
253 |
+
else:
|
254 |
+
query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
|
255 |
+
# only need to cal key or value...
|
256 |
+
key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
|
257 |
+
value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
|
258 |
+
|
259 |
+
with profiler.profile("permute*6"):
|
260 |
+
query = query.permute(0, 2, 3, 1)
|
261 |
+
key = key.permute(0, 2, 3, 1)
|
262 |
+
value = value.permute(0, 2, 3, 1)
|
263 |
+
|
264 |
+
query = query.permute(0,3,1,2)
|
265 |
+
key = key.permute(0,3,1,2)
|
266 |
+
value = value.permute(0,3,1,2)
|
267 |
+
|
268 |
+
# query = self.bnorm1(self.max_pool(query))
|
269 |
+
# key = self.bnorm1(self.max_pool(key))
|
270 |
+
# value = self.bnorm1(self.max_pool(value))
|
271 |
+
# multi-head attention
|
272 |
+
|
273 |
+
with profiler.profile("q_conv+k_conv+v_conv"):
|
274 |
+
query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
|
275 |
+
key = self.k_proj_conv(key)
|
276 |
+
value = self.v_proj_conv(value)
|
277 |
+
|
278 |
+
C = query.shape[-3]
|
279 |
+
# TODO: Need to be consistent with bs=1 (where mask region do not in attention at all)
|
280 |
+
if x_mask is not None and source_mask is not None:
|
281 |
+
x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
|
282 |
+
source_mask = self.max_pool(source_mask.float()).bool()
|
283 |
+
|
284 |
+
mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
|
285 |
+
mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
|
286 |
+
|
287 |
+
query = query[:, :, :mask_h0, :mask_w0]
|
288 |
+
key = key[:, :, :mask_h1, :mask_w1]
|
289 |
+
value = value[:, :, :mask_h1, :mask_w1]
|
290 |
+
|
291 |
+
# mask_h0, mask_w0 = data['mask0'][0].sum(-2)[0], data['mask0'][0].sum(-1)[0]
|
292 |
+
# mask_h1, mask_w1 = data['mask1'][0].sum(-2)[0], data['mask1'][0].sum(-1)[0]
|
293 |
+
# C = feat_c0.shape[-3]
|
294 |
+
# feat_c0 = feat_c0[:, :, :mask_h0, :mask_w0]
|
295 |
+
# feat_c1 = feat_c1[:, :, :mask_h1, :mask_w1]
|
296 |
+
|
297 |
+
|
298 |
+
# feat_c0 = feat_c0.reshape(-1, mask_h0, mask_w0, C)
|
299 |
+
# feat_c1 = feat_c1.reshape(-1, mask_h1, mask_w1, C)
|
300 |
+
# if mask_h0 != data['mask0'].size(-2):
|
301 |
+
# feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0]-mask_h0, data['hw0_c'][1], C, device=feat_c0.device)], dim=1)
|
302 |
+
# elif mask_w0 != data['mask0'].size(-1):
|
303 |
+
# feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0], data['hw0_c'][1]-mask_w0, C, device=feat_c0.device)], dim=2)
|
304 |
+
|
305 |
+
# if mask_h1 != data['mask1'].size(-2):
|
306 |
+
# feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0]-mask_h1, data['hw1_c'][1], C, device=feat_c1.device)], dim=1)
|
307 |
+
# elif mask_w1 != data['mask1'].size(-1):
|
308 |
+
# feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0], data['hw1_c'][1]-mask_w1, C, device=feat_c1.device)], dim=2)
|
309 |
+
|
310 |
+
|
311 |
+
else:
|
312 |
+
assert x_mask is None and source_mask is None
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
# query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
|
317 |
+
# key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
|
318 |
+
# value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
|
319 |
+
|
320 |
+
with profiler.profile("rearrange*3"):
|
321 |
+
query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
322 |
+
key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
323 |
+
value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
324 |
+
|
325 |
+
with profiler.profile("attention"):
|
326 |
+
message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D]
|
327 |
+
|
328 |
+
if x_mask is not None and source_mask is not None:
|
329 |
+
message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
|
330 |
+
if mask_h0 != x_mask.size(-2):
|
331 |
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
|
332 |
+
elif mask_w0 != x_mask.size(-1):
|
333 |
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
|
334 |
+
# message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
|
335 |
+
|
336 |
+
else:
|
337 |
+
assert x_mask is None and source_mask is None
|
338 |
+
|
339 |
+
with profiler.profile("merge"):
|
340 |
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
341 |
+
# message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
|
342 |
+
|
343 |
+
with profiler.profile("rearrange*1"):
|
344 |
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
345 |
+
|
346 |
+
with profiler.profile("upsample"):
|
347 |
+
if self.scatter:
|
348 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
349 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
350 |
+
# message = self.aggregate(message)
|
351 |
+
# assert False
|
352 |
+
else:
|
353 |
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
354 |
+
|
355 |
+
# message = self.norm1(message)
|
356 |
+
|
357 |
+
# feed-forward network
|
358 |
+
with profiler.profile("feed-forward_mlp+permute*2+norm2"):
|
359 |
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
360 |
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
361 |
+
|
362 |
+
return x + message
|
363 |
+
|
364 |
+
|
365 |
+
def _build_projection(self,
|
366 |
+
dim_in,
|
367 |
+
dim_out,
|
368 |
+
kernel_size=3,
|
369 |
+
padding=1,
|
370 |
+
stride=1,
|
371 |
+
method='dw_bn',
|
372 |
+
):
|
373 |
+
if method == 'dw_bn':
|
374 |
+
proj = nn.Sequential(OrderedDict([
|
375 |
+
('conv', nn.Conv2d(
|
376 |
+
dim_in,
|
377 |
+
dim_in,
|
378 |
+
kernel_size=kernel_size,
|
379 |
+
padding=padding,
|
380 |
+
stride=stride,
|
381 |
+
bias=False,
|
382 |
+
groups=dim_in
|
383 |
+
)),
|
384 |
+
('bn', nn.BatchNorm2d(dim_in)),
|
385 |
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
386 |
+
]))
|
387 |
+
elif method == 'avg':
|
388 |
+
proj = nn.Sequential(OrderedDict([
|
389 |
+
('avg', nn.AvgPool2d(
|
390 |
+
kernel_size=kernel_size,
|
391 |
+
padding=padding,
|
392 |
+
stride=stride,
|
393 |
+
ceil_mode=True
|
394 |
+
)),
|
395 |
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
396 |
+
]))
|
397 |
+
elif method == 'linear':
|
398 |
+
proj = None
|
399 |
+
elif method == 'dw':
|
400 |
+
proj = nn.Sequential(OrderedDict([
|
401 |
+
('conv', nn.Conv2d(
|
402 |
+
dim_in,
|
403 |
+
dim_in,
|
404 |
+
kernel_size=kernel_size,
|
405 |
+
padding=padding,
|
406 |
+
stride=stride,
|
407 |
+
bias=False,
|
408 |
+
groups=dim_in
|
409 |
+
)),
|
410 |
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
411 |
+
]))
|
412 |
+
else:
|
413 |
+
raise ValueError('Unknown method ({})'.format(method))
|
414 |
+
|
415 |
+
return proj
|
416 |
+
|
417 |
+
class AG_RoPE_EncoderLayer(nn.Module):
|
418 |
+
def __init__(self,
|
419 |
+
d_model,
|
420 |
+
nhead,
|
421 |
+
attention='linear',
|
422 |
+
pool_size=4,
|
423 |
+
pool_size2=4,
|
424 |
+
xformer=False,
|
425 |
+
leaky=-1.0,
|
426 |
+
dw_conv=False,
|
427 |
+
dw_conv2=False,
|
428 |
+
scatter=False,
|
429 |
+
norm_before=True,
|
430 |
+
rope=False,
|
431 |
+
npe=None,
|
432 |
+
vit_norm=False,
|
433 |
+
dw_proj=False,
|
434 |
+
):
|
435 |
+
super(AG_RoPE_EncoderLayer, self).__init__()
|
436 |
+
|
437 |
+
self.pool_size = pool_size
|
438 |
+
self.pool_size2 = pool_size2
|
439 |
+
self.dw_conv = dw_conv
|
440 |
+
self.dw_conv2 = dw_conv2
|
441 |
+
self.scatter = scatter
|
442 |
+
self.norm_before = norm_before
|
443 |
+
self.vit_norm = vit_norm
|
444 |
+
self.dw_proj = dw_proj
|
445 |
+
self.rope = rope
|
446 |
+
if self.dw_conv and self.pool_size != 1:
|
447 |
+
self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
|
448 |
+
if self.dw_conv2 and self.pool_size2 != 1:
|
449 |
+
self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size2, padding=0, stride=pool_size2, bias=False, groups=d_model)
|
450 |
+
|
451 |
+
self.dim = d_model // nhead
|
452 |
+
self.nhead = nhead
|
453 |
+
|
454 |
+
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size2, stride=self.pool_size2)
|
455 |
+
|
456 |
+
# multi-head attention
|
457 |
+
if self.dw_proj:
|
458 |
+
self.q_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
|
459 |
+
self.k_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
|
460 |
+
self.v_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
|
461 |
+
else:
|
462 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
463 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
464 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
465 |
+
|
466 |
+
if self.rope:
|
467 |
+
self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True)
|
468 |
+
|
469 |
+
if xformer:
|
470 |
+
self.attention = XAttention()
|
471 |
+
else:
|
472 |
+
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
473 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
474 |
+
|
475 |
+
# feed-forward network
|
476 |
+
if leaky > 0:
|
477 |
+
if self.vit_norm:
|
478 |
+
self.mlp = nn.Sequential(
|
479 |
+
nn.Linear(d_model, d_model*2, bias=False),
|
480 |
+
nn.LeakyReLU(leaky, True),
|
481 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
482 |
+
)
|
483 |
+
else:
|
484 |
+
self.mlp = nn.Sequential(
|
485 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
486 |
+
nn.LeakyReLU(leaky, True),
|
487 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
488 |
+
)
|
489 |
+
|
490 |
+
else:
|
491 |
+
if self.vit_norm:
|
492 |
+
self.mlp = nn.Sequential(
|
493 |
+
nn.Linear(d_model, d_model*2, bias=False),
|
494 |
+
nn.ReLU(True),
|
495 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
496 |
+
)
|
497 |
+
else:
|
498 |
+
self.mlp = nn.Sequential(
|
499 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
500 |
+
nn.ReLU(True),
|
501 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
502 |
+
)
|
503 |
+
|
504 |
+
# norm and dropout
|
505 |
+
self.norm1 = nn.LayerNorm(d_model)
|
506 |
+
self.norm2 = nn.LayerNorm(d_model)
|
507 |
+
|
508 |
+
# self.norm1 = nn.BatchNorm2d(d_model)
|
509 |
+
|
510 |
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
511 |
+
"""
|
512 |
+
Args:
|
513 |
+
x (torch.Tensor): [N, C, H1, W1]
|
514 |
+
source (torch.Tensor): [N, C, H2, W2]
|
515 |
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
516 |
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
517 |
+
"""
|
518 |
+
bs, C, H1, W1 = x.size()
|
519 |
+
H2, W2 = source.size(-2), source.size(-1)
|
520 |
+
|
521 |
+
|
522 |
+
if self.norm_before and not self.vit_norm:
|
523 |
+
if self.pool_size == 1:
|
524 |
+
query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
|
525 |
+
elif self.dw_conv:
|
526 |
+
query = self.norm1(self.aggregate(x).permute(0,2,3,1)) # [N, H, W, C]
|
527 |
+
else:
|
528 |
+
query = self.norm1(self.max_pool(x).permute(0,2,3,1)) # [N, H, W, C]
|
529 |
+
if self.pool_size2 == 1:
|
530 |
+
source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
|
531 |
+
elif self.dw_conv2:
|
532 |
+
source = self.norm1(self.aggregate2(source).permute(0,2,3,1)) # [N, H, W, C]
|
533 |
+
else:
|
534 |
+
source = self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C]
|
535 |
+
elif self.vit_norm:
|
536 |
+
if self.pool_size == 1:
|
537 |
+
query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
|
538 |
+
elif self.dw_conv:
|
539 |
+
query = self.aggregate(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
|
540 |
+
else:
|
541 |
+
query = self.max_pool(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
|
542 |
+
if self.pool_size2 == 1:
|
543 |
+
source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
|
544 |
+
elif self.dw_conv2:
|
545 |
+
source = self.aggregate2(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
|
546 |
+
else:
|
547 |
+
source = self.max_pool(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
|
548 |
+
else:
|
549 |
+
if self.pool_size == 1:
|
550 |
+
query = x.permute(0,2,3,1) # [N, H, W, C]
|
551 |
+
elif self.dw_conv:
|
552 |
+
query = self.aggregate(x).permute(0,2,3,1) # [N, H, W, C]
|
553 |
+
else:
|
554 |
+
query = self.max_pool(x).permute(0,2,3,1) # [N, H, W, C]
|
555 |
+
if self.pool_size2 == 1:
|
556 |
+
source = source.permute(0,2,3,1) # [N, H, W, C]
|
557 |
+
elif self.dw_conv2:
|
558 |
+
source = self.aggregate2(source).permute(0,2,3,1) # [N, H, W, C]
|
559 |
+
else:
|
560 |
+
source = self.max_pool(source).permute(0,2,3,1) # [N, H, W, C]
|
561 |
+
|
562 |
+
# projection
|
563 |
+
if self.dw_proj:
|
564 |
+
query = self.q_proj(query.permute(0,3,1,2)).permute(0,2,3,1)
|
565 |
+
key = self.k_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
|
566 |
+
value = self.v_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
|
567 |
+
else:
|
568 |
+
query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source)
|
569 |
+
|
570 |
+
# RoPE
|
571 |
+
if self.rope:
|
572 |
+
query = self.rope_pos_enc(query)
|
573 |
+
if self.pool_size == 1 and self.pool_size2 == 4:
|
574 |
+
key = self.rope_pos_enc(key, 4)
|
575 |
+
else:
|
576 |
+
key = self.rope_pos_enc(key)
|
577 |
+
|
578 |
+
use_mask = x_mask is not None and source_mask is not None
|
579 |
+
if bs == 1 or not use_mask:
|
580 |
+
if use_mask:
|
581 |
+
# downsample mask
|
582 |
+
if self.pool_size ==1:
|
583 |
+
pass
|
584 |
+
else:
|
585 |
+
x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
|
586 |
+
|
587 |
+
if self.pool_size2 ==1:
|
588 |
+
pass
|
589 |
+
else:
|
590 |
+
source_mask = self.max_pool(source_mask.float()).bool()
|
591 |
+
|
592 |
+
mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
|
593 |
+
mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
|
594 |
+
|
595 |
+
query = query[:, :mask_h0, :mask_w0, :]
|
596 |
+
key = key[:, :mask_h1, :mask_w1, :]
|
597 |
+
value = value[:, :mask_h1, :mask_w1, :]
|
598 |
+
else:
|
599 |
+
assert x_mask is None and source_mask is None
|
600 |
+
|
601 |
+
if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
|
602 |
+
query = rearrange(query, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
603 |
+
key = rearrange(key, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
604 |
+
value = rearrange(value, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
605 |
+
else: # N L H D
|
606 |
+
query = rearrange(query, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
607 |
+
key = rearrange(key, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
608 |
+
value = rearrange(value, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
609 |
+
|
610 |
+
message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
|
611 |
+
|
612 |
+
if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
|
613 |
+
message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
614 |
+
|
615 |
+
if use_mask: # padding zero
|
616 |
+
message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L h D]
|
617 |
+
if mask_h0 != x_mask.size(-2):
|
618 |
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
|
619 |
+
elif mask_w0 != x_mask.size(-1):
|
620 |
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
|
621 |
+
else:
|
622 |
+
assert x_mask is None and source_mask is None
|
623 |
+
|
624 |
+
message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
625 |
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
626 |
+
|
627 |
+
if self.pool_size == 1:
|
628 |
+
pass
|
629 |
+
else:
|
630 |
+
if self.scatter:
|
631 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
632 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
633 |
+
message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
|
634 |
+
else:
|
635 |
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
636 |
+
|
637 |
+
if not self.norm_before and not self.vit_norm:
|
638 |
+
message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
|
639 |
+
|
640 |
+
# feed-forward network
|
641 |
+
if self.vit_norm:
|
642 |
+
message_inter = (x + message)
|
643 |
+
del x
|
644 |
+
message = self.norm2(message_inter.permute(0, 2, 3, 1))
|
645 |
+
message = self.mlp(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
646 |
+
return message_inter + message
|
647 |
+
else:
|
648 |
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
649 |
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
650 |
+
|
651 |
+
return x + message
|
652 |
+
else: # mask with bs > 1
|
653 |
+
if self.pool_size ==1:
|
654 |
+
pass
|
655 |
+
else:
|
656 |
+
x_mask = self.max_pool(x_mask.float()).bool()
|
657 |
+
|
658 |
+
if self.pool_size2 ==1:
|
659 |
+
pass
|
660 |
+
else:
|
661 |
+
source_mask = self.max_pool(source_mask.float()).bool()
|
662 |
+
m_list = []
|
663 |
+
for i in range(bs):
|
664 |
+
mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
|
665 |
+
mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
|
666 |
+
|
667 |
+
q = query[i:i+1, :mask_h0, :mask_w0, :]
|
668 |
+
k = key[i:i+1, :mask_h1, :mask_w1, :]
|
669 |
+
v = value[i:i+1, :mask_h1, :mask_w1, :]
|
670 |
+
|
671 |
+
if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
|
672 |
+
q = rearrange(q, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
673 |
+
k = rearrange(k, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
674 |
+
v = rearrange(v, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
675 |
+
else: # N L H D
|
676 |
+
q = rearrange(q, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
677 |
+
k = rearrange(k, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
678 |
+
v = rearrange(v, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
679 |
+
|
680 |
+
m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
|
681 |
+
|
682 |
+
if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
|
683 |
+
m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
684 |
+
|
685 |
+
m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
|
686 |
+
if mask_h0 != x_mask.size(-2):
|
687 |
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
|
688 |
+
elif mask_w0 != x_mask.size(-1):
|
689 |
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
|
690 |
+
m_list.append(m)
|
691 |
+
m = torch.cat(m_list, dim=0)
|
692 |
+
|
693 |
+
m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
694 |
+
# m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
|
695 |
+
m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
696 |
+
|
697 |
+
if self.pool_size == 1:
|
698 |
+
pass
|
699 |
+
else:
|
700 |
+
if self.scatter:
|
701 |
+
m = torch.repeat_interleave(m, self.pool_size, dim=-2)
|
702 |
+
m = torch.repeat_interleave(m, self.pool_size, dim=-1)
|
703 |
+
m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
|
704 |
+
else:
|
705 |
+
m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
706 |
+
|
707 |
+
|
708 |
+
if not self.norm_before and not self.vit_norm:
|
709 |
+
m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
|
710 |
+
|
711 |
+
# feed-forward network
|
712 |
+
if self.vit_norm:
|
713 |
+
m_inter = (x + m)
|
714 |
+
del x
|
715 |
+
m = self.norm2(m_inter.permute(0, 2, 3, 1))
|
716 |
+
m = self.mlp(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
717 |
+
return m_inter + m
|
718 |
+
else:
|
719 |
+
m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
720 |
+
m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
721 |
+
|
722 |
+
return x + m
|
723 |
+
|
724 |
+
return x + m
|
725 |
+
|
726 |
+
class AG_Conv_EncoderLayer(nn.Module):
|
727 |
+
def __init__(self,
|
728 |
+
d_model,
|
729 |
+
nhead,
|
730 |
+
attention='linear',
|
731 |
+
pool_size=4,
|
732 |
+
bn=True,
|
733 |
+
xformer=False,
|
734 |
+
leaky=-1.0,
|
735 |
+
dw_conv=False,
|
736 |
+
dw_conv2=False,
|
737 |
+
scatter=False,
|
738 |
+
norm_before=True,
|
739 |
+
):
|
740 |
+
super(AG_Conv_EncoderLayer, self).__init__()
|
741 |
+
|
742 |
+
self.pool_size = pool_size
|
743 |
+
self.dw_conv = dw_conv
|
744 |
+
self.dw_conv2 = dw_conv2
|
745 |
+
self.scatter = scatter
|
746 |
+
self.norm_before = norm_before
|
747 |
+
if self.dw_conv:
|
748 |
+
self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
|
749 |
+
if self.dw_conv2:
|
750 |
+
self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
|
751 |
+
self.dim = d_model // nhead
|
752 |
+
self.nhead = nhead
|
753 |
+
|
754 |
+
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
|
755 |
+
|
756 |
+
# multi-head attention
|
757 |
+
if bn:
|
758 |
+
method = 'dw_bn'
|
759 |
+
else:
|
760 |
+
method = 'dw'
|
761 |
+
self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
|
762 |
+
self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
|
763 |
+
self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
|
764 |
+
|
765 |
+
if xformer:
|
766 |
+
self.attention = XAttention()
|
767 |
+
else:
|
768 |
+
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
769 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
770 |
+
|
771 |
+
# feed-forward network
|
772 |
+
if leaky > 0:
|
773 |
+
self.mlp = nn.Sequential(
|
774 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
775 |
+
nn.LeakyReLU(leaky, True),
|
776 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
777 |
+
)
|
778 |
+
|
779 |
+
else:
|
780 |
+
self.mlp = nn.Sequential(
|
781 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
782 |
+
nn.ReLU(True),
|
783 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
784 |
+
)
|
785 |
+
|
786 |
+
# norm and dropout
|
787 |
+
self.norm1 = nn.LayerNorm(d_model)
|
788 |
+
self.norm2 = nn.LayerNorm(d_model)
|
789 |
+
|
790 |
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
791 |
+
"""
|
792 |
+
Args:
|
793 |
+
x (torch.Tensor): [N, C, H1, W1]
|
794 |
+
source (torch.Tensor): [N, C, H2, W2]
|
795 |
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
796 |
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
797 |
+
"""
|
798 |
+
bs = x.size(0)
|
799 |
+
H1, W1 = x.size(-2), x.size(-1)
|
800 |
+
H2, W2 = source.size(-2), source.size(-1)
|
801 |
+
C = x.shape[-3]
|
802 |
+
|
803 |
+
if self.norm_before:
|
804 |
+
if self.dw_conv:
|
805 |
+
query = self.norm1(self.aggregate(x).permute(0,2,3,1)).permute(0,3,1,2)
|
806 |
+
else:
|
807 |
+
query = self.norm1(self.max_pool(x).permute(0,2,3,1)).permute(0,3,1,2)
|
808 |
+
if self.dw_conv2:
|
809 |
+
source = self.norm1(self.aggregate2(source).permute(0,2,3,1)).permute(0,3,1,2)
|
810 |
+
else:
|
811 |
+
source = self.norm1(self.max_pool(source).permute(0,2,3,1)).permute(0,3,1,2)
|
812 |
+
else:
|
813 |
+
if self.dw_conv:
|
814 |
+
query = self.aggregate(x)
|
815 |
+
else:
|
816 |
+
query = self.max_pool(x)
|
817 |
+
if self.dw_conv2:
|
818 |
+
source = self.aggregate2(source)
|
819 |
+
else:
|
820 |
+
source = self.max_pool(source)
|
821 |
+
|
822 |
+
key, value = source, source
|
823 |
+
|
824 |
+
query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
|
825 |
+
key = self.k_proj_conv(key)
|
826 |
+
value = self.v_proj_conv(value)
|
827 |
+
|
828 |
+
use_mask = x_mask is not None and source_mask is not None
|
829 |
+
if bs == 1 or not use_mask:
|
830 |
+
if use_mask:
|
831 |
+
x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
|
832 |
+
source_mask = self.max_pool(source_mask.float()).bool()
|
833 |
+
|
834 |
+
mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
|
835 |
+
mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
|
836 |
+
|
837 |
+
query = query[:, :, :mask_h0, :mask_w0]
|
838 |
+
key = key[:, :, :mask_h1, :mask_w1]
|
839 |
+
value = value[:, :, :mask_h1, :mask_w1]
|
840 |
+
|
841 |
+
else:
|
842 |
+
assert x_mask is None and source_mask is None
|
843 |
+
|
844 |
+
if PFLASH_AVAILABLE: # N H L D
|
845 |
+
query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
846 |
+
key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
847 |
+
value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
848 |
+
|
849 |
+
else: # N L H D
|
850 |
+
query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
851 |
+
key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
852 |
+
value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
853 |
+
|
854 |
+
message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
|
855 |
+
|
856 |
+
if PFLASH_AVAILABLE: # N H L D
|
857 |
+
message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
858 |
+
|
859 |
+
if use_mask: # padding zero
|
860 |
+
message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L H D]
|
861 |
+
if mask_h0 != x_mask.size(-2):
|
862 |
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
|
863 |
+
elif mask_w0 != x_mask.size(-1):
|
864 |
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
|
865 |
+
else:
|
866 |
+
assert x_mask is None and source_mask is None
|
867 |
+
|
868 |
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
869 |
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
870 |
+
|
871 |
+
if self.scatter:
|
872 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
873 |
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
874 |
+
message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
|
875 |
+
else:
|
876 |
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
877 |
+
|
878 |
+
if not self.norm_before:
|
879 |
+
message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
|
880 |
+
|
881 |
+
# feed-forward network
|
882 |
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
883 |
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
884 |
+
|
885 |
+
return x + message
|
886 |
+
else: # mask with bs > 1
|
887 |
+
x_mask = self.max_pool(x_mask.float()).bool()
|
888 |
+
source_mask = self.max_pool(source_mask.float()).bool()
|
889 |
+
m_list = []
|
890 |
+
for i in range(bs):
|
891 |
+
mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
|
892 |
+
mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
|
893 |
+
|
894 |
+
q = query[i:i+1, :, :mask_h0, :mask_w0]
|
895 |
+
k = key[i:i+1, :, :mask_h1, :mask_w1]
|
896 |
+
v = value[i:i+1, :, :mask_h1, :mask_w1]
|
897 |
+
|
898 |
+
if PFLASH_AVAILABLE: # N H L D
|
899 |
+
q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
900 |
+
k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
901 |
+
v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
902 |
+
|
903 |
+
else: # N L H D
|
904 |
+
q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
905 |
+
k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
906 |
+
v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
907 |
+
|
908 |
+
m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
|
909 |
+
|
910 |
+
if PFLASH_AVAILABLE: # N H L D
|
911 |
+
m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
912 |
+
|
913 |
+
m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
|
914 |
+
if mask_h0 != x_mask.size(-2):
|
915 |
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
|
916 |
+
elif mask_w0 != x_mask.size(-1):
|
917 |
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
|
918 |
+
m_list.append(m)
|
919 |
+
m = torch.cat(m_list, dim=0)
|
920 |
+
|
921 |
+
m = self.merge(m.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
922 |
+
|
923 |
+
# m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
|
924 |
+
m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
925 |
+
|
926 |
+
if self.scatter:
|
927 |
+
m = torch.repeat_interleave(m, self.pool_size, dim=-2)
|
928 |
+
m = torch.repeat_interleave(m, self.pool_size, dim=-1)
|
929 |
+
m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
|
930 |
+
else:
|
931 |
+
m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
932 |
+
|
933 |
+
if not self.norm_before:
|
934 |
+
m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
|
935 |
+
|
936 |
+
# feed-forward network
|
937 |
+
m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
938 |
+
m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
939 |
+
|
940 |
+
return x + m
|
941 |
+
|
942 |
+
def _build_projection(self,
|
943 |
+
dim_in,
|
944 |
+
dim_out,
|
945 |
+
kernel_size=3,
|
946 |
+
padding=1,
|
947 |
+
stride=1,
|
948 |
+
method='dw_bn',
|
949 |
+
):
|
950 |
+
if method == 'dw_bn':
|
951 |
+
proj = nn.Sequential(OrderedDict([
|
952 |
+
('conv', nn.Conv2d(
|
953 |
+
dim_in,
|
954 |
+
dim_in,
|
955 |
+
kernel_size=kernel_size,
|
956 |
+
padding=padding,
|
957 |
+
stride=stride,
|
958 |
+
bias=False,
|
959 |
+
groups=dim_in
|
960 |
+
)),
|
961 |
+
('bn', nn.BatchNorm2d(dim_in)),
|
962 |
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
963 |
+
]))
|
964 |
+
elif method == 'avg':
|
965 |
+
proj = nn.Sequential(OrderedDict([
|
966 |
+
('avg', nn.AvgPool2d(
|
967 |
+
kernel_size=kernel_size,
|
968 |
+
padding=padding,
|
969 |
+
stride=stride,
|
970 |
+
ceil_mode=True
|
971 |
+
)),
|
972 |
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
973 |
+
]))
|
974 |
+
elif method == 'linear':
|
975 |
+
proj = None
|
976 |
+
elif method == 'dw':
|
977 |
+
proj = nn.Sequential(OrderedDict([
|
978 |
+
('conv', nn.Conv2d(
|
979 |
+
dim_in,
|
980 |
+
dim_in,
|
981 |
+
kernel_size=kernel_size,
|
982 |
+
padding=padding,
|
983 |
+
stride=stride,
|
984 |
+
bias=False,
|
985 |
+
groups=dim_in
|
986 |
+
)),
|
987 |
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
988 |
+
]))
|
989 |
+
else:
|
990 |
+
raise ValueError('Unknown method ({})'.format(method))
|
991 |
+
|
992 |
+
return proj
|
993 |
+
|
994 |
+
|
995 |
+
class RoPELoFTREncoderLayer(nn.Module):
|
996 |
+
def __init__(self,
|
997 |
+
d_model,
|
998 |
+
nhead,
|
999 |
+
attention='linear',
|
1000 |
+
rope=False,
|
1001 |
+
token_mixer=None,
|
1002 |
+
):
|
1003 |
+
super(RoPELoFTREncoderLayer, self).__init__()
|
1004 |
+
|
1005 |
+
self.dim = d_model // nhead
|
1006 |
+
self.nhead = nhead
|
1007 |
+
|
1008 |
+
# multi-head attention
|
1009 |
+
if token_mixer is None:
|
1010 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
1011 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
1012 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
1013 |
+
|
1014 |
+
self.rope = rope
|
1015 |
+
self.token_mixer = None
|
1016 |
+
if token_mixer is not None:
|
1017 |
+
self.token_mixer = token_mixer
|
1018 |
+
if token_mixer == 'dwcn':
|
1019 |
+
self.attention = nn.Sequential(OrderedDict([
|
1020 |
+
('conv', nn.Conv2d(
|
1021 |
+
d_model,
|
1022 |
+
d_model,
|
1023 |
+
kernel_size=3,
|
1024 |
+
padding=1,
|
1025 |
+
stride=1,
|
1026 |
+
bias=False,
|
1027 |
+
groups=d_model
|
1028 |
+
)),
|
1029 |
+
]))
|
1030 |
+
elif self.rope:
|
1031 |
+
assert attention == 'linear'
|
1032 |
+
self.attention = RoPELinearAttention()
|
1033 |
+
|
1034 |
+
if token_mixer is None:
|
1035 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
1036 |
+
|
1037 |
+
# feed-forward network
|
1038 |
+
if token_mixer is None:
|
1039 |
+
self.mlp = nn.Sequential(
|
1040 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
1041 |
+
nn.ReLU(True),
|
1042 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
1043 |
+
)
|
1044 |
+
else:
|
1045 |
+
self.mlp = nn.Sequential(
|
1046 |
+
nn.Linear(d_model, d_model, bias=False),
|
1047 |
+
nn.ReLU(True),
|
1048 |
+
nn.Linear(d_model, d_model, bias=False),
|
1049 |
+
)
|
1050 |
+
# norm and dropout
|
1051 |
+
self.norm1 = nn.LayerNorm(d_model)
|
1052 |
+
self.norm2 = nn.LayerNorm(d_model)
|
1053 |
+
|
1054 |
+
def forward(self, x, source, x_mask=None, source_mask=None, H=None, W=None):
|
1055 |
+
"""
|
1056 |
+
Args:
|
1057 |
+
x (torch.Tensor): [N, L, C]
|
1058 |
+
source (torch.Tensor): [N, L, C]
|
1059 |
+
x_mask (torch.Tensor): [N, L] (optional)
|
1060 |
+
source_mask (torch.Tensor): [N, S] (optional)
|
1061 |
+
"""
|
1062 |
+
bs = x.size(0)
|
1063 |
+
assert H*W == x.size(-2)
|
1064 |
+
|
1065 |
+
# x = rearrange(x, 'n c h w -> n (h w) c')
|
1066 |
+
# source = rearrange(source, 'n c h w -> n (h w) c')
|
1067 |
+
query, key, value = x, source, source
|
1068 |
+
|
1069 |
+
if self.token_mixer is not None:
|
1070 |
+
# multi-head attention
|
1071 |
+
m = self.norm1(x)
|
1072 |
+
m = rearrange(m, 'n (h w) c -> n c h w', h=H, w=W)
|
1073 |
+
m = self.attention(m)
|
1074 |
+
m = rearrange(m, 'n c h w -> n (h w) c')
|
1075 |
+
|
1076 |
+
x = x + m
|
1077 |
+
x = x + self.mlp(self.norm2(x))
|
1078 |
+
return x
|
1079 |
+
else:
|
1080 |
+
# multi-head attention
|
1081 |
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
1082 |
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
1083 |
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
1084 |
+
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask, H=H, W=W) # [N, L, (H, D)]
|
1085 |
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
1086 |
+
message = self.norm1(message)
|
1087 |
+
|
1088 |
+
# feed-forward network
|
1089 |
+
message = self.mlp(torch.cat([x, message], dim=2))
|
1090 |
+
message = self.norm2(message)
|
1091 |
+
|
1092 |
+
return x + message
|
1093 |
+
|
1094 |
+
class LoFTREncoderLayer(nn.Module):
|
1095 |
+
def __init__(self,
|
1096 |
+
d_model,
|
1097 |
+
nhead,
|
1098 |
+
attention='linear',
|
1099 |
+
xformer=False,
|
1100 |
+
):
|
1101 |
+
super(LoFTREncoderLayer, self).__init__()
|
1102 |
+
|
1103 |
+
self.dim = d_model // nhead
|
1104 |
+
self.nhead = nhead
|
1105 |
+
|
1106 |
+
# multi-head attention
|
1107 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
1108 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
1109 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
1110 |
+
|
1111 |
+
if xformer:
|
1112 |
+
self.attention = XAttention()
|
1113 |
+
else:
|
1114 |
+
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
1115 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
1116 |
+
|
1117 |
+
# feed-forward network
|
1118 |
+
self.mlp = nn.Sequential(
|
1119 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
1120 |
+
nn.ReLU(True),
|
1121 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
# norm and dropout
|
1125 |
+
self.norm1 = nn.LayerNorm(d_model)
|
1126 |
+
self.norm2 = nn.LayerNorm(d_model)
|
1127 |
+
|
1128 |
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
1129 |
+
"""
|
1130 |
+
Args:
|
1131 |
+
x (torch.Tensor): [N, L, C]
|
1132 |
+
source (torch.Tensor): [N, S, C]
|
1133 |
+
x_mask (torch.Tensor): [N, L] (optional)
|
1134 |
+
source_mask (torch.Tensor): [N, S] (optional)
|
1135 |
+
"""
|
1136 |
+
bs = x.size(0)
|
1137 |
+
query, key, value = x, source, source
|
1138 |
+
|
1139 |
+
# multi-head attention
|
1140 |
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
1141 |
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
1142 |
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
1143 |
+
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
1144 |
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
1145 |
+
message = self.norm1(message)
|
1146 |
+
|
1147 |
+
# feed-forward network
|
1148 |
+
message = self.mlp(torch.cat([x, message], dim=2))
|
1149 |
+
message = self.norm2(message)
|
1150 |
+
|
1151 |
+
return x + message
|
1152 |
+
|
1153 |
+
def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
|
1154 |
+
"""
|
1155 |
+
Args:
|
1156 |
+
x (torch.Tensor): [N, L, C]
|
1157 |
+
source (torch.Tensor): [N, S, C]
|
1158 |
+
x_mask (torch.Tensor): [N, L] (optional)
|
1159 |
+
source_mask (torch.Tensor): [N, S] (optional)
|
1160 |
+
"""
|
1161 |
+
bs = x.size(0)
|
1162 |
+
query, key, value = x, source, source
|
1163 |
+
|
1164 |
+
# multi-head attention
|
1165 |
+
with profiler.profile("proj*3"):
|
1166 |
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
1167 |
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
1168 |
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
1169 |
+
with profiler.profile("attention"):
|
1170 |
+
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
1171 |
+
with profiler.profile("merge"):
|
1172 |
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
1173 |
+
with profiler.profile("norm1"):
|
1174 |
+
message = self.norm1(message)
|
1175 |
+
|
1176 |
+
# feed-forward network
|
1177 |
+
with profiler.profile("mlp"):
|
1178 |
+
message = self.mlp(torch.cat([x, message], dim=2))
|
1179 |
+
with profiler.profile("norm2"):
|
1180 |
+
message = self.norm2(message)
|
1181 |
+
|
1182 |
+
return x + message
|
1183 |
+
|
1184 |
+
class PANEncoderLayer_cross(nn.Module):
|
1185 |
+
def __init__(self,
|
1186 |
+
d_model,
|
1187 |
+
nhead,
|
1188 |
+
attention='linear',
|
1189 |
+
pool_size=4,
|
1190 |
+
bn=True,
|
1191 |
+
):
|
1192 |
+
super(PANEncoderLayer_cross, self).__init__()
|
1193 |
+
|
1194 |
+
self.pool_size = pool_size
|
1195 |
+
|
1196 |
+
self.dim = d_model // nhead
|
1197 |
+
self.nhead = nhead
|
1198 |
+
|
1199 |
+
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
|
1200 |
+
# multi-head attention
|
1201 |
+
if bn:
|
1202 |
+
method = 'dw_bn'
|
1203 |
+
else:
|
1204 |
+
method = 'dw'
|
1205 |
+
self.qk_proj_conv = self._build_projection(d_model, d_model, method=method)
|
1206 |
+
self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
|
1207 |
+
|
1208 |
+
# self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
|
1209 |
+
# self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
1210 |
+
# self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
1211 |
+
self.attention = FullAttention()
|
1212 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
1213 |
+
|
1214 |
+
# feed-forward network
|
1215 |
+
self.mlp = nn.Sequential(
|
1216 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
1217 |
+
nn.ReLU(True),
|
1218 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
1219 |
+
)
|
1220 |
+
|
1221 |
+
# norm and dropout
|
1222 |
+
self.norm1 = nn.LayerNorm(d_model)
|
1223 |
+
self.norm2 = nn.LayerNorm(d_model)
|
1224 |
+
|
1225 |
+
# self.norm1 = nn.BatchNorm2d(d_model)
|
1226 |
+
|
1227 |
+
def forward(self, x1, x2, x1_mask=None, x2_mask=None):
|
1228 |
+
"""
|
1229 |
+
Args:
|
1230 |
+
x (torch.Tensor): [N, C, H1, W1]
|
1231 |
+
source (torch.Tensor): [N, C, H2, W2]
|
1232 |
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
1233 |
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
1234 |
+
"""
|
1235 |
+
bs = x1.size(0)
|
1236 |
+
H1, W1 = x1.size(-2) // self.pool_size, x1.size(-1) // self.pool_size
|
1237 |
+
H2, W2 = x2.size(-2) // self.pool_size, x2.size(-1) // self.pool_size
|
1238 |
+
|
1239 |
+
query = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
|
1240 |
+
key = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
|
1241 |
+
v2 = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
|
1242 |
+
v1 = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
|
1243 |
+
|
1244 |
+
# multi-head attention
|
1245 |
+
query = self.qk_proj_conv(query) # [N, C, H1//pool, W1//pool]
|
1246 |
+
key = self.qk_proj_conv(key)
|
1247 |
+
v2 = self.v_proj_conv(v2)
|
1248 |
+
v1 = self.v_proj_conv(v1)
|
1249 |
+
|
1250 |
+
C = query.shape[-3]
|
1251 |
+
if x1_mask is not None and x2_mask is not None:
|
1252 |
+
x1_mask = self.max_pool(x1_mask.float()).bool() # [N, H1//pool, W1//pool]
|
1253 |
+
x2_mask = self.max_pool(x2_mask.float()).bool()
|
1254 |
+
|
1255 |
+
mask_h1, mask_w1 = x1_mask[0].sum(-2)[0], x1_mask[0].sum(-1)[0]
|
1256 |
+
mask_h2, mask_w2 = x2_mask[0].sum(-2)[0], x2_mask[0].sum(-1)[0]
|
1257 |
+
|
1258 |
+
query = query[:, :, :mask_h1, :mask_w1]
|
1259 |
+
key = key[:, :, :mask_h2, :mask_w2]
|
1260 |
+
v1 = v1[:, :, :mask_h1, :mask_w1]
|
1261 |
+
v2 = v2[:, :, :mask_h2, :mask_w2]
|
1262 |
+
x1_mask = x1_mask[:, :mask_h1, :mask_w1]
|
1263 |
+
x2_mask = x2_mask[:, :mask_h2, :mask_w2]
|
1264 |
+
|
1265 |
+
else:
|
1266 |
+
assert x1_mask is None and x2_mask is None
|
1267 |
+
|
1268 |
+
query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
1269 |
+
key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
1270 |
+
v2 = rearrange(v2, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
1271 |
+
v1 = rearrange(v1, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
1272 |
+
if x2_mask is not None or x1_mask is not None:
|
1273 |
+
x1_mask = x1_mask.flatten(-2)
|
1274 |
+
x2_mask = x2_mask.flatten(-2)
|
1275 |
+
|
1276 |
+
|
1277 |
+
QK = torch.einsum("nlhd,nshd->nlsh", query, key)
|
1278 |
+
with torch.autocast(enabled=False, device_type='cuda'):
|
1279 |
+
if x2_mask is not None or x1_mask is not None:
|
1280 |
+
# S1 = S2.transpose(-2,-3).masked_fill(~(x_mask[:, None, :, None] * source_mask[:, :, None, None]), -1e9) # float('-inf')
|
1281 |
+
QK = QK.float().masked_fill_(~(x1_mask[:, :, None, None] * x2_mask[:, None, :, None]), -1e9) # float('-inf')
|
1282 |
+
|
1283 |
+
|
1284 |
+
# Compute the attention and the weighted average
|
1285 |
+
softmax_temp = 1. / query.size(3)**.5 # sqrt(D)
|
1286 |
+
S1 = torch.softmax(softmax_temp * QK, dim=2)
|
1287 |
+
S2 = torch.softmax(softmax_temp * QK, dim=3)
|
1288 |
+
|
1289 |
+
m1 = torch.einsum("nlsh,nshd->nlhd", S1, v2)
|
1290 |
+
m2 = torch.einsum("nlsh,nlhd->nshd", S2, v1)
|
1291 |
+
|
1292 |
+
if x1_mask is not None and x2_mask is not None:
|
1293 |
+
m1 = m1.view(bs, mask_h1, mask_w1, self.nhead, self.dim)
|
1294 |
+
if mask_h1 != H1:
|
1295 |
+
m1 = torch.cat([m1, torch.zeros(m1.size(0), H1-mask_h1, W1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=1)
|
1296 |
+
elif mask_w1 != W1:
|
1297 |
+
m1 = torch.cat([m1, torch.zeros(m1.size(0), H1, W1-mask_w1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=2)
|
1298 |
+
else:
|
1299 |
+
assert x1_mask is None and x2_mask is None
|
1300 |
+
|
1301 |
+
m1 = self.merge(m1.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
1302 |
+
m1 = rearrange(m1, 'b (h w) c -> b c h w', h=H1, w=W1) # [N, C, H, W]
|
1303 |
+
m1 = torch.nn.functional.interpolate(m1, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
1304 |
+
# feed-forward network
|
1305 |
+
m1 = self.mlp(torch.cat([x1, m1], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
1306 |
+
m1 = self.norm2(m1).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
1307 |
+
|
1308 |
+
if x1_mask is not None and x2_mask is not None:
|
1309 |
+
m2 = m2.view(bs, mask_h2, mask_w2, self.nhead, self.dim)
|
1310 |
+
if mask_h2 != H2:
|
1311 |
+
m2 = torch.cat([m2, torch.zeros(m2.size(0), H2-mask_h2, W2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=1)
|
1312 |
+
elif mask_w2 != W2:
|
1313 |
+
m2 = torch.cat([m2, torch.zeros(m2.size(0), H2, W2-mask_w2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=2)
|
1314 |
+
else:
|
1315 |
+
assert x1_mask is None and x2_mask is None
|
1316 |
+
|
1317 |
+
m2 = self.merge(m2.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
1318 |
+
m2 = rearrange(m2, 'b (h w) c -> b c h w', h=H2, w=W2) # [N, C, H, W]
|
1319 |
+
m2 = torch.nn.functional.interpolate(m2, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
1320 |
+
# feed-forward network
|
1321 |
+
m2 = self.mlp(torch.cat([x2, m2], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
1322 |
+
m2 = self.norm2(m2).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
1323 |
+
|
1324 |
+
return x1 + m1, x2 + m2
|
1325 |
+
|
1326 |
+
def _build_projection(self,
|
1327 |
+
dim_in,
|
1328 |
+
dim_out,
|
1329 |
+
kernel_size=3,
|
1330 |
+
padding=1,
|
1331 |
+
stride=1,
|
1332 |
+
method='dw_bn',
|
1333 |
+
):
|
1334 |
+
if method == 'dw_bn':
|
1335 |
+
proj = nn.Sequential(OrderedDict([
|
1336 |
+
('conv', nn.Conv2d(
|
1337 |
+
dim_in,
|
1338 |
+
dim_in,
|
1339 |
+
kernel_size=kernel_size,
|
1340 |
+
padding=padding,
|
1341 |
+
stride=stride,
|
1342 |
+
bias=False,
|
1343 |
+
groups=dim_in
|
1344 |
+
)),
|
1345 |
+
('bn', nn.BatchNorm2d(dim_in)),
|
1346 |
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
1347 |
+
]))
|
1348 |
+
elif method == 'avg':
|
1349 |
+
proj = nn.Sequential(OrderedDict([
|
1350 |
+
('avg', nn.AvgPool2d(
|
1351 |
+
kernel_size=kernel_size,
|
1352 |
+
padding=padding,
|
1353 |
+
stride=stride,
|
1354 |
+
ceil_mode=True
|
1355 |
+
)),
|
1356 |
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
1357 |
+
]))
|
1358 |
+
elif method == 'linear':
|
1359 |
+
proj = None
|
1360 |
+
elif method == 'dw':
|
1361 |
+
proj = nn.Sequential(OrderedDict([
|
1362 |
+
('conv', nn.Conv2d(
|
1363 |
+
dim_in,
|
1364 |
+
dim_in,
|
1365 |
+
kernel_size=kernel_size,
|
1366 |
+
padding=padding,
|
1367 |
+
stride=stride,
|
1368 |
+
bias=False,
|
1369 |
+
groups=dim_in
|
1370 |
+
)),
|
1371 |
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
1372 |
+
]))
|
1373 |
+
else:
|
1374 |
+
raise ValueError('Unknown method ({})'.format(method))
|
1375 |
+
|
1376 |
+
return proj
|
1377 |
+
|
1378 |
+
class LocalFeatureTransformer(nn.Module):
|
1379 |
+
"""A Local Feature Transformer (LoFTR) module."""
|
1380 |
+
|
1381 |
+
def __init__(self, config):
|
1382 |
+
super(LocalFeatureTransformer, self).__init__()
|
1383 |
+
|
1384 |
+
self.full_config = config
|
1385 |
+
self.fine = False
|
1386 |
+
if 'coarse' not in config:
|
1387 |
+
self.fine = True # fine attention
|
1388 |
+
else:
|
1389 |
+
config = config['coarse']
|
1390 |
+
self.d_model = config['d_model']
|
1391 |
+
self.nhead = config['nhead']
|
1392 |
+
self.layer_names = config['layer_names']
|
1393 |
+
self.pan = config['pan']
|
1394 |
+
self.bidirect = config['bidirection']
|
1395 |
+
# prune
|
1396 |
+
self.pool_size = config['pool_size']
|
1397 |
+
self.matchability = False
|
1398 |
+
self.depth_confidence = -1.0
|
1399 |
+
self.width_confidence = -1.0
|
1400 |
+
# self.depth_confidence = config['depth_confidence']
|
1401 |
+
# self.width_confidence = config['width_confidence']
|
1402 |
+
# self.matchability = self.depth_confidence > 0 or self.width_confidence > 0
|
1403 |
+
# self.thr = self.full_config['match_coarse']['thr']
|
1404 |
+
if not self.fine:
|
1405 |
+
# asy
|
1406 |
+
self.asymmetric = config['asymmetric']
|
1407 |
+
self.asymmetric_self = config['asymmetric_self']
|
1408 |
+
# aggregate
|
1409 |
+
self.aggregate = config['dwconv']
|
1410 |
+
# RoPE
|
1411 |
+
self.rope = config['rope']
|
1412 |
+
# absPE
|
1413 |
+
self.abspe = config['abspe']
|
1414 |
+
|
1415 |
+
else:
|
1416 |
+
self.rope, self.asymmetric, self.asymmetric_self, self.aggregate = False, False, False, False
|
1417 |
+
if self.matchability:
|
1418 |
+
self.n_layers = len(self.layer_names) // 2
|
1419 |
+
assert self.n_layers == 4
|
1420 |
+
self.log_assignment = nn.ModuleList(
|
1421 |
+
[MatchAssignment(self.d_model) for _ in range(self.n_layers)])
|
1422 |
+
self.token_confidence = nn.ModuleList([
|
1423 |
+
TokenConfidence(self.d_model) for _ in range(self.n_layers-1)])
|
1424 |
+
|
1425 |
+
self.CoarseMatching = CoarseMatching(self.full_config['match_coarse'])
|
1426 |
+
|
1427 |
+
# self only
|
1428 |
+
# if self.rope:
|
1429 |
+
# self_layer = RoPELoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'], config['rope'], config['token_mixer'])
|
1430 |
+
# self.layers = nn.ModuleList([copy.deepcopy(self_layer) for _ in range(len(self.layer_names))])
|
1431 |
+
|
1432 |
+
if self.bidirect:
|
1433 |
+
assert config['xformer'] is False and config['pan'] is True
|
1434 |
+
self_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'], config['xformer'])
|
1435 |
+
cross_layer = PANEncoderLayer_cross(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'])
|
1436 |
+
self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
|
1437 |
+
else:
|
1438 |
+
if self.aggregate:
|
1439 |
+
if self.rope:
|
1440 |
+
# assert config['npe'][0] == 832 and config['npe'][1] == 832 and config['npe'][2] == 832 and config['npe'][3] == 832
|
1441 |
+
logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
|
1442 |
+
self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
|
1443 |
+
config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
|
1444 |
+
config['norm_before'], config['rope'], config['npe'], config['vit_norm'], config['rope_dwproj'])
|
1445 |
+
cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
|
1446 |
+
config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
|
1447 |
+
config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
|
1448 |
+
self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
|
1449 |
+
elif self.abspe:
|
1450 |
+
logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
|
1451 |
+
self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
|
1452 |
+
config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
|
1453 |
+
config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
|
1454 |
+
cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
|
1455 |
+
config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
|
1456 |
+
config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
|
1457 |
+
self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
|
1458 |
+
|
1459 |
+
else:
|
1460 |
+
encoder_layer = AG_Conv_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'],
|
1461 |
+
config['xformer'], config['leaky'], config['dwconv'], config['scatter'],
|
1462 |
+
config['norm_before'])
|
1463 |
+
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
1464 |
+
else:
|
1465 |
+
encoder_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'],
|
1466 |
+
config['bn'], config['xformer'], config['leaky'], config['dwconv'], config['scatter']) \
|
1467 |
+
if config['pan'] else LoFTREncoderLayer(config['d_model'], config['nhead'],
|
1468 |
+
config['attention'], config['xformer'])
|
1469 |
+
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
1470 |
+
self._reset_parameters()
|
1471 |
+
|
1472 |
+
def _reset_parameters(self):
|
1473 |
+
for p in self.parameters():
|
1474 |
+
if p.dim() > 1:
|
1475 |
+
nn.init.xavier_uniform_(p)
|
1476 |
+
|
1477 |
+
def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
|
1478 |
+
"""
|
1479 |
+
Args:
|
1480 |
+
feat0 (torch.Tensor): [N, C, H, W]
|
1481 |
+
feat1 (torch.Tensor): [N, C, H, W]
|
1482 |
+
mask0 (torch.Tensor): [N, L] (optional)
|
1483 |
+
mask1 (torch.Tensor): [N, S] (optional)
|
1484 |
+
"""
|
1485 |
+
# nchw for pan and n(hw)c for loftr
|
1486 |
+
assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
|
1487 |
+
H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1)
|
1488 |
+
bs = feat0.shape[0]
|
1489 |
+
padding = False
|
1490 |
+
if bs == 1 and mask0 is not None and mask1 is not None and self.pan: # NCHW for pan
|
1491 |
+
mask_H0, mask_W0 = mask0.size(-2), mask0.size(-1)
|
1492 |
+
mask_H1, mask_W1 = mask1.size(-2), mask1.size(-1)
|
1493 |
+
mask_h0, mask_w0 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0]
|
1494 |
+
mask_h1, mask_w1 = mask1[0].sum(-2)[0], mask1[0].sum(-1)[0]
|
1495 |
+
|
1496 |
+
#round to self.pool_size
|
1497 |
+
if self.pan:
|
1498 |
+
mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.pool_size*self.pool_size, mask_w0//self.pool_size*self.pool_size, mask_h1//self.pool_size*self.pool_size, mask_w1//self.pool_size*self.pool_size
|
1499 |
+
|
1500 |
+
feat0 = feat0[:, :, :mask_h0, :mask_w0]
|
1501 |
+
feat1 = feat1[:, :, :mask_h1, :mask_w1]
|
1502 |
+
|
1503 |
+
padding = True
|
1504 |
+
|
1505 |
+
# rope self only
|
1506 |
+
# if self.rope:
|
1507 |
+
# feat0, feat1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
|
1508 |
+
# prune
|
1509 |
+
if padding:
|
1510 |
+
l0, l1 = mask_h0 * mask_w0, mask_h1 * mask_w1
|
1511 |
+
else:
|
1512 |
+
l0, l1 = H0 * W0, H1 * W1
|
1513 |
+
do_early_stop = self.depth_confidence > 0
|
1514 |
+
do_point_pruning = self.width_confidence > 0
|
1515 |
+
if do_point_pruning:
|
1516 |
+
ind0 = torch.arange(0, l0, device=feat0.device)[None]
|
1517 |
+
ind1 = torch.arange(0, l1, device=feat0.device)[None]
|
1518 |
+
# We store the index of the layer at which pruning is detected.
|
1519 |
+
prune0 = torch.ones_like(ind0)
|
1520 |
+
prune1 = torch.ones_like(ind1)
|
1521 |
+
if do_early_stop:
|
1522 |
+
token0, token1 = None, None
|
1523 |
+
|
1524 |
+
for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
|
1525 |
+
if padding:
|
1526 |
+
mask0, mask1 = None, None
|
1527 |
+
if name == 'self':
|
1528 |
+
# if self.rope:
|
1529 |
+
# feat0 = layer(feat0, feat0, mask0, mask1, H0, W0)
|
1530 |
+
# feat1 = layer(feat1, feat1, mask0, mask1, H1, W1)
|
1531 |
+
if self.asymmetric:
|
1532 |
+
assert False, 'not worked'
|
1533 |
+
# feat0 = layer(feat0, feat0, mask0, mask1)
|
1534 |
+
feat1 = layer(feat1, feat1, mask1, mask1)
|
1535 |
+
else:
|
1536 |
+
feat0 = layer(feat0, feat0, mask0, mask0)
|
1537 |
+
feat1 = layer(feat1, feat1, mask1, mask1)
|
1538 |
+
elif name == 'cross':
|
1539 |
+
if self.bidirect:
|
1540 |
+
feat0, feat1 = layer(feat0, feat1, mask0, mask1)
|
1541 |
+
else:
|
1542 |
+
if self.asymmetric or self.asymmetric_self:
|
1543 |
+
assert False, 'not worked'
|
1544 |
+
feat0 = layer(feat0, feat1, mask0, mask1)
|
1545 |
+
else:
|
1546 |
+
feat0 = layer(feat0, feat1, mask0, mask1)
|
1547 |
+
feat1 = layer(feat1, feat0, mask1, mask0)
|
1548 |
+
|
1549 |
+
if i == len(self.layer_names) - 1 and not self.training:
|
1550 |
+
continue
|
1551 |
+
if self.matchability:
|
1552 |
+
desc0, desc1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
|
1553 |
+
if do_early_stop:
|
1554 |
+
token0, token1 = self.token_confidence[i//2](desc0, desc1)
|
1555 |
+
if self.check_if_stop(token0, token1, i, l0+l1) and not self.training:
|
1556 |
+
break
|
1557 |
+
if do_point_pruning:
|
1558 |
+
scores0, scores1 = self.log_assignment[i//2].scores(desc0, desc1)
|
1559 |
+
mask0 = self.get_pruning_mask(token0, scores0, i)
|
1560 |
+
mask1 = self.get_pruning_mask(token1, scores1, i)
|
1561 |
+
ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
|
1562 |
+
feat0, feat1 = desc0[mask0][None], desc1[mask1][None]
|
1563 |
+
if feat0.shape[-2] == 0 or desc1.shape[-2] == 0:
|
1564 |
+
break
|
1565 |
+
prune0[:, ind0] += 1
|
1566 |
+
prune1[:, ind1] += 1
|
1567 |
+
if self.training and self.matchability:
|
1568 |
+
scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
|
1569 |
+
m0_full = torch.zeros((bs, mask_h0 * mask_w0), device=matchability0.device, dtype=matchability0.dtype)
|
1570 |
+
m0_full.scatter(1, ind0, matchability0.squeeze(-1))
|
1571 |
+
if padding and self.d_model == feat0.size(1):
|
1572 |
+
m0_full = m0_full.reshape(bs, mask_h0, mask_w0)
|
1573 |
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
1574 |
+
if mask_h0 != mask_H0:
|
1575 |
+
m0_full = torch.cat([m0_full, torch.zeros(bs, mask_H0-mask_h0, mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=1)
|
1576 |
+
elif mask_w0 != mask_W0:
|
1577 |
+
m0_full = torch.cat([m0_full, torch.zeros(bs, mask_h0, mask_W0-mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=2)
|
1578 |
+
m0_full = m0_full.reshape(bs, mask_H0*mask_W0)
|
1579 |
+
m1_full = torch.zeros((bs, mask_h1 * mask_w1), device=matchability0.device, dtype=matchability0.dtype)
|
1580 |
+
m1_full.scatter(1, ind1, matchability1.squeeze(-1))
|
1581 |
+
if padding and self.d_model == feat1.size(1):
|
1582 |
+
m1_full = m1_full.reshape(bs, mask_h1, mask_w1)
|
1583 |
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
1584 |
+
if mask_h1 != mask_H1:
|
1585 |
+
m1_full = torch.cat([m1_full, torch.zeros(bs, mask_H1-mask_h1, mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=1)
|
1586 |
+
elif mask_w1 != mask_W1:
|
1587 |
+
m1_full = torch.cat([m1_full, torch.zeros(bs, mask_h1, mask_W1-mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=2)
|
1588 |
+
m1_full = m1_full.reshape(bs, mask_H1*mask_W1)
|
1589 |
+
data.update({'matchability0_'+str(i//2): m0_full, 'matchability1_'+str(i//2): m1_full})
|
1590 |
+
m0, m1, mscores0, mscores1 = filter_matches(
|
1591 |
+
scores, self.thr)
|
1592 |
+
if do_point_pruning:
|
1593 |
+
m0_ = torch.full((bs, l0), -1, device=m0.device, dtype=m0.dtype)
|
1594 |
+
m1_ = torch.full((bs, l1), -1, device=m1.device, dtype=m1.dtype)
|
1595 |
+
m0_[:, ind0] = torch.where(
|
1596 |
+
m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
1597 |
+
m1_[:, ind1] = torch.where(
|
1598 |
+
m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
1599 |
+
mscores0_ = torch.zeros((bs, l0), device=mscores0.device)
|
1600 |
+
mscores1_ = torch.zeros((bs, l1), device=mscores1.device)
|
1601 |
+
mscores0_[:, ind0] = mscores0
|
1602 |
+
mscores1_[:, ind1] = mscores1
|
1603 |
+
m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
1604 |
+
if padding and self.d_model == feat0.size(1):
|
1605 |
+
m0 = m0.reshape(bs, mask_h0, mask_w0)
|
1606 |
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
1607 |
+
if mask_h0 != mask_H0:
|
1608 |
+
m0 = torch.cat([m0, -torch.ones(bs, mask_H0-mask_h0, mask_w0, device=m0.device, dtype=m0.dtype)], dim=1)
|
1609 |
+
elif mask_w0 != mask_W0:
|
1610 |
+
m0 = torch.cat([m0, -torch.ones(bs, mask_h0, mask_W0-mask_w0, device=m0.device, dtype=m0.dtype)], dim=2)
|
1611 |
+
m0 = m0.reshape(bs, mask_H0*mask_W0)
|
1612 |
+
if padding and self.d_model == feat1.size(1):
|
1613 |
+
m1 = m1.reshape(bs, mask_h1, mask_w1)
|
1614 |
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
1615 |
+
if mask_h1 != mask_H1:
|
1616 |
+
m1 = torch.cat([m1, -torch.ones(bs, mask_H1-mask_h1, mask_w1, device=m1.device, dtype=m1.dtype)], dim=1)
|
1617 |
+
elif mask_w1 != mask_W1:
|
1618 |
+
m1 = torch.cat([m1, -torch.ones(bs, mask_h1, mask_W1-mask_w1, device=m1.device, dtype=m1.dtype)], dim=2)
|
1619 |
+
m1 = m1.reshape(bs, mask_H1*mask_W1)
|
1620 |
+
data.update({'matches0_'+str(i//2): m0, 'matches1_'+str(i//2): m1})
|
1621 |
+
conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
|
1622 |
+
ind = ind0[...,None] * l1 + ind1[:,None,:]
|
1623 |
+
# conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
|
1624 |
+
conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
|
1625 |
+
if padding and self.d_model == feat0.size(1):
|
1626 |
+
conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
|
1627 |
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
1628 |
+
if mask_h0 != mask_H0:
|
1629 |
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
|
1630 |
+
elif mask_w0 != mask_W0:
|
1631 |
+
conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
|
1632 |
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
1633 |
+
if mask_h1 != mask_H1:
|
1634 |
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
|
1635 |
+
elif mask_w1 != mask_W1:
|
1636 |
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
|
1637 |
+
conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
|
1638 |
+
data.update({'conf_matrix_'+str(i//2): conf})
|
1639 |
+
|
1640 |
+
|
1641 |
+
|
1642 |
+
else:
|
1643 |
+
raise KeyError
|
1644 |
+
|
1645 |
+
if self.matchability and not self.training:
|
1646 |
+
scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
|
1647 |
+
conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
|
1648 |
+
ind = ind0[...,None] * l1 + ind1[:,None,:]
|
1649 |
+
# conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
|
1650 |
+
conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
|
1651 |
+
if padding and self.d_model == feat0.size(1):
|
1652 |
+
conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
|
1653 |
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
1654 |
+
if mask_h0 != mask_H0:
|
1655 |
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
|
1656 |
+
elif mask_w0 != mask_W0:
|
1657 |
+
conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
|
1658 |
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
1659 |
+
if mask_h1 != mask_H1:
|
1660 |
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
|
1661 |
+
elif mask_w1 != mask_W1:
|
1662 |
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
|
1663 |
+
conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
|
1664 |
+
data.update({'conf_matrix': conf})
|
1665 |
+
data.update(**self.CoarseMatching.get_coarse_match(conf, data))
|
1666 |
+
# m0, m1, mscores0, mscores1 = filter_matches(
|
1667 |
+
# scores, self.conf.filter_threshold)
|
1668 |
+
|
1669 |
+
# matches, mscores = [], []
|
1670 |
+
# for k in range(b):
|
1671 |
+
# valid = m0[k] > -1
|
1672 |
+
# m_indices_0 = torch.where(valid)[0]
|
1673 |
+
# m_indices_1 = m0[k][valid]
|
1674 |
+
# if do_point_pruning:
|
1675 |
+
# m_indices_0 = ind0[k, m_indices_0]
|
1676 |
+
# m_indices_1 = ind1[k, m_indices_1]
|
1677 |
+
# matches.append(torch.stack([m_indices_0, m_indices_1], -1))
|
1678 |
+
# mscores.append(mscores0[k][valid])
|
1679 |
+
|
1680 |
+
# # TODO: Remove when hloc switches to the compact format.
|
1681 |
+
# if do_point_pruning:
|
1682 |
+
# m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
|
1683 |
+
# m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
|
1684 |
+
# m0_[:, ind0] = torch.where(
|
1685 |
+
# m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
1686 |
+
# m1_[:, ind1] = torch.where(
|
1687 |
+
# m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
1688 |
+
# mscores0_ = torch.zeros((b, m), device=mscores0.device)
|
1689 |
+
# mscores1_ = torch.zeros((b, n), device=mscores1.device)
|
1690 |
+
# mscores0_[:, ind0] = mscores0
|
1691 |
+
# mscores1_[:, ind1] = mscores1
|
1692 |
+
# m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
1693 |
+
|
1694 |
+
# pred = {
|
1695 |
+
# 'matches0': m0,
|
1696 |
+
# 'matches1': m1,
|
1697 |
+
# 'matching_scores0': mscores0,
|
1698 |
+
# 'matching_scores1': mscores1,
|
1699 |
+
# 'stop': i+1,
|
1700 |
+
# 'matches': matches,
|
1701 |
+
# 'scores': mscores,
|
1702 |
+
# }
|
1703 |
+
|
1704 |
+
# if do_point_pruning:
|
1705 |
+
# pred.update(dict(prune0=prune0, prune1=prune1))
|
1706 |
+
# return pred
|
1707 |
+
|
1708 |
+
|
1709 |
+
if padding and self.d_model == feat0.size(1):
|
1710 |
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
1711 |
+
if mask_h0 != mask_H0:
|
1712 |
+
feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2)
|
1713 |
+
elif mask_w0 != mask_W0:
|
1714 |
+
feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1)
|
1715 |
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
1716 |
+
if mask_h1 != mask_H1:
|
1717 |
+
feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2)
|
1718 |
+
elif mask_w1 != mask_W1:
|
1719 |
+
feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1)
|
1720 |
+
|
1721 |
+
return feat0, feat1
|
1722 |
+
|
1723 |
+
def pro(self, feat0, feat1, mask0=None, mask1=None, profiler=None):
|
1724 |
+
"""
|
1725 |
+
Args:
|
1726 |
+
feat0 (torch.Tensor): [N, C, H, W]
|
1727 |
+
feat1 (torch.Tensor): [N, C, H, W]
|
1728 |
+
mask0 (torch.Tensor): [N, L] (optional)
|
1729 |
+
mask1 (torch.Tensor): [N, S] (optional)
|
1730 |
+
"""
|
1731 |
+
|
1732 |
+
assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
|
1733 |
+
with profiler.profile("LoFTR_transformer_attention"):
|
1734 |
+
for layer, name in zip(self.layers, self.layer_names):
|
1735 |
+
if name == 'self':
|
1736 |
+
feat0 = layer.pro(feat0, feat0, mask0, mask0, profiler=profiler)
|
1737 |
+
feat1 = layer.pro(feat1, feat1, mask1, mask1, profiler=profiler)
|
1738 |
+
elif name == 'cross':
|
1739 |
+
feat0 = layer.pro(feat0, feat1, mask0, mask1, profiler=profiler)
|
1740 |
+
feat1 = layer.pro(feat1, feat0, mask1, mask0, profiler=profiler)
|
1741 |
+
else:
|
1742 |
+
raise KeyError
|
1743 |
+
|
1744 |
+
return feat0, feat1
|
1745 |
+
|
1746 |
+
def confidence_threshold(self, layer_index: int) -> float:
|
1747 |
+
""" scaled confidence threshold """
|
1748 |
+
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
|
1749 |
+
return np.clip(threshold, 0, 1)
|
1750 |
+
|
1751 |
+
def get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor,
|
1752 |
+
layer_index: int) -> torch.Tensor:
|
1753 |
+
""" mask points which should be removed """
|
1754 |
+
threshold = self.confidence_threshold(layer_index)
|
1755 |
+
if confidences is not None:
|
1756 |
+
scores = torch.where(
|
1757 |
+
confidences > threshold, scores, scores.new_tensor(1.0))
|
1758 |
+
return scores > (1 - self.width_confidence)
|
1759 |
+
|
1760 |
+
def check_if_stop(self,
|
1761 |
+
confidences0: torch.Tensor,
|
1762 |
+
confidences1: torch.Tensor,
|
1763 |
+
layer_index: int, num_points: int) -> torch.Tensor:
|
1764 |
+
""" evaluate stopping condition"""
|
1765 |
+
confidences = torch.cat([confidences0, confidences1], -1)
|
1766 |
+
threshold = self.confidence_threshold(layer_index)
|
1767 |
+
pos = 1.0 - (confidences < threshold).float().sum() / num_points
|
1768 |
+
return pos > self.depth_confidence
|
imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class TokenConfidence(nn.Module):
|
6 |
+
def __init__(self, dim: int) -> None:
|
7 |
+
super().__init__()
|
8 |
+
self.token = nn.Sequential(
|
9 |
+
nn.Linear(dim, 1),
|
10 |
+
nn.Sigmoid()
|
11 |
+
)
|
12 |
+
|
13 |
+
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
14 |
+
""" get confidence tokens """
|
15 |
+
return (
|
16 |
+
self.token(desc0.detach().float()).squeeze(-1),
|
17 |
+
self.token(desc1.detach().float()).squeeze(-1))
|
18 |
+
|
19 |
+
def sigmoid_log_double_softmax(
|
20 |
+
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
|
21 |
+
""" create the log assignment matrix from logits and similarity"""
|
22 |
+
b, m, n = sim.shape
|
23 |
+
m0, m1 = torch.sigmoid(z0), torch.sigmoid(z1)
|
24 |
+
certainties = torch.log(m0) + torch.log(m1).transpose(1, 2)
|
25 |
+
scores0 = F.log_softmax(sim, 2)
|
26 |
+
scores1 = F.log_softmax(
|
27 |
+
sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
|
28 |
+
scores = scores0 + scores1 + certainties
|
29 |
+
# scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
|
30 |
+
# scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
|
31 |
+
return scores, m0, m1
|
32 |
+
|
33 |
+
class MatchAssignment(nn.Module):
|
34 |
+
def __init__(self, dim: int) -> None:
|
35 |
+
super().__init__()
|
36 |
+
self.dim = dim
|
37 |
+
self.matchability = nn.Linear(dim, 1, bias=True)
|
38 |
+
self.final_proj = nn.Linear(dim, dim, bias=True)
|
39 |
+
|
40 |
+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
41 |
+
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
42 |
+
""" build assignment matrix from descriptors """
|
43 |
+
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
|
44 |
+
_, _, d = mdesc0.shape
|
45 |
+
mdesc0, mdesc1 = mdesc0 / d**.25, mdesc1 / d**.25
|
46 |
+
sim = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1)
|
47 |
+
z0 = self.matchability(desc0)
|
48 |
+
z1 = self.matchability(desc1)
|
49 |
+
scores, m0, m1 = sigmoid_log_double_softmax(sim, z0, z1)
|
50 |
+
return scores, sim, m0, m1
|
51 |
+
|
52 |
+
def scores(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
53 |
+
m0 = torch.sigmoid(self.matchability(desc0)).squeeze(-1)
|
54 |
+
m1 = torch.sigmoid(self.matchability(desc1)).squeeze(-1)
|
55 |
+
return m0, m1
|
56 |
+
|
57 |
+
def filter_matches(scores: torch.Tensor, th: float):
|
58 |
+
""" obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
|
59 |
+
max0, max1 = scores.max(2), scores.max(1)
|
60 |
+
m0, m1 = max0.indices, max1.indices
|
61 |
+
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
|
62 |
+
indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
|
63 |
+
mutual0 = indices0 == m1.gather(1, m0)
|
64 |
+
mutual1 = indices1 == m0.gather(1, m1)
|
65 |
+
max0_exp = max0.values.exp()
|
66 |
+
zero = max0_exp.new_tensor(0)
|
67 |
+
mscores0 = torch.where(mutual0, max0_exp, zero)
|
68 |
+
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
69 |
+
if th is not None:
|
70 |
+
valid0 = mutual0 & (mscores0 > th)
|
71 |
+
else:
|
72 |
+
valid0 = mutual0
|
73 |
+
valid1 = mutual1 & valid0.gather(1, m1)
|
74 |
+
m0 = torch.where(valid0, m0, -1)
|
75 |
+
m1 = torch.where(valid1, m1, -1)
|
76 |
+
return m0, m1, mscores0, mscores1
|
imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops.einops import rearrange, repeat
|
5 |
+
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
INF = 1e9
|
9 |
+
|
10 |
+
def mask_border(m, b: int, v):
|
11 |
+
""" Mask borders with value
|
12 |
+
Args:
|
13 |
+
m (torch.Tensor): [N, H0, W0, H1, W1]
|
14 |
+
b (int)
|
15 |
+
v (m.dtype)
|
16 |
+
"""
|
17 |
+
if b <= 0:
|
18 |
+
return
|
19 |
+
|
20 |
+
m[:, :b] = v
|
21 |
+
m[:, :, :b] = v
|
22 |
+
m[:, :, :, :b] = v
|
23 |
+
m[:, :, :, :, :b] = v
|
24 |
+
m[:, -b:] = v
|
25 |
+
m[:, :, -b:] = v
|
26 |
+
m[:, :, :, -b:] = v
|
27 |
+
m[:, :, :, :, -b:] = v
|
28 |
+
|
29 |
+
|
30 |
+
def mask_border_with_padding(m, bd, v, p_m0, p_m1):
|
31 |
+
if bd <= 0:
|
32 |
+
return
|
33 |
+
|
34 |
+
m[:, :bd] = v
|
35 |
+
m[:, :, :bd] = v
|
36 |
+
m[:, :, :, :bd] = v
|
37 |
+
m[:, :, :, :, :bd] = v
|
38 |
+
|
39 |
+
h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
|
40 |
+
h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
|
41 |
+
for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
|
42 |
+
m[b_idx, h0 - bd:] = v
|
43 |
+
m[b_idx, :, w0 - bd:] = v
|
44 |
+
m[b_idx, :, :, h1 - bd:] = v
|
45 |
+
m[b_idx, :, :, :, w1 - bd:] = v
|
46 |
+
|
47 |
+
|
48 |
+
def compute_max_candidates(p_m0, p_m1):
|
49 |
+
"""Compute the max candidates of all pairs within a batch
|
50 |
+
|
51 |
+
Args:
|
52 |
+
p_m0, p_m1 (torch.Tensor): padded masks
|
53 |
+
"""
|
54 |
+
h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
|
55 |
+
h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
|
56 |
+
max_cand = torch.sum(
|
57 |
+
torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
|
58 |
+
return max_cand
|
59 |
+
|
60 |
+
|
61 |
+
class CoarseMatching(nn.Module):
|
62 |
+
def __init__(self, config):
|
63 |
+
super().__init__()
|
64 |
+
self.config = config
|
65 |
+
# general config
|
66 |
+
self.thr = config['thr']
|
67 |
+
self.border_rm = config['border_rm']
|
68 |
+
# -- # for trainig fine-level LoFTR
|
69 |
+
self.train_coarse_percent = config['train_coarse_percent']
|
70 |
+
self.train_pad_num_gt_min = config['train_pad_num_gt_min']
|
71 |
+
|
72 |
+
# we provide 2 options for differentiable matching
|
73 |
+
self.match_type = config['match_type']
|
74 |
+
if self.match_type == 'dual_softmax':
|
75 |
+
self.temperature = config['dsmax_temperature']
|
76 |
+
elif self.match_type == 'sinkhorn':
|
77 |
+
try:
|
78 |
+
from .superglue import log_optimal_transport
|
79 |
+
except ImportError:
|
80 |
+
raise ImportError("download superglue.py first!")
|
81 |
+
self.log_optimal_transport = log_optimal_transport
|
82 |
+
self.bin_score = nn.Parameter(
|
83 |
+
torch.tensor(config['skh_init_bin_score'], requires_grad=True))
|
84 |
+
self.skh_iters = config['skh_iters']
|
85 |
+
self.skh_prefilter = config['skh_prefilter']
|
86 |
+
else:
|
87 |
+
raise NotImplementedError()
|
88 |
+
|
89 |
+
self.mtd = config['mtd_spvs']
|
90 |
+
self.fix_bias = config['fix_bias']
|
91 |
+
|
92 |
+
def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
feat0 (torch.Tensor): [N, L, C]
|
96 |
+
feat1 (torch.Tensor): [N, S, C]
|
97 |
+
data (dict)
|
98 |
+
mask_c0 (torch.Tensor): [N, L] (optional)
|
99 |
+
mask_c1 (torch.Tensor): [N, S] (optional)
|
100 |
+
Update:
|
101 |
+
data (dict): {
|
102 |
+
'b_ids' (torch.Tensor): [M'],
|
103 |
+
'i_ids' (torch.Tensor): [M'],
|
104 |
+
'j_ids' (torch.Tensor): [M'],
|
105 |
+
'gt_mask' (torch.Tensor): [M'],
|
106 |
+
'mkpts0_c' (torch.Tensor): [M, 2],
|
107 |
+
'mkpts1_c' (torch.Tensor): [M, 2],
|
108 |
+
'mconf' (torch.Tensor): [M]}
|
109 |
+
NOTE: M' != M during training.
|
110 |
+
"""
|
111 |
+
N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
|
112 |
+
|
113 |
+
# normalize
|
114 |
+
feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
|
115 |
+
[feat_c0, feat_c1])
|
116 |
+
|
117 |
+
if self.match_type == 'dual_softmax':
|
118 |
+
with torch.autocast(enabled=False, device_type='cuda'):
|
119 |
+
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
|
120 |
+
feat_c1) / self.temperature
|
121 |
+
if mask_c0 is not None:
|
122 |
+
sim_matrix = sim_matrix.float().masked_fill_(
|
123 |
+
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
|
124 |
+
-INF
|
125 |
+
# float("-inf") if sim_matrix.dtype == torch.float16 else -INF
|
126 |
+
)
|
127 |
+
if self.config['fp16log']:
|
128 |
+
t1 = F.softmax(sim_matrix, 1)
|
129 |
+
t2 = F.softmax(sim_matrix, 2)
|
130 |
+
conf_matrix = t1*t2
|
131 |
+
logger.info(f'feat_c0absmax: {feat_c0.abs().max()}')
|
132 |
+
logger.info(f'feat_c1absmax: {feat_c1.abs().max()}')
|
133 |
+
logger.info(f'sim_matrix: {sim_matrix.dtype}')
|
134 |
+
logger.info(f'sim_matrixabsmax: {sim_matrix.abs().max()}')
|
135 |
+
logger.info(f't1: {t1.dtype}, t2: {t2.dtype}, conf_matrix: {conf_matrix.dtype}')
|
136 |
+
logger.info(f't1absmax: {t1.abs().max()}, t2absmax: {t2.abs().max()}, conf_matrixabsmax: {conf_matrix.abs().max()}')
|
137 |
+
else:
|
138 |
+
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
|
139 |
+
|
140 |
+
data.update({'conf_matrix': conf_matrix})
|
141 |
+
|
142 |
+
# predict coarse matches from conf_matrix
|
143 |
+
data.update(**self.get_coarse_match(conf_matrix, data))
|
144 |
+
|
145 |
+
@torch.no_grad()
|
146 |
+
def get_coarse_match(self, conf_matrix, data):
|
147 |
+
"""
|
148 |
+
Args:
|
149 |
+
conf_matrix (torch.Tensor): [N, L, S]
|
150 |
+
data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
|
151 |
+
Returns:
|
152 |
+
coarse_matches (dict): {
|
153 |
+
'b_ids' (torch.Tensor): [M'],
|
154 |
+
'i_ids' (torch.Tensor): [M'],
|
155 |
+
'j_ids' (torch.Tensor): [M'],
|
156 |
+
'gt_mask' (torch.Tensor): [M'],
|
157 |
+
'm_bids' (torch.Tensor): [M],
|
158 |
+
'mkpts0_c' (torch.Tensor): [M, 2],
|
159 |
+
'mkpts1_c' (torch.Tensor): [M, 2],
|
160 |
+
'mconf' (torch.Tensor): [M]}
|
161 |
+
"""
|
162 |
+
axes_lengths = {
|
163 |
+
'h0c': data['hw0_c'][0],
|
164 |
+
'w0c': data['hw0_c'][1],
|
165 |
+
'h1c': data['hw1_c'][0],
|
166 |
+
'w1c': data['hw1_c'][1]
|
167 |
+
}
|
168 |
+
_device = conf_matrix.device
|
169 |
+
# 1. confidence thresholding
|
170 |
+
mask = conf_matrix > self.thr
|
171 |
+
mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
|
172 |
+
**axes_lengths)
|
173 |
+
if 'mask0' not in data:
|
174 |
+
mask_border(mask, self.border_rm, False)
|
175 |
+
else:
|
176 |
+
mask_border_with_padding(mask, self.border_rm, False,
|
177 |
+
data['mask0'], data['mask1'])
|
178 |
+
mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
|
179 |
+
**axes_lengths)
|
180 |
+
|
181 |
+
# 2. mutual nearest
|
182 |
+
if self.mtd:
|
183 |
+
b_ids, i_ids, j_ids = torch.where(mask)
|
184 |
+
mconf = conf_matrix[b_ids, i_ids, j_ids]
|
185 |
+
else:
|
186 |
+
mask = mask \
|
187 |
+
* (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
|
188 |
+
* (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
|
189 |
+
|
190 |
+
# 3. find all valid coarse matches
|
191 |
+
# this only works when at most one `True` in each row
|
192 |
+
mask_v, all_j_ids = mask.max(dim=2)
|
193 |
+
b_ids, i_ids = torch.where(mask_v)
|
194 |
+
j_ids = all_j_ids[b_ids, i_ids]
|
195 |
+
mconf = conf_matrix[b_ids, i_ids, j_ids]
|
196 |
+
|
197 |
+
# 4. Random sampling of training samples for fine-level LoFTR
|
198 |
+
# (optional) pad samples with gt coarse-level matches
|
199 |
+
if self.training:
|
200 |
+
# NOTE:
|
201 |
+
# The sampling is performed across all pairs in a batch without manually balancing
|
202 |
+
# #samples for fine-level increases w.r.t. batch_size
|
203 |
+
if 'mask0' not in data:
|
204 |
+
num_candidates_max = mask.size(0) * max(
|
205 |
+
mask.size(1), mask.size(2))
|
206 |
+
else:
|
207 |
+
num_candidates_max = compute_max_candidates(
|
208 |
+
data['mask0'], data['mask1'])
|
209 |
+
num_matches_train = int(num_candidates_max *
|
210 |
+
self.train_coarse_percent)
|
211 |
+
num_matches_pred = len(b_ids)
|
212 |
+
assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
|
213 |
+
|
214 |
+
# pred_indices is to select from prediction
|
215 |
+
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
|
216 |
+
pred_indices = torch.arange(num_matches_pred, device=_device)
|
217 |
+
else:
|
218 |
+
pred_indices = torch.randint(
|
219 |
+
num_matches_pred,
|
220 |
+
(num_matches_train - self.train_pad_num_gt_min, ),
|
221 |
+
device=_device)
|
222 |
+
|
223 |
+
# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
|
224 |
+
gt_pad_indices = torch.randint(
|
225 |
+
len(data['spv_b_ids']),
|
226 |
+
(max(num_matches_train - num_matches_pred,
|
227 |
+
self.train_pad_num_gt_min), ),
|
228 |
+
device=_device)
|
229 |
+
mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
|
230 |
+
|
231 |
+
b_ids, i_ids, j_ids, mconf = map(
|
232 |
+
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
|
233 |
+
dim=0),
|
234 |
+
*zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
|
235 |
+
[j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
|
236 |
+
|
237 |
+
# These matches select patches that feed into fine-level network
|
238 |
+
coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
|
239 |
+
|
240 |
+
# 4. Update with matches in original image resolution
|
241 |
+
if self.fix_bias:
|
242 |
+
scale = 8
|
243 |
+
else:
|
244 |
+
scale = data['hw0_i'][0] / data['hw0_c'][0]
|
245 |
+
scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
|
246 |
+
scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
|
247 |
+
mkpts0_c = torch.stack(
|
248 |
+
[i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
|
249 |
+
dim=1) * scale0
|
250 |
+
mkpts1_c = torch.stack(
|
251 |
+
[j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
|
252 |
+
dim=1) * scale1
|
253 |
+
|
254 |
+
m_bids = b_ids[mconf != 0]
|
255 |
+
|
256 |
+
m_bids_f = repeat(m_bids, 'b -> b k', k = 3).reshape(-1)
|
257 |
+
coarse_matches.update({
|
258 |
+
'gt_mask': mconf == 0,
|
259 |
+
'm_bids': m_bids, # mconf == 0 => gt matches
|
260 |
+
'm_bids_f': m_bids_f,
|
261 |
+
'mkpts0_c': mkpts0_c[mconf != 0],
|
262 |
+
'mkpts1_c': mkpts1_c[mconf != 0],
|
263 |
+
'mconf': mconf[mconf != 0]
|
264 |
+
})
|
265 |
+
|
266 |
+
return coarse_matches
|
imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from kornia.geometry.subpix import dsnt
|
7 |
+
from kornia.utils.grid import create_meshgrid
|
8 |
+
|
9 |
+
from loguru import logger
|
10 |
+
|
11 |
+
class FineMatching(nn.Module):
|
12 |
+
"""FineMatching with s2d paradigm"""
|
13 |
+
|
14 |
+
def __init__(self, config):
|
15 |
+
super().__init__()
|
16 |
+
self.config = config
|
17 |
+
self.topk = config['match_fine']['topk']
|
18 |
+
self.mtd_spvs = config['fine']['mtd_spvs']
|
19 |
+
self.align_corner = config['align_corner']
|
20 |
+
self.fix_bias = config['fix_bias']
|
21 |
+
self.normfinem = config['match_fine']['normfinem']
|
22 |
+
self.fix_fine_matching = config['match_fine']['fix_fine_matching']
|
23 |
+
self.mutual_nearest = config['match_fine']['force_nearest']
|
24 |
+
self.skip_fine_softmax = config['match_fine']['skip_fine_softmax']
|
25 |
+
self.normfeat = config['match_fine']['normfeat']
|
26 |
+
self.use_sigmoid = config['match_fine']['use_sigmoid']
|
27 |
+
self.local_regress = config['match_fine']['local_regress']
|
28 |
+
self.local_regress_rmborder = config['match_fine']['local_regress_rmborder']
|
29 |
+
self.local_regress_nomask = config['match_fine']['local_regress_nomask']
|
30 |
+
self.local_regress_temperature = config['match_fine']['local_regress_temperature']
|
31 |
+
self.local_regress_padone = config['match_fine']['local_regress_padone']
|
32 |
+
self.local_regress_slice = config['match_fine']['local_regress_slice']
|
33 |
+
self.local_regress_slicedim = config['match_fine']['local_regress_slicedim']
|
34 |
+
self.local_regress_inner = config['match_fine']['local_regress_inner']
|
35 |
+
self.multi_regress = config['match_fine']['multi_regress']
|
36 |
+
def forward(self, feat_0, feat_1, data):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
feat0 (torch.Tensor): [M, WW, C]
|
40 |
+
feat1 (torch.Tensor): [M, WW, C]
|
41 |
+
data (dict)
|
42 |
+
Update:
|
43 |
+
data (dict):{
|
44 |
+
'expec_f' (torch.Tensor): [M, 3],
|
45 |
+
'mkpts0_f' (torch.Tensor): [M, 2],
|
46 |
+
'mkpts1_f' (torch.Tensor): [M, 2]}
|
47 |
+
"""
|
48 |
+
M, WW, C = feat_0.shape
|
49 |
+
W = int(math.sqrt(WW))
|
50 |
+
if self.fix_bias:
|
51 |
+
scale = 2
|
52 |
+
else:
|
53 |
+
scale = data['hw0_i'][0] / data['hw0_f'][0]
|
54 |
+
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
|
55 |
+
|
56 |
+
# corner case: if no coarse matches found
|
57 |
+
if M == 0:
|
58 |
+
assert self.training == False, "M is always >0, when training, see coarse_matching.py"
|
59 |
+
# logger.warning('No matches found in coarse-level.')
|
60 |
+
if self.mtd_spvs:
|
61 |
+
data.update({
|
62 |
+
'conf_matrix_f': torch.empty(0, WW, WW, device=feat_0.device),
|
63 |
+
'mkpts0_f': data['mkpts0_c'],
|
64 |
+
'mkpts1_f': data['mkpts1_c'],
|
65 |
+
})
|
66 |
+
# if self.local_regress:
|
67 |
+
# data.update({
|
68 |
+
# 'sim_matrix_f': torch.empty(0, WW, WW, device=feat_0.device),
|
69 |
+
# })
|
70 |
+
return
|
71 |
+
else:
|
72 |
+
data.update({
|
73 |
+
'expec_f': torch.empty(0, 3, device=feat_0.device),
|
74 |
+
'mkpts0_f': data['mkpts0_c'],
|
75 |
+
'mkpts1_f': data['mkpts1_c'],
|
76 |
+
})
|
77 |
+
return
|
78 |
+
|
79 |
+
if self.mtd_spvs:
|
80 |
+
with torch.autocast(enabled=False, device_type='cuda'):
|
81 |
+
# feat_0 = feat_0 / feat_0.size(-2)
|
82 |
+
if self.local_regress_slice:
|
83 |
+
feat_ff0, feat_ff1 = feat_0[...,-self.local_regress_slicedim:], feat_1[...,-self.local_regress_slicedim:]
|
84 |
+
feat_f0, feat_f1 = feat_0[...,:-self.local_regress_slicedim], feat_1[...,:-self.local_regress_slicedim]
|
85 |
+
conf_matrix_ff = torch.einsum('mlc,mrc->mlr', feat_ff0, feat_ff1 / (self.local_regress_slicedim)**.5)
|
86 |
+
else:
|
87 |
+
feat_f0, feat_f1 = feat_0, feat_1
|
88 |
+
if self.normfinem:
|
89 |
+
feat_f0 = feat_f0 / C**.5
|
90 |
+
feat_f1 = feat_f1 / C**.5
|
91 |
+
conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1)
|
92 |
+
else:
|
93 |
+
if self.local_regress_slice:
|
94 |
+
conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1 / (C - self.local_regress_slicedim)**.5)
|
95 |
+
else:
|
96 |
+
conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1 / C**.5)
|
97 |
+
|
98 |
+
if self.normfeat:
|
99 |
+
feat_f0, feat_f1 = torch.nn.functional.normalize(feat_f0.float(), p=2, dim=-1), torch.nn.functional.normalize(feat_f1.float(), p=2, dim=-1)
|
100 |
+
|
101 |
+
if self.config['fp16log']:
|
102 |
+
logger.info(f'sim_matrix: {conf_matrix_f.abs().max()}')
|
103 |
+
# sim_matrix *= 1. / C**.5 # normalize
|
104 |
+
|
105 |
+
if self.multi_regress:
|
106 |
+
assert not self.local_regress
|
107 |
+
assert not self.normfinem and not self.normfeat
|
108 |
+
heatmap = F.softmax(conf_matrix_f, 2).view(M, WW, W, W) # [M, WW, W, W]
|
109 |
+
|
110 |
+
assert (W - 2) == (self.config['resolution'][0] // self.config['resolution'][1]) # c8
|
111 |
+
windows_scale = (W - 1) / (self.config['resolution'][0] // self.config['resolution'][1])
|
112 |
+
|
113 |
+
coords_normalized = dsnt.spatial_expectation2d(heatmap, True) * windows_scale # [M, WW, 2]
|
114 |
+
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2)[:,None,:,:] * windows_scale # [1, 1, WW, 2]
|
115 |
+
|
116 |
+
# compute std over <x, y>
|
117 |
+
var = torch.sum(grid_normalized**2 * heatmap.view(M, WW, WW, 1), dim=-2) - coords_normalized**2 # ([1,1,WW,2] * [M,WW,WW,1])->[M,WW,2]
|
118 |
+
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M,WW] clamp needed for numerical stability
|
119 |
+
|
120 |
+
# for fine-level supervision
|
121 |
+
data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(-1)], -1)}) # [M, WW, 2]
|
122 |
+
|
123 |
+
# get the least uncertain matches
|
124 |
+
val, idx = torch.topk(std, self.topk, dim=-1, largest=False) # [M,topk]
|
125 |
+
coords_normalized = coords_normalized[torch.arange(M, device=conf_matrix_f.device, dtype=torch.long)[:,None], idx] # [M,topk]
|
126 |
+
|
127 |
+
grid = create_meshgrid(W, W, False, idx.device) - W // 2 + 0.5 # [1, W, W, 2]
|
128 |
+
grid = grid.reshape(1, -1, 2).expand(M, -1, -1) # [M, WW, 2]
|
129 |
+
delta_l = torch.gather(grid, 1, idx.unsqueeze(-1).expand(-1, -1, 2)) # [M, topk, 2] in (x, y)
|
130 |
+
|
131 |
+
# compute absolute kpt coords
|
132 |
+
self.get_multi_fine_match_align(delta_l, coords_normalized, data)
|
133 |
+
|
134 |
+
|
135 |
+
else:
|
136 |
+
|
137 |
+
if self.skip_fine_softmax:
|
138 |
+
pass
|
139 |
+
elif self.use_sigmoid:
|
140 |
+
conf_matrix_f = torch.sigmoid(conf_matrix_f)
|
141 |
+
else:
|
142 |
+
if self.local_regress:
|
143 |
+
del feat_f0, feat_f1
|
144 |
+
softmax_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2)
|
145 |
+
# softmax_matrix_f = conf_matrix_f
|
146 |
+
if self.local_regress_inner:
|
147 |
+
softmax_matrix_f = softmax_matrix_f.reshape(M, self.WW, self.W+2, self.W+2)
|
148 |
+
softmax_matrix_f = softmax_matrix_f[...,1:-1,1:-1].reshape(M, self.WW, self.WW)
|
149 |
+
# if self.training:
|
150 |
+
# for fine-level supervision
|
151 |
+
data.update({'conf_matrix_f': softmax_matrix_f})
|
152 |
+
if self.local_regress_slice:
|
153 |
+
data.update({'sim_matrix_ff': conf_matrix_ff})
|
154 |
+
else:
|
155 |
+
data.update({'sim_matrix_f': conf_matrix_f})
|
156 |
+
|
157 |
+
else:
|
158 |
+
conf_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2)
|
159 |
+
|
160 |
+
# for fine-level supervision
|
161 |
+
data.update({'conf_matrix_f': conf_matrix_f})
|
162 |
+
|
163 |
+
# compute absolute kpt coords
|
164 |
+
if self.local_regress:
|
165 |
+
self.get_fine_ds_match(softmax_matrix_f, data)
|
166 |
+
del softmax_matrix_f
|
167 |
+
idx_l, idx_r = data['idx_l'], data['idx_r']
|
168 |
+
del data['idx_l'], data['idx_r']
|
169 |
+
m_ids = torch.arange(M, device=idx_l.device, dtype=torch.long).unsqueeze(-1).expand(-1, self.topk)
|
170 |
+
# if self.training:
|
171 |
+
m_ids = m_ids[:len(data['mconf']) // self.topk]
|
172 |
+
idx_r_iids, idx_r_jids = idx_r // W, idx_r % W
|
173 |
+
|
174 |
+
# remove boarder
|
175 |
+
if self.local_regress_nomask:
|
176 |
+
# log for inner precent
|
177 |
+
# mask = (idx_r_iids >= 1) & (idx_r_iids <= W-2) & (idx_r_jids >= 1) & (idx_r_jids <= W-2)
|
178 |
+
# mask_sum = mask.sum()
|
179 |
+
# logger.info(f'total fine match: {mask.numel()}; regressed fine match: {mask_sum}, per: {mask_sum / mask.numel()}')
|
180 |
+
mask = None
|
181 |
+
m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1)
|
182 |
+
if self.local_regress_inner: # been sliced before
|
183 |
+
delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) # [1, 3, 3, 2]
|
184 |
+
else:
|
185 |
+
# no mask + 1 for padding
|
186 |
+
delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) + torch.tensor([1], dtype=torch.long, device=conf_matrix_f.device) # [1, 3, 3, 2]
|
187 |
+
|
188 |
+
m_ids = m_ids[...,None,None].expand(-1, 3, 3)
|
189 |
+
idx_l = idx_l[...,None,None].expand(-1, 3, 3) # [m, k, 3, 3]
|
190 |
+
|
191 |
+
idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1]
|
192 |
+
idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0]
|
193 |
+
|
194 |
+
if idx_l.numel() == 0:
|
195 |
+
data.update({
|
196 |
+
'mkpts0_f': data['mkpts0_c'],
|
197 |
+
'mkpts1_f': data['mkpts1_c'],
|
198 |
+
})
|
199 |
+
return
|
200 |
+
|
201 |
+
if self.local_regress_slice:
|
202 |
+
conf_matrix_f = conf_matrix_ff
|
203 |
+
if self.local_regress_inner:
|
204 |
+
conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W+2, self.W+2)
|
205 |
+
else:
|
206 |
+
conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W, self.W)
|
207 |
+
conf_matrix_f = F.pad(conf_matrix_f, (1,1,1,1))
|
208 |
+
else:
|
209 |
+
mask = (idx_r_iids >= 1) & (idx_r_iids <= W-2) & (idx_r_jids >= 1) & (idx_r_jids <= W-2)
|
210 |
+
if W == 10:
|
211 |
+
idx_l_iids, idx_l_jids = idx_l // W, idx_l % W
|
212 |
+
mask = mask & (idx_l_iids >= 1) & (idx_l_iids <= W-2) & (idx_l_jids >= 1) & (idx_l_jids <= W-2)
|
213 |
+
|
214 |
+
m_ids = m_ids[mask].to(torch.long)
|
215 |
+
idx_l, idx_r_iids, idx_r_jids = idx_l[mask].to(torch.long), idx_r_iids[mask].to(torch.long), idx_r_jids[mask].to(torch.long)
|
216 |
+
|
217 |
+
m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1)
|
218 |
+
mask = mask.reshape(-1)
|
219 |
+
|
220 |
+
delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) # [1, 3, 3, 2]
|
221 |
+
|
222 |
+
m_ids = m_ids[:,None,None].expand(-1, 3, 3)
|
223 |
+
idx_l = idx_l[:,None,None].expand(-1, 3, 3) # [m, 3, 3]
|
224 |
+
# bug !!!!!!!!! 1,0 rather 0,1
|
225 |
+
# idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0]
|
226 |
+
# idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1]
|
227 |
+
idx_r_iids = idx_r_iids[:,None,None].expand(-1, 3, 3) + delta[..., 1]
|
228 |
+
idx_r_jids = idx_r_jids[:,None,None].expand(-1, 3, 3) + delta[..., 0]
|
229 |
+
|
230 |
+
if idx_l.numel() == 0:
|
231 |
+
data.update({
|
232 |
+
'mkpts0_f': data['mkpts0_c'],
|
233 |
+
'mkpts1_f': data['mkpts1_c'],
|
234 |
+
})
|
235 |
+
return
|
236 |
+
if not self.local_regress_slice:
|
237 |
+
conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W, self.W)
|
238 |
+
else:
|
239 |
+
conf_matrix_f = conf_matrix_ff.reshape(M, self.WW, self.W, self.W)
|
240 |
+
|
241 |
+
conf_matrix_f = conf_matrix_f[m_ids, idx_l, idx_r_iids, idx_r_jids]
|
242 |
+
conf_matrix_f = conf_matrix_f.reshape(-1, 9)
|
243 |
+
if self.local_regress_padone: # follow the training detach the gradient of center
|
244 |
+
conf_matrix_f[:,4] = -1e4
|
245 |
+
heatmap = F.softmax(conf_matrix_f / self.local_regress_temperature, -1)
|
246 |
+
logger.info(f'maxmax&maxmean of heatmap: {heatmap.view(-1).max()}, {heatmap.view(-1).min(), heatmap.max(-1)[0].mean()}')
|
247 |
+
heatmap[:,4] = 1.0 # no need gradient calculation in inference
|
248 |
+
logger.info(f'min of heatmap: {heatmap.view(-1).min()}')
|
249 |
+
heatmap = heatmap.reshape(-1, 3, 3)
|
250 |
+
# heatmap = torch.ones_like(softmax) # ones_like for detach the gradient of center
|
251 |
+
# heatmap[:,:4], heatmap[:,5:] = softmax[:,:4], softmax[:,5:]
|
252 |
+
# heatmap = heatmap.reshape(-1, 3, 3)
|
253 |
+
else:
|
254 |
+
conf_matrix_f = F.softmax(conf_matrix_f / self.local_regress_temperature, -1)
|
255 |
+
# logger.info(f'max&min&mean of heatmap: {conf_matrix_f.view(-1).max()}, {conf_matrix_f.view(-1).min(), conf_matrix_f.max(-1)[0].mean()}')
|
256 |
+
heatmap = conf_matrix_f.reshape(-1, 3, 3)
|
257 |
+
|
258 |
+
# compute coordinates from heatmap
|
259 |
+
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
|
260 |
+
|
261 |
+
# coords_normalized_l2 = coords_normalized.norm(p=2, dim=-1)
|
262 |
+
# logger.info(f'mean&max&min abs of local: {coords_normalized_l2.mean(), coords_normalized_l2.max(), coords_normalized_l2.min()}')
|
263 |
+
|
264 |
+
# compute absolute kpt coords
|
265 |
+
|
266 |
+
if data['bs'] == 1:
|
267 |
+
scale1 = scale * data['scale1'] if 'scale0' in data else scale
|
268 |
+
else:
|
269 |
+
if mask is not None:
|
270 |
+
scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']) // self.topk,...][:,None,:].expand(-1, self.topk, 2).reshape(-1, 2)[mask] if 'scale0' in data else scale
|
271 |
+
else:
|
272 |
+
scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']) // self.topk,...][:,None,:].expand(-1, self.topk, 2).reshape(-1, 2) if 'scale0' in data else scale
|
273 |
+
|
274 |
+
self.get_fine_match_local(coords_normalized, data, scale1, mask, True)
|
275 |
+
|
276 |
+
else:
|
277 |
+
self.get_fine_ds_match(conf_matrix_f, data)
|
278 |
+
|
279 |
+
|
280 |
+
else:
|
281 |
+
if self.align_corner is True:
|
282 |
+
feat_f0, feat_f1 = feat_0, feat_1
|
283 |
+
feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
|
284 |
+
sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
|
285 |
+
softmax_temp = 1. / C**.5
|
286 |
+
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
|
287 |
+
|
288 |
+
# compute coordinates from heatmap
|
289 |
+
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
|
290 |
+
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
|
291 |
+
|
292 |
+
# compute std over <x, y>
|
293 |
+
var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
|
294 |
+
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
|
295 |
+
|
296 |
+
# for fine-level supervision
|
297 |
+
data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
|
298 |
+
|
299 |
+
# compute absolute kpt coords
|
300 |
+
self.get_fine_match(coords_normalized, data)
|
301 |
+
else:
|
302 |
+
feat_f0, feat_f1 = feat_0, feat_1
|
303 |
+
# even matching windows while coarse grid not aligned to fine grid!!!
|
304 |
+
# assert W == 5, "others size not checked"
|
305 |
+
if self.fix_bias:
|
306 |
+
assert W % 2 == 1, "W must be odd when select"
|
307 |
+
feat_f0_picked = feat_f0[:, WW//2]
|
308 |
+
|
309 |
+
else:
|
310 |
+
# assert W == 6, "others size not checked"
|
311 |
+
assert W % 2 == 0, "W must be even when coarse grid not aligned to fine grid(average)"
|
312 |
+
feat_f0_picked = (feat_f0[:, WW//2 - W//2 - 1] + feat_f0[:, WW//2 - W//2] + feat_f0[:, WW//2 + W//2] + feat_f0[:, WW//2 + W//2 - 1]) / 4
|
313 |
+
sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
|
314 |
+
softmax_temp = 1. / C**.5
|
315 |
+
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
|
316 |
+
|
317 |
+
# compute coordinates from heatmap
|
318 |
+
windows_scale = (W - 1) / (self.config['resolution'][0] // self.config['resolution'][1])
|
319 |
+
|
320 |
+
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] * windows_scale # [M, 2]
|
321 |
+
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) * windows_scale # [1, WW, 2]
|
322 |
+
|
323 |
+
# compute std over <x, y>
|
324 |
+
var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
|
325 |
+
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
|
326 |
+
|
327 |
+
# for fine-level supervision
|
328 |
+
data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
|
329 |
+
|
330 |
+
# compute absolute kpt coords
|
331 |
+
self.get_fine_match_align(coords_normalized, data)
|
332 |
+
|
333 |
+
|
334 |
+
@torch.no_grad()
|
335 |
+
def get_fine_match(self, coords_normed, data):
|
336 |
+
W, WW, C, scale = self.W, self.WW, self.C, self.scale
|
337 |
+
|
338 |
+
# mkpts0_f and mkpts1_f
|
339 |
+
mkpts0_f = data['mkpts0_c']
|
340 |
+
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
|
341 |
+
mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
|
342 |
+
|
343 |
+
data.update({
|
344 |
+
"mkpts0_f": mkpts0_f,
|
345 |
+
"mkpts1_f": mkpts1_f
|
346 |
+
})
|
347 |
+
|
348 |
+
def get_fine_match_local(self, coords_normed, data, scale1, mask, reserve_border=True):
|
349 |
+
W, WW, C, scale = self.W, self.WW, self.C, self.scale
|
350 |
+
|
351 |
+
if mask is None:
|
352 |
+
mkpts0_c, mkpts1_c = data['mkpts0_c'], data['mkpts1_c']
|
353 |
+
else:
|
354 |
+
data['mkpts0_c'], data['mkpts1_c'] = data['mkpts0_c'].reshape(-1, 2), data['mkpts1_c'].reshape(-1, 2)
|
355 |
+
mkpts0_c, mkpts1_c = data['mkpts0_c'][mask], data['mkpts1_c'][mask]
|
356 |
+
mask_sum = mask.sum()
|
357 |
+
logger.info(f'total fine match: {mask.numel()}; regressed fine match: {mask_sum}, per: {mask_sum / mask.numel()}')
|
358 |
+
# print(mkpts0_c.shape, mkpts1_c.shape, coords_normed.shape, scale1.shape)
|
359 |
+
# print(data['mkpts0_c'].shape, data['mkpts1_c'].shape)
|
360 |
+
# mkpts0_f and mkpts1_f
|
361 |
+
mkpts0_f = mkpts0_c
|
362 |
+
mkpts1_f = mkpts1_c + (coords_normed * (3 // 2) * scale1)
|
363 |
+
|
364 |
+
if reserve_border and mask is not None:
|
365 |
+
mkpts0_f, mkpts1_f = torch.cat([mkpts0_f, data['mkpts0_c'][~mask].reshape(-1, 2)]), torch.cat([mkpts1_f, data['mkpts1_c'][~mask].reshape(-1, 2)])
|
366 |
+
else:
|
367 |
+
pass
|
368 |
+
|
369 |
+
del data['mkpts0_c'], data['mkpts1_c']
|
370 |
+
data.update({
|
371 |
+
"mkpts0_f": mkpts0_f,
|
372 |
+
"mkpts1_f": mkpts1_f
|
373 |
+
})
|
374 |
+
|
375 |
+
# can be used for both aligned and not aligned
|
376 |
+
@torch.no_grad()
|
377 |
+
def get_fine_match_align(self, coord_normed, data):
|
378 |
+
W, WW, C, scale = self.W, self.WW, self.C, self.scale
|
379 |
+
c2f = self.config['resolution'][0] // self.config['resolution'][1]
|
380 |
+
# mkpts0_f and mkpts1_f
|
381 |
+
mkpts0_f = data['mkpts0_c']
|
382 |
+
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
|
383 |
+
mkpts1_f = data['mkpts1_c'] + (coord_normed * (c2f // 2) * scale1)[:len(data['mconf'])]
|
384 |
+
|
385 |
+
data.update({
|
386 |
+
"mkpts0_f": mkpts0_f,
|
387 |
+
"mkpts1_f": mkpts1_f
|
388 |
+
})
|
389 |
+
|
390 |
+
@torch.no_grad()
|
391 |
+
def get_multi_fine_match_align(self, delta_l, coord_normed, data):
|
392 |
+
W, WW, C, scale = self.W, self.WW, self.C, self.scale
|
393 |
+
c2f = self.config['resolution'][0] // self.config['resolution'][1]
|
394 |
+
# mkpts0_f and mkpts1_f
|
395 |
+
scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else torch.tensor([[scale, scale]], device=delta_l.device)
|
396 |
+
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else torch.tensor([[scale, scale]], device=delta_l.device)
|
397 |
+
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
|
398 |
+
mkpts1_f = (data['mkpts1_c'][:,None,:] + (coord_normed * (c2f // 2) * scale1[:,None,:])[:len(data['mconf'])]).reshape(-1, 2)
|
399 |
+
|
400 |
+
data.update({
|
401 |
+
"mkpts0_f": mkpts0_f,
|
402 |
+
"mkpts1_f": mkpts1_f,
|
403 |
+
"mconf": data['mconf'][:,None].expand(-1, self.topk).reshape(-1)
|
404 |
+
})
|
405 |
+
|
406 |
+
@torch.no_grad()
|
407 |
+
def get_fine_ds_match(self, conf_matrix, data):
|
408 |
+
W, WW, C, scale = self.W, self.WW, self.C, self.scale
|
409 |
+
|
410 |
+
# select topk matches
|
411 |
+
m, _, _ = conf_matrix.shape
|
412 |
+
|
413 |
+
|
414 |
+
if self.mutual_nearest:
|
415 |
+
pass
|
416 |
+
|
417 |
+
|
418 |
+
elif not self.fix_fine_matching: # only allow one2mul but mul2one
|
419 |
+
|
420 |
+
val, idx_r = conf_matrix.max(-1) # (m, WW), (m, WW)
|
421 |
+
val, idx_l = torch.topk(val, self.topk, dim = -1) # (m, topk), (m, topk)
|
422 |
+
idx_r = torch.gather(idx_r, 1, idx_l) # (m, topk)
|
423 |
+
|
424 |
+
# mkpts0_c use xy coordinate, so we don't need to convert it to hw coordinate
|
425 |
+
# grid = create_meshgrid(W, W, False, conf_matrix.device).transpose(-3,-2) - W // 2 + 0.5 # (1, W, W, 2)
|
426 |
+
grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5 # (1, W, W, 2)
|
427 |
+
grid = grid.reshape(1, -1, 2).expand(m, -1, -1) # (m, WW, 2)
|
428 |
+
delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2)) # (m, topk, 2)
|
429 |
+
delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2)) # (m, topk, 2)
|
430 |
+
|
431 |
+
# mkpts0_f and mkpts1_f
|
432 |
+
scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
|
433 |
+
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
|
434 |
+
|
435 |
+
if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1
|
436 |
+
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
|
437 |
+
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
|
438 |
+
else: # scale0 is a float
|
439 |
+
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)[:len(data['mconf']),...]).reshape(-1, 2)
|
440 |
+
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)[:len(data['mconf']),...]).reshape(-1, 2)
|
441 |
+
|
442 |
+
else: # allow one2mul mul2one and mul2mul
|
443 |
+
conf_matrix = conf_matrix.reshape(m, -1)
|
444 |
+
if self.local_regress: # for the compatibility of former config
|
445 |
+
conf_matrix = conf_matrix[:len(data['mconf']),...]
|
446 |
+
val, idx = torch.topk(conf_matrix, self.topk, dim = -1)
|
447 |
+
idx_l = idx // WW
|
448 |
+
idx_r = idx % WW
|
449 |
+
|
450 |
+
if self.local_regress:
|
451 |
+
data.update({'idx_l': idx_l, 'idx_r': idx_r})
|
452 |
+
|
453 |
+
# mkpts0_c use xy coordinate, so we don't need to convert it to hw coordinate
|
454 |
+
# grid = create_meshgrid(W, W, False, conf_matrix.device).transpose(-3,-2) - W // 2 + 0.5 # (1, W, W, 2)
|
455 |
+
grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5
|
456 |
+
grid = grid.reshape(1, -1, 2).expand(m, -1, -1)
|
457 |
+
delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2))
|
458 |
+
delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2))
|
459 |
+
|
460 |
+
# mkpts0_f and mkpts1_f
|
461 |
+
scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
|
462 |
+
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
|
463 |
+
|
464 |
+
if self.local_regress:
|
465 |
+
if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1
|
466 |
+
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
|
467 |
+
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
|
468 |
+
else: # scale0 is a float
|
469 |
+
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)).reshape(-1, 2)
|
470 |
+
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)).reshape(-1, 2)
|
471 |
+
|
472 |
+
else:
|
473 |
+
if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1
|
474 |
+
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
|
475 |
+
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
|
476 |
+
else: # scale0 is a float
|
477 |
+
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)[:len(data['mconf']),...]).reshape(-1, 2)
|
478 |
+
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)[:len(data['mconf']),...]).reshape(-1, 2)
|
479 |
+
del data['mkpts0_c'], data['mkpts1_c']
|
480 |
+
data['mconf'] = data['mconf'].reshape(-1, 1).expand(-1, self.topk).reshape(-1)
|
481 |
+
# data['mconf'] = val.reshape(-1)[:len(data['mconf'])]*0.1 + data['mconf']
|
482 |
+
|
483 |
+
if self.local_regress:
|
484 |
+
data.update({
|
485 |
+
"mkpts0_c": mkpts0_f,
|
486 |
+
"mkpts1_c": mkpts1_f
|
487 |
+
})
|
488 |
+
else:
|
489 |
+
data.update({
|
490 |
+
"mkpts0_f": mkpts0_f,
|
491 |
+
"mkpts1_f": mkpts1_f
|
492 |
+
})
|
493 |
+
|
imcui/third_party/MatchAnything/src/loftr/utils/geometry.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from src.utils.homography_utils import warp_points_torch
|
3 |
+
|
4 |
+
def get_unique_indices(input_tensor):
|
5 |
+
if input_tensor.shape[0] > 1:
|
6 |
+
unique, inverse = torch.unique(input_tensor, sorted=True, return_inverse=True, dim=0)
|
7 |
+
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
|
8 |
+
inverse, perm = inverse.flip([0]), perm.flip([0])
|
9 |
+
perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
|
10 |
+
else:
|
11 |
+
perm = torch.zeros((input_tensor.shape[0],), dtype=torch.long, device=input_tensor.device)
|
12 |
+
return perm
|
13 |
+
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, consistency_thr=0.2, cycle_proj_distance_thr=3.0):
|
17 |
+
""" Warp kpts0 from I0 to I1 with depth, K and Rt
|
18 |
+
Also check covisibility and depth consistency.
|
19 |
+
Depth is consistent if relative error < 0.2 (hard-coded).
|
20 |
+
|
21 |
+
Args:
|
22 |
+
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
|
23 |
+
depth0 (torch.Tensor): [N, H, W],
|
24 |
+
depth1 (torch.Tensor): [N, H, W],
|
25 |
+
T_0to1 (torch.Tensor): [N, 3, 4],
|
26 |
+
K0 (torch.Tensor): [N, 3, 3],
|
27 |
+
K1 (torch.Tensor): [N, 3, 3],
|
28 |
+
Returns:
|
29 |
+
calculable_mask (torch.Tensor): [N, L]
|
30 |
+
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
|
31 |
+
"""
|
32 |
+
kpts0_long = kpts0.round().long()
|
33 |
+
|
34 |
+
# Sample depth, get calculable_mask on depth != 0
|
35 |
+
kpts0_depth = torch.stack(
|
36 |
+
[depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
|
37 |
+
) # (N, L)
|
38 |
+
nonzero_mask = kpts0_depth != 0
|
39 |
+
|
40 |
+
# Unproject
|
41 |
+
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
|
42 |
+
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
|
43 |
+
|
44 |
+
# Rigid Transform
|
45 |
+
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
|
46 |
+
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
|
47 |
+
|
48 |
+
# Project
|
49 |
+
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
|
50 |
+
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
|
51 |
+
|
52 |
+
# Covisible Check
|
53 |
+
h, w = depth1.shape[1:3]
|
54 |
+
covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
|
55 |
+
(w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
|
56 |
+
w_kpts0_long = w_kpts0.long()
|
57 |
+
w_kpts0_long[~covisible_mask, :] = 0
|
58 |
+
|
59 |
+
w_kpts0_depth = torch.stack(
|
60 |
+
[depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
|
61 |
+
) # (N, L)
|
62 |
+
consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < consistency_thr
|
63 |
+
|
64 |
+
# Cycle Consistency Check
|
65 |
+
dst_pts_h = torch.cat([w_kpts0, torch.ones_like(w_kpts0[..., [0]], device=w_kpts0.device)], dim=-1) * w_kpts0_depth[..., None] # B * N_dst * N_pts * 3
|
66 |
+
dst_pts_cam = K1.inverse() @ dst_pts_h.transpose(2, 1) # (N, 3, L)
|
67 |
+
dst_pose = T_0to1.inverse()
|
68 |
+
world_points_cycle_back = dst_pose[:, :3, :3] @ dst_pts_cam + dst_pose[:, :3, [3]]
|
69 |
+
src_warp_back_h = (K0 @ world_points_cycle_back).transpose(2, 1) # (N, L, 3)
|
70 |
+
src_back_proj_pts = src_warp_back_h[..., :2] / (src_warp_back_h[..., [2]] + 1e-4)
|
71 |
+
cycle_reproj_distance_mask = torch.linalg.norm(src_back_proj_pts - kpts0[:, None], dim=-1) < cycle_proj_distance_thr
|
72 |
+
|
73 |
+
valid_mask = nonzero_mask * covisible_mask * consistent_mask * cycle_reproj_distance_mask
|
74 |
+
|
75 |
+
return valid_mask, w_kpts0
|
76 |
+
|
77 |
+
@torch.no_grad()
|
78 |
+
def warp_kpts_by_sparse_gt_matches_batches(kpts0, gt_matches, dist_thr):
|
79 |
+
B, n_pts = kpts0.shape[0], kpts0.shape[1]
|
80 |
+
if n_pts > 20 * 10000:
|
81 |
+
all_kpts_valid_mask, all_kpts_warpped = [], []
|
82 |
+
for b_id in range(B):
|
83 |
+
kpts_valid_mask, kpts_warpped = warp_kpts_by_sparse_gt_matches(kpts0[[b_id]], gt_matches[[b_id]], dist_thr[[b_id]])
|
84 |
+
all_kpts_valid_mask.append(kpts_valid_mask)
|
85 |
+
all_kpts_warpped.append(kpts_warpped)
|
86 |
+
return torch.cat(all_kpts_valid_mask, dim=0), torch.cat(all_kpts_warpped, dim=0)
|
87 |
+
else:
|
88 |
+
return warp_kpts_by_sparse_gt_matches(kpts0, gt_matches, dist_thr)
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def warp_kpts_by_sparse_gt_matches(kpts0, gt_matches, dist_thr):
|
92 |
+
kpts_warpped = torch.zeros_like(kpts0)
|
93 |
+
kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool)
|
94 |
+
gt_matches_non_padding_mask = gt_matches.sum(-1) > 0
|
95 |
+
|
96 |
+
dist_matrix = torch.cdist(kpts0, gt_matches[..., :2]) # B * N * M
|
97 |
+
if dist_thr is not None:
|
98 |
+
mask = dist_matrix < dist_thr[:, None, None]
|
99 |
+
else:
|
100 |
+
mask = torch.ones_like(dist_matrix, dtype=torch.bool)
|
101 |
+
# Mutual-Nearest check:
|
102 |
+
mask = mask \
|
103 |
+
* (dist_matrix == dist_matrix.min(dim=2, keepdim=True)[0]) \
|
104 |
+
* (dist_matrix == dist_matrix.min(dim=1, keepdim=True)[0])
|
105 |
+
|
106 |
+
mask_v, all_j_ids = mask.max(dim=2)
|
107 |
+
b_ids, i_ids = torch.where(mask_v)
|
108 |
+
j_ids = all_j_ids[b_ids, i_ids]
|
109 |
+
|
110 |
+
j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1))
|
111 |
+
b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids])
|
112 |
+
|
113 |
+
i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1))
|
114 |
+
b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids])
|
115 |
+
|
116 |
+
kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[b_ids, j_ids]
|
117 |
+
kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][b_ids, j_ids]
|
118 |
+
|
119 |
+
return kpts_valid_mask, kpts_warpped
|
120 |
+
|
121 |
+
@torch.no_grad()
|
122 |
+
def warp_kpts_by_sparse_gt_matches_fine_chunks(kpts0, gt_matches, dist_thr):
|
123 |
+
B, n_pts = kpts0.shape[0], kpts0.shape[1]
|
124 |
+
chunk_n = 500
|
125 |
+
all_kpts_valid_mask, all_kpts_warpped = [], []
|
126 |
+
for b_id in range(0, B, chunk_n):
|
127 |
+
kpts_valid_mask, kpts_warpped = warp_kpts_by_sparse_gt_matches_fine(kpts0[b_id : b_id+chunk_n], gt_matches, dist_thr)
|
128 |
+
all_kpts_valid_mask.append(kpts_valid_mask)
|
129 |
+
all_kpts_warpped.append(kpts_warpped)
|
130 |
+
return torch.cat(all_kpts_valid_mask, dim=0), torch.cat(all_kpts_warpped, dim=0)
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
def warp_kpts_by_sparse_gt_matches_fine(kpts0, gt_matches, dist_thr):
|
134 |
+
"""
|
135 |
+
Only support single batch
|
136 |
+
Input:
|
137 |
+
kpts0: N * ww * 2
|
138 |
+
gt_matches: M * 2
|
139 |
+
"""
|
140 |
+
B = kpts0.shape[0] # B is the fine matches in a single pair
|
141 |
+
assert gt_matches.shape[0] == 1
|
142 |
+
kpts_warpped = torch.zeros_like(kpts0)
|
143 |
+
kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool)
|
144 |
+
gt_matches_non_padding_mask = gt_matches.sum(-1) > 0
|
145 |
+
|
146 |
+
dist_matrix = torch.cdist(kpts0, gt_matches[..., :2]) # B * N * M
|
147 |
+
if dist_thr is not None:
|
148 |
+
mask = dist_matrix < dist_thr[:, None, None]
|
149 |
+
else:
|
150 |
+
mask = torch.ones_like(dist_matrix, dtype=torch.bool)
|
151 |
+
# Mutual-Nearest check:
|
152 |
+
mask = mask \
|
153 |
+
* (dist_matrix == dist_matrix.min(dim=2, keepdim=True)[0]) \
|
154 |
+
* (dist_matrix == dist_matrix.min(dim=1, keepdim=True)[0])
|
155 |
+
|
156 |
+
mask_v, all_j_ids = mask.max(dim=2)
|
157 |
+
b_ids, i_ids = torch.where(mask_v)
|
158 |
+
j_ids = all_j_ids[b_ids, i_ids]
|
159 |
+
|
160 |
+
j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1))
|
161 |
+
b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids])
|
162 |
+
|
163 |
+
i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1))
|
164 |
+
b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids])
|
165 |
+
|
166 |
+
kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[0, j_ids]
|
167 |
+
kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][0, j_ids]
|
168 |
+
|
169 |
+
return kpts_valid_mask, kpts_warpped
|
170 |
+
|
171 |
+
@torch.no_grad()
|
172 |
+
def warp_kpts_by_sparse_gt_matches_fast(kpts0, gt_matches, scale0, current_h, current_w):
|
173 |
+
B, n_gt_pts = gt_matches.shape[0], gt_matches.shape[1]
|
174 |
+
kpts_warpped = torch.zeros_like(kpts0)
|
175 |
+
kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool)
|
176 |
+
gt_matches_non_padding_mask = gt_matches.sum(-1) > 0
|
177 |
+
|
178 |
+
all_j_idxs = torch.arange(gt_matches.shape[-2], device=gt_matches.device, dtype=torch.long)[None].expand(B, n_gt_pts)
|
179 |
+
all_b_idxs = torch.arange(B, device=gt_matches.device, dtype=torch.long)[:, None].expand(B, n_gt_pts)
|
180 |
+
gt_matches_rescale = gt_matches[..., :2] / scale0 # From original img scale to resized scale
|
181 |
+
in_boundary_mask = (gt_matches_rescale[..., 0] <= current_w-1) & (gt_matches_rescale[..., 0] >= 0) & (gt_matches_rescale[..., 1] <= current_h -1) & (gt_matches_rescale[..., 1] >= 0)
|
182 |
+
|
183 |
+
gt_matches_rescale = gt_matches_rescale.round().to(torch.long)
|
184 |
+
all_i_idxs = gt_matches_rescale[..., 1] * current_w + gt_matches_rescale[..., 0] # idx = y * w + x
|
185 |
+
|
186 |
+
# Filter:
|
187 |
+
b_ids, i_ids, j_ids = map(lambda x: x[gt_matches_non_padding_mask & in_boundary_mask], [all_b_idxs, all_i_idxs, all_j_idxs])
|
188 |
+
|
189 |
+
j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1))
|
190 |
+
b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids])
|
191 |
+
|
192 |
+
i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1))
|
193 |
+
b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids])
|
194 |
+
|
195 |
+
kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[b_ids, j_ids]
|
196 |
+
kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][b_ids, j_ids]
|
197 |
+
|
198 |
+
return kpts_valid_mask, kpts_warpped
|
199 |
+
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def homo_warp_kpts(kpts0, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None):
|
203 |
+
"""
|
204 |
+
original_size1: N * 2, (h, w)
|
205 |
+
"""
|
206 |
+
normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L)
|
207 |
+
kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3)
|
208 |
+
kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2
|
209 |
+
valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
|
210 |
+
& (kpts_warpped[..., 1] < original_size1[:, [0]]) # N * L
|
211 |
+
if original_size0 is not None:
|
212 |
+
valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
|
213 |
+
& (kpts0[..., 1] < original_size0[:, [0]]) # N * L
|
214 |
+
|
215 |
+
return valid_mask, kpts_warpped
|
216 |
+
|
217 |
+
@torch.no_grad()
|
218 |
+
# if using mask in homo warp(for coarse supervision)
|
219 |
+
def homo_warp_kpts_with_mask(kpts0, scale, depth_mask, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None):
|
220 |
+
"""
|
221 |
+
original_size1: N * 2, (h, w)
|
222 |
+
"""
|
223 |
+
normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L)
|
224 |
+
kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3)
|
225 |
+
kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2
|
226 |
+
# get coarse-level depth_mask
|
227 |
+
depth_mask_coarse = depth_mask[:, :, ::scale, ::scale]
|
228 |
+
depth_mask_coarse = depth_mask_coarse.reshape(depth_mask.shape[0], -1)
|
229 |
+
|
230 |
+
valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
|
231 |
+
& (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask_coarse != 0) # N * L
|
232 |
+
if original_size0 is not None:
|
233 |
+
valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
|
234 |
+
& (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask_coarse != 0) # N * L
|
235 |
+
|
236 |
+
return valid_mask, kpts_warpped
|
237 |
+
|
238 |
+
@torch.no_grad()
|
239 |
+
# if using mask in homo warp(for fine supervision)
|
240 |
+
def homo_warp_kpts_with_mask_f(kpts0, depth_mask, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None):
|
241 |
+
"""
|
242 |
+
original_size1: N * 2, (h, w)
|
243 |
+
"""
|
244 |
+
normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L)
|
245 |
+
kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3)
|
246 |
+
kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2
|
247 |
+
valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
|
248 |
+
& (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask != 0) # N * L
|
249 |
+
if original_size0 is not None:
|
250 |
+
valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
|
251 |
+
& (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask != 0) # N * L
|
252 |
+
|
253 |
+
return valid_mask, kpts_warpped
|
254 |
+
|
255 |
+
@torch.no_grad()
|
256 |
+
def homo_warp_kpts_glue(kpts0, homo, original_size0=None, original_size1=None):
|
257 |
+
"""
|
258 |
+
original_size1: N * 2, (h, w)
|
259 |
+
"""
|
260 |
+
kpts_warpped = warp_points_torch(kpts0, homo, inverse=False)
|
261 |
+
valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
|
262 |
+
& (kpts_warpped[..., 1] < original_size1[:, [0]]) # N * L
|
263 |
+
if original_size0 is not None:
|
264 |
+
valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
|
265 |
+
& (kpts0[..., 1] < original_size0[:, [0]]) # N * L
|
266 |
+
return valid_mask, kpts_warpped
|
267 |
+
|
268 |
+
@torch.no_grad()
|
269 |
+
# if using mask in homo warp(for coarse supervision)
|
270 |
+
def homo_warp_kpts_glue_with_mask(kpts0, scale, depth_mask, homo, original_size0=None, original_size1=None):
|
271 |
+
"""
|
272 |
+
original_size1: N * 2, (h, w)
|
273 |
+
"""
|
274 |
+
kpts_warpped = warp_points_torch(kpts0, homo, inverse=False)
|
275 |
+
# get coarse-level depth_mask
|
276 |
+
depth_mask_coarse = depth_mask[:, :, ::scale, ::scale]
|
277 |
+
depth_mask_coarse = depth_mask_coarse.reshape(depth_mask.shape[0], -1)
|
278 |
+
|
279 |
+
valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
|
280 |
+
& (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask_coarse != 0) # N * L
|
281 |
+
if original_size0 is not None:
|
282 |
+
valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
|
283 |
+
& (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask_coarse != 0) # N * L
|
284 |
+
return valid_mask, kpts_warpped
|
285 |
+
|
286 |
+
@torch.no_grad()
|
287 |
+
# if using mask in homo warp(for fine supervision)
|
288 |
+
def homo_warp_kpts_glue_with_mask_f(kpts0, depth_mask, homo, original_size0=None, original_size1=None):
|
289 |
+
"""
|
290 |
+
original_size1: N * 2, (h, w)
|
291 |
+
"""
|
292 |
+
kpts_warpped = warp_points_torch(kpts0, homo, inverse=False)
|
293 |
+
valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
|
294 |
+
& (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask != 0) # N * L
|
295 |
+
if original_size0 is not None:
|
296 |
+
valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
|
297 |
+
& (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask != 0) # N * L
|
298 |
+
return valid_mask, kpts_warpped
|
imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class PositionEncodingSine(nn.Module):
|
7 |
+
"""
|
8 |
+
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True, npe=False):
|
12 |
+
"""
|
13 |
+
Args:
|
14 |
+
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
15 |
+
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
|
16 |
+
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
|
17 |
+
on the final performance. For now, we keep both impls for backward compatability.
|
18 |
+
We will remove the buggy impl after re-training all variants of our released models.
|
19 |
+
"""
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
pe = torch.zeros((d_model, *max_shape))
|
23 |
+
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
24 |
+
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
25 |
+
|
26 |
+
assert npe is not None
|
27 |
+
if npe is not None:
|
28 |
+
if isinstance(npe, bool):
|
29 |
+
train_res_H, train_res_W, test_res_H, test_res_W = 832, 832, 832, 832
|
30 |
+
print('loftr no npe!!!!', npe)
|
31 |
+
else:
|
32 |
+
print('absnpe!!!!', npe)
|
33 |
+
train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W
|
34 |
+
y_position, x_position = y_position * train_res_H / test_res_H, x_position * train_res_W / test_res_W
|
35 |
+
|
36 |
+
if temp_bug_fix:
|
37 |
+
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
|
38 |
+
else: # a buggy implementation (for backward compatability only)
|
39 |
+
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
|
40 |
+
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
41 |
+
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
42 |
+
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
43 |
+
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
44 |
+
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
45 |
+
|
46 |
+
self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
x: [N, C, H, W]
|
52 |
+
"""
|
53 |
+
return x + self.pe[:, :, :x.size(2), :x.size(3)]
|
54 |
+
|
55 |
+
class RoPEPositionEncodingSine(nn.Module):
|
56 |
+
"""
|
57 |
+
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, d_model, max_shape=(256, 256), npe=None, ropefp16=True):
|
61 |
+
"""
|
62 |
+
Args:
|
63 |
+
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
64 |
+
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
|
65 |
+
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
|
66 |
+
on the final performance. For now, we keep both impls for backward compatability.
|
67 |
+
We will remove the buggy impl after re-training all variants of our released models.
|
68 |
+
"""
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
# pe = torch.zeros((d_model, *max_shape))
|
72 |
+
# y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1)
|
73 |
+
# x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1)
|
74 |
+
i_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1) # [H, 1]
|
75 |
+
j_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1) # [W, 1]
|
76 |
+
|
77 |
+
assert npe is not None
|
78 |
+
if npe is not None:
|
79 |
+
train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W
|
80 |
+
i_position, j_position = i_position * train_res_H / test_res_H, j_position * train_res_W / test_res_W
|
81 |
+
|
82 |
+
div_term = torch.exp(torch.arange(0, d_model//4, 1).float() * (-math.log(10000.0) / (d_model//4)))
|
83 |
+
div_term = div_term[None, None, :] # [1, 1, C//4]
|
84 |
+
# pe[0::4, :, :] = torch.sin(x_position * div_term)
|
85 |
+
# pe[1::4, :, :] = torch.cos(x_position * div_term)
|
86 |
+
# pe[2::4, :, :] = torch.sin(y_position * div_term)
|
87 |
+
# pe[3::4, :, :] = torch.cos(y_position * div_term)
|
88 |
+
sin = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
|
89 |
+
cos = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
|
90 |
+
sin[:, :, 0::2] = torch.sin(i_position * div_term).half() if ropefp16 else torch.sin(i_position * div_term)
|
91 |
+
sin[:, :, 1::2] = torch.sin(j_position * div_term).half() if ropefp16 else torch.sin(j_position * div_term)
|
92 |
+
cos[:, :, 0::2] = torch.cos(i_position * div_term).half() if ropefp16 else torch.cos(i_position * div_term)
|
93 |
+
cos[:, :, 1::2] = torch.cos(j_position * div_term).half() if ropefp16 else torch.cos(j_position * div_term)
|
94 |
+
|
95 |
+
sin = sin.repeat_interleave(2, dim=-1)
|
96 |
+
cos = cos.repeat_interleave(2, dim=-1)
|
97 |
+
# self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, H, W, C]
|
98 |
+
self.register_buffer('sin', sin.unsqueeze(0), persistent=False) # [1, H, W, C//2]
|
99 |
+
self.register_buffer('cos', cos.unsqueeze(0), persistent=False) # [1, H, W, C//2]
|
100 |
+
|
101 |
+
i_position4 = i_position.reshape(64,4,64,4,1)[...,0,:]
|
102 |
+
i_position4 = i_position4.mean(-3)
|
103 |
+
j_position4 = j_position.reshape(64,4,64,4,1)[:,0,...]
|
104 |
+
j_position4 = j_position4.mean(-2)
|
105 |
+
sin4 = torch.zeros(max_shape[0]//4, max_shape[1]//4, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
|
106 |
+
cos4 = torch.zeros(max_shape[0]//4, max_shape[1]//4, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
|
107 |
+
sin4[:, :, 0::2] = torch.sin(i_position4 * div_term).half() if ropefp16 else torch.sin(i_position4 * div_term)
|
108 |
+
sin4[:, :, 1::2] = torch.sin(j_position4 * div_term).half() if ropefp16 else torch.sin(j_position4 * div_term)
|
109 |
+
cos4[:, :, 0::2] = torch.cos(i_position4 * div_term).half() if ropefp16 else torch.cos(i_position4 * div_term)
|
110 |
+
cos4[:, :, 1::2] = torch.cos(j_position4 * div_term).half() if ropefp16 else torch.cos(j_position4 * div_term)
|
111 |
+
sin4 = sin4.repeat_interleave(2, dim=-1)
|
112 |
+
cos4 = cos4.repeat_interleave(2, dim=-1)
|
113 |
+
self.register_buffer('sin4', sin4.unsqueeze(0), persistent=False) # [1, H, W, C//2]
|
114 |
+
self.register_buffer('cos4', cos4.unsqueeze(0), persistent=False) # [1, H, W, C//2]
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
def forward(self, x, ratio=1):
|
119 |
+
"""
|
120 |
+
Args:
|
121 |
+
x: [N, H, W, C]
|
122 |
+
"""
|
123 |
+
if ratio == 4:
|
124 |
+
return (x * self.cos4[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin4[:, :x.size(1), :x.size(2), :])
|
125 |
+
else:
|
126 |
+
return (x * self.cos[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin[:, :x.size(1), :x.size(2), :])
|
127 |
+
|
128 |
+
def rotate_half(self, x):
|
129 |
+
x = x.unflatten(-1, (-1, 2))
|
130 |
+
x1, x2 = x.unbind(dim=-1)
|
131 |
+
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
imcui/third_party/MatchAnything/src/loftr/utils/supervision.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import log
|
2 |
+
from loguru import logger as loguru_logger
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from kornia.utils import create_meshgrid
|
8 |
+
|
9 |
+
from .geometry import warp_kpts, homo_warp_kpts, homo_warp_kpts_glue, homo_warp_kpts_with_mask, homo_warp_kpts_with_mask_f, homo_warp_kpts_glue_with_mask, homo_warp_kpts_glue_with_mask_f, warp_kpts_by_sparse_gt_matches_fast, warp_kpts_by_sparse_gt_matches_fine_chunks
|
10 |
+
|
11 |
+
from kornia.geometry.subpix import dsnt
|
12 |
+
from kornia.utils.grid import create_meshgrid
|
13 |
+
|
14 |
+
def static_vars(**kwargs):
|
15 |
+
def decorate(func):
|
16 |
+
for k in kwargs:
|
17 |
+
setattr(func, k, kwargs[k])
|
18 |
+
return func
|
19 |
+
return decorate
|
20 |
+
|
21 |
+
############## ↓ Coarse-Level supervision ↓ ##############
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def mask_pts_at_padded_regions(grid_pt, mask):
|
25 |
+
"""For megadepth dataset, zero-padding exists in images"""
|
26 |
+
mask = repeat(mask, 'n h w -> n (h w) c', c=2)
|
27 |
+
grid_pt[~mask.bool()] = 0
|
28 |
+
return grid_pt
|
29 |
+
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def spvs_coarse(data, config):
|
33 |
+
"""
|
34 |
+
Update:
|
35 |
+
data (dict): {
|
36 |
+
"conf_matrix_gt": [N, hw0, hw1],
|
37 |
+
'spv_b_ids': [M]
|
38 |
+
'spv_i_ids': [M]
|
39 |
+
'spv_j_ids': [M]
|
40 |
+
'spv_w_pt0_i': [N, hw0, 2], in original image resolution
|
41 |
+
'spv_pt1_i': [N, hw1, 2], in original image resolution
|
42 |
+
}
|
43 |
+
|
44 |
+
NOTE:
|
45 |
+
- for scannet dataset, there're 3 kinds of resolution {i, c, f}
|
46 |
+
- for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
|
47 |
+
"""
|
48 |
+
# 1. misc
|
49 |
+
device = data['image0'].device
|
50 |
+
N, _, H0, W0 = data['image0'].shape
|
51 |
+
_, _, H1, W1 = data['image1'].shape
|
52 |
+
|
53 |
+
if 'loftr' in config.METHOD:
|
54 |
+
scale = config['LOFTR']['RESOLUTION'][0]
|
55 |
+
|
56 |
+
scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
|
57 |
+
scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
|
58 |
+
h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
|
59 |
+
|
60 |
+
if config['LOFTR']['MATCH_COARSE']['MTD_SPVS'] and not config['LOFTR']['FORCE_LOOP_BACK']:
|
61 |
+
# 2. warp grids
|
62 |
+
# create kpts in meshgrid and resize them to image resolution
|
63 |
+
grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
|
64 |
+
grid_pt0_i = scale0 * grid_pt0_c
|
65 |
+
grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
|
66 |
+
grid_pt1_i = scale1 * grid_pt1_c
|
67 |
+
|
68 |
+
correct_0to1 = torch.zeros((grid_pt0_i.shape[0], grid_pt0_i.shape[1]), dtype=torch.bool, device=grid_pt0_i.device)
|
69 |
+
w_pt0_i = torch.zeros_like(grid_pt0_i)
|
70 |
+
|
71 |
+
valid_dpt_b_mask = data['T_0to1'].sum(dim=-1).sum(dim=-1) != 0
|
72 |
+
valid_homo_warp_mask = (data['homography'].sum(dim=-1).sum(dim=-1) != 0) | (data['homo_sample_normed'].sum(dim=-1).sum(dim=-1) != 0)
|
73 |
+
valid_gt_match_warp_mask = (data['gt_matches_mask'][:, 0] != 0) # N
|
74 |
+
|
75 |
+
if valid_homo_warp_mask.sum() != 0:
|
76 |
+
if data['homography'].sum()==0:
|
77 |
+
if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0): # the key 'depth_mask' only exits when using the dataste "CommonDataSetHomoWarp"
|
78 |
+
correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_with_mask(grid_pt0_i[valid_homo_warp_mask], scale, data['homo_mask0'][valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
|
79 |
+
else:
|
80 |
+
correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts(grid_pt0_i[valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], \
|
81 |
+
data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
|
82 |
+
else:
|
83 |
+
if 'homo_mask0' in data and (data['homo_mask0']==0).sum()!=0:
|
84 |
+
correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue_with_mask(grid_pt0_i[valid_homo_warp_mask], scale, data['homo_mask0'][valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
|
85 |
+
else:
|
86 |
+
correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue(grid_pt0_i[valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \
|
87 |
+
original_size1=data['origin_img_size1'][valid_homo_warp_mask])
|
88 |
+
correct_0to1[valid_homo_warp_mask] = correct_0to1_homo
|
89 |
+
w_pt0_i[valid_homo_warp_mask] = w_pt0_i_homo
|
90 |
+
|
91 |
+
if valid_gt_match_warp_mask.sum() != 0:
|
92 |
+
correct_0to1_dpt, w_pt0_i_dpt = warp_kpts_by_sparse_gt_matches_fast(grid_pt0_i[valid_gt_match_warp_mask], data['gt_matches'][valid_gt_match_warp_mask], scale0=scale0[valid_gt_match_warp_mask], current_h=h0, current_w=w0)
|
93 |
+
correct_0to1[valid_gt_match_warp_mask] = correct_0to1_dpt
|
94 |
+
w_pt0_i[valid_gt_match_warp_mask] = w_pt0_i_dpt
|
95 |
+
|
96 |
+
if valid_dpt_b_mask.sum() != 0:
|
97 |
+
correct_0to1_dpt, w_pt0_i_dpt = warp_kpts(grid_pt0_i[valid_dpt_b_mask], data['depth0'][valid_dpt_b_mask], data['depth1'][valid_dpt_b_mask], data['T_0to1'][valid_dpt_b_mask], data['K0'][valid_dpt_b_mask], data['K1'][valid_dpt_b_mask], consistency_thr=0.05)
|
98 |
+
correct_0to1[valid_dpt_b_mask] = correct_0to1_dpt
|
99 |
+
w_pt0_i[valid_dpt_b_mask] = w_pt0_i_dpt
|
100 |
+
|
101 |
+
w_pt0_c = w_pt0_i / scale1
|
102 |
+
|
103 |
+
# 3. check if mutual nearest neighbor
|
104 |
+
w_pt0_c_round = w_pt0_c[:, :, :].round() # [N, hw, 2]
|
105 |
+
if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT:
|
106 |
+
w_pt0_c_error = (1.0 - 2*torch.abs(w_pt0_c - w_pt0_c_round)).prod(-1)
|
107 |
+
w_pt0_c_round = w_pt0_c_round.long() # [N, hw, 2]
|
108 |
+
nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 # [N, hw]
|
109 |
+
|
110 |
+
# corner case: out of boundary
|
111 |
+
def out_bound_mask(pt, w, h):
|
112 |
+
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
|
113 |
+
nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = -1
|
114 |
+
|
115 |
+
correct_0to1[:, 0] = False # ignore the top-left corner
|
116 |
+
|
117 |
+
# 4. construct a gt conf_matrix
|
118 |
+
mask1 = torch.stack([data['mask1'].reshape(-1, h1*w1)[_b, _i] for _b, _i in enumerate(nearest_index1)], dim=0)
|
119 |
+
correct_0to1 = correct_0to1 * data['mask0'].reshape(-1, h0*w0) * mask1
|
120 |
+
|
121 |
+
conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device, dtype=torch.bool)
|
122 |
+
b_ids, i_ids = torch.where(correct_0to1 != 0)
|
123 |
+
j_ids = nearest_index1[b_ids, i_ids]
|
124 |
+
valid_j_ids = j_ids != -1
|
125 |
+
b_ids, i_ids, j_ids = map(lambda x: x[valid_j_ids], [b_ids, i_ids, j_ids])
|
126 |
+
|
127 |
+
conf_matrix_gt[b_ids, i_ids, j_ids] = 1
|
128 |
+
|
129 |
+
# overlap weight
|
130 |
+
if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT:
|
131 |
+
conf_matrix_error_gt = w_pt0_c_error[b_ids, i_ids]
|
132 |
+
assert torch.all(conf_matrix_error_gt >= -0.001)
|
133 |
+
assert torch.all(conf_matrix_error_gt <= 1.001)
|
134 |
+
data.update({'conf_matrix_error_gt': conf_matrix_error_gt})
|
135 |
+
data.update({'conf_matrix_gt': conf_matrix_gt})
|
136 |
+
|
137 |
+
# 5. save coarse matches(gt) for training fine level
|
138 |
+
if len(b_ids) == 0:
|
139 |
+
loguru_logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
|
140 |
+
# this won't affect fine-level loss calculation
|
141 |
+
b_ids = torch.tensor([0], device=device)
|
142 |
+
i_ids = torch.tensor([0], device=device)
|
143 |
+
j_ids = torch.tensor([0], device=device)
|
144 |
+
|
145 |
+
data.update({
|
146 |
+
'spv_b_ids': b_ids,
|
147 |
+
'spv_i_ids': i_ids,
|
148 |
+
'spv_j_ids': j_ids
|
149 |
+
})
|
150 |
+
|
151 |
+
data.update({'mkpts0_c_gt_b_ids': b_ids})
|
152 |
+
data.update({'mkpts0_c_gt': torch.stack([i_ids % w0, i_ids // w0], dim=-1) * scale0[b_ids, 0]})
|
153 |
+
data.update({'mkpts1_c_gt': torch.stack([j_ids % w1, j_ids // w1], dim=-1) * scale1[b_ids, 0]})
|
154 |
+
|
155 |
+
# 6. save intermediate results (for fast fine-level computation)
|
156 |
+
data.update({
|
157 |
+
'spv_w_pt0_i': w_pt0_i,
|
158 |
+
'spv_pt1_i': grid_pt1_i,
|
159 |
+
# 'correct_0to1_c': correct_0to1
|
160 |
+
})
|
161 |
+
else:
|
162 |
+
raise NotImplementedError
|
163 |
+
|
164 |
+
def compute_supervision_coarse(data, config):
|
165 |
+
spvs_coarse(data, config)
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def get_gt_flow(data, h, w):
|
169 |
+
device = data['image0'].device
|
170 |
+
B, _, H0, W0 = data['image0'].shape
|
171 |
+
scale = H0 / h
|
172 |
+
|
173 |
+
scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
|
174 |
+
scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
|
175 |
+
|
176 |
+
x1_n = torch.meshgrid(
|
177 |
+
*[
|
178 |
+
torch.linspace(
|
179 |
+
-1 + 1 / n, 1 - 1 / n, n, device=device
|
180 |
+
)
|
181 |
+
for n in (B, h, w)
|
182 |
+
]
|
183 |
+
)
|
184 |
+
grid_coord = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, h*w, 2) # normalized
|
185 |
+
grid_coord = torch.stack(
|
186 |
+
(w * (grid_coord[..., 0] + 1) / 2, h * (grid_coord[..., 1] + 1) / 2), dim=-1
|
187 |
+
) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
|
188 |
+
grid_coord_in_origin = grid_coord * scale0
|
189 |
+
|
190 |
+
correct_0to1 = torch.zeros((grid_coord_in_origin.shape[0], grid_coord_in_origin.shape[1]), dtype=torch.bool, device=device)
|
191 |
+
w_pt0_i = torch.zeros_like(grid_coord_in_origin)
|
192 |
+
|
193 |
+
valid_dpt_b_mask = data['T_0to1'].sum(dim=-1).sum(dim=-1) != 0
|
194 |
+
valid_homo_warp_mask = (data['homography'].sum(dim=-1).sum(dim=-1) != 0) | (data['homo_sample_normed'].sum(dim=-1).sum(dim=-1) != 0)
|
195 |
+
valid_gt_match_warp_mask = (data['gt_matches_mask'] != 0)[:, 0]
|
196 |
+
|
197 |
+
if valid_homo_warp_mask.sum() != 0:
|
198 |
+
if data['homography'].sum()==0:
|
199 |
+
if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
|
200 |
+
# data['load_mask'] = True or False, data['depth_mask'] = depth_mask or None
|
201 |
+
correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_with_mask(grid_coord_in_origin[valid_homo_warp_mask], int(scale), data['homo_mask0'][valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], \
|
202 |
+
data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
|
203 |
+
else:
|
204 |
+
correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts(grid_coord_in_origin[valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], data['homo_sample_normed'][valid_homo_warp_mask], \
|
205 |
+
original_size1=data['origin_img_size1'][valid_homo_warp_mask])
|
206 |
+
else:
|
207 |
+
if 'homo_mask0' in data and (data['homo_mask0']==0).sum()!=0:
|
208 |
+
correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue_with_mask(grid_coord_in_origin[valid_homo_warp_mask], int(scale), data['homo_mask0'][valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \
|
209 |
+
original_size1=data['origin_img_size1'][valid_homo_warp_mask])
|
210 |
+
else:
|
211 |
+
correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue(grid_coord_in_origin[valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \
|
212 |
+
original_size1=data['origin_img_size1'][valid_homo_warp_mask])
|
213 |
+
correct_0to1[valid_homo_warp_mask] = correct_0to1_homo
|
214 |
+
w_pt0_i[valid_homo_warp_mask] = w_pt0_i_homo
|
215 |
+
|
216 |
+
if valid_gt_match_warp_mask.sum() != 0:
|
217 |
+
correct_0to1_dpt, w_pt0_i_dpt = warp_kpts_by_sparse_gt_matches_fast(grid_coord_in_origin[valid_gt_match_warp_mask], data['gt_matches'][valid_gt_match_warp_mask], scale0=scale0[valid_gt_match_warp_mask], current_h=h, current_w=w)
|
218 |
+
correct_0to1[valid_gt_match_warp_mask] = correct_0to1_dpt
|
219 |
+
w_pt0_i[valid_gt_match_warp_mask] = w_pt0_i_dpt
|
220 |
+
if valid_dpt_b_mask.sum() != 0:
|
221 |
+
correct_0to1_dpt, w_pt0_i_dpt = warp_kpts(grid_coord_in_origin[valid_dpt_b_mask], data['depth0'][valid_dpt_b_mask], data['depth1'][valid_dpt_b_mask], data['T_0to1'][valid_dpt_b_mask], data['K0'][valid_dpt_b_mask], data['K1'][valid_dpt_b_mask], consistency_thr=0.05)
|
222 |
+
correct_0to1[valid_dpt_b_mask] = correct_0to1_dpt
|
223 |
+
w_pt0_i[valid_dpt_b_mask] = w_pt0_i_dpt
|
224 |
+
|
225 |
+
w_pt0_c = w_pt0_i / scale1
|
226 |
+
|
227 |
+
def out_bound_mask(pt, w, h):
|
228 |
+
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
|
229 |
+
correct_0to1[out_bound_mask(w_pt0_c, w, h)] = 0
|
230 |
+
|
231 |
+
w_pt0_n = torch.stack(
|
232 |
+
(2 * w_pt0_c[..., 0] / w - 1, 2 * w_pt0_c[..., 1] / h - 1), dim=-1
|
233 |
+
) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
|
234 |
+
# w_pt1_c = w_pt1_i / scale0
|
235 |
+
|
236 |
+
if scale > 8:
|
237 |
+
data.update({'mkpts0_c_gt': grid_coord_in_origin[correct_0to1]})
|
238 |
+
data.update({'mkpts1_c_gt': w_pt0_i[correct_0to1]})
|
239 |
+
|
240 |
+
return w_pt0_n.reshape(B, h, w, 2), correct_0to1.float().reshape(B, h, w)
|
241 |
+
|
242 |
+
@torch.no_grad()
|
243 |
+
def compute_roma_supervision(data, config):
|
244 |
+
gt_flow = {}
|
245 |
+
for scale in list(data["corresps"]):
|
246 |
+
scale_corresps = data["corresps"][scale]
|
247 |
+
flow_pre_delta = rearrange(scale_corresps['flow'] if 'flow'in scale_corresps else scale_corresps['dense_flow'], "b d h w -> b h w d")
|
248 |
+
b, h, w, d = flow_pre_delta.shape
|
249 |
+
gt_warp, gt_prob = get_gt_flow(data, h, w)
|
250 |
+
gt_flow[scale] = {'gt_warp': gt_warp, "gt_prob": gt_prob}
|
251 |
+
|
252 |
+
data.update({"gt": gt_flow})
|
253 |
+
|
254 |
+
############## ↓ Fine-Level supervision ↓ ##############
|
255 |
+
|
256 |
+
@static_vars(counter = 0)
|
257 |
+
@torch.no_grad()
|
258 |
+
def spvs_fine(data, config, logger = None):
|
259 |
+
"""
|
260 |
+
Update:
|
261 |
+
data (dict):{
|
262 |
+
"expec_f_gt": [M, 2]}
|
263 |
+
"""
|
264 |
+
# 1. misc
|
265 |
+
# w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
|
266 |
+
if config.LOFTR.FINE.MTD_SPVS:
|
267 |
+
pt1_i = data['spv_pt1_i']
|
268 |
+
else:
|
269 |
+
spv_w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
|
270 |
+
if 'loftr' in config.METHOD:
|
271 |
+
scale = config['LOFTR']['RESOLUTION'][1]
|
272 |
+
scale_c = config['LOFTR']['RESOLUTION'][0]
|
273 |
+
radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2
|
274 |
+
|
275 |
+
# 2. get coarse prediction
|
276 |
+
b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
|
277 |
+
|
278 |
+
# 3. compute gt
|
279 |
+
scalei0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
|
280 |
+
scale0 = scale * data['scale0'] if 'scale0' in data else scale
|
281 |
+
scalei1 = scale * data['scale1'][b_ids] if 'scale0' in data else scale
|
282 |
+
|
283 |
+
if config.LOFTR.FINE.MTD_SPVS:
|
284 |
+
W = config['LOFTR']['FINE_WINDOW_SIZE']
|
285 |
+
WW = W*W
|
286 |
+
device = data['image0'].device
|
287 |
+
|
288 |
+
N, _, H0, W0 = data['image0'].shape
|
289 |
+
_, _, H1, W1 = data['image1'].shape
|
290 |
+
|
291 |
+
if config.LOFTR.ALIGN_CORNER is False:
|
292 |
+
hf0, wf0, hf1, wf1 = data['hw0_f'][0], data['hw0_f'][1], data['hw1_f'][0], data['hw1_f'][1]
|
293 |
+
hc0, wc0, hc1, wc1 = data['hw0_c'][0], data['hw0_c'][1], data['hw1_c'][0], data['hw1_c'][1]
|
294 |
+
# loguru_logger.info('hf0, wf0, hf1, wf1', hf0, wf0, hf1, wf1)
|
295 |
+
else:
|
296 |
+
hf0, wf0, hf1, wf1 = map(lambda x: x // scale, [H0, W0, H1, W1])
|
297 |
+
hc0, wc0, hc1, wc1 = map(lambda x: x // scale_c, [H0, W0, H1, W1])
|
298 |
+
|
299 |
+
m = b_ids.shape[0]
|
300 |
+
if m == 0:
|
301 |
+
conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device)
|
302 |
+
|
303 |
+
data.update({'conf_matrix_f_gt': conf_matrix_f_gt})
|
304 |
+
if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
|
305 |
+
conf_matrix_f_error_gt = torch.zeros(1, device=device)
|
306 |
+
data.update({'conf_matrix_f_error_gt': conf_matrix_f_error_gt})
|
307 |
+
if config.LOFTR.MATCH_FINE.MULTI_REGRESS:
|
308 |
+
data.update({'expec_f': torch.zeros(1, 3, device=device)})
|
309 |
+
data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
|
310 |
+
|
311 |
+
if config.LOFTR.MATCH_FINE.LOCAL_REGRESS:
|
312 |
+
data.update({'expec_f': torch.zeros(1, 2, device=device)})
|
313 |
+
data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
|
314 |
+
else:
|
315 |
+
grid_pt0_f = create_meshgrid(hf0, wf0, False, device) - W // 2 + 0.5 # [1, hf0, wf0, 2] # use fine coordinates
|
316 |
+
# grid_pt0_f = create_meshgrid(hf0, wf0, False, device) + 0.5 # [1, hf0, wf0, 2] # use fine coordinates
|
317 |
+
grid_pt0_f = rearrange(grid_pt0_f, 'n h w c -> n c h w')
|
318 |
+
# 1. unfold(crop) all local windows
|
319 |
+
if config.LOFTR.ALIGN_CORNER is False: # even windows
|
320 |
+
if config.LOFTR.MATCH_FINE.MULTI_REGRESS or (config.LOFTR.MATCH_FINE.LOCAL_REGRESS and W == 10):
|
321 |
+
grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W-2, padding=1) # overlap windows W-2 padding=1
|
322 |
+
else:
|
323 |
+
grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W, padding=0)
|
324 |
+
else:
|
325 |
+
grid_pt0_f_unfold = F.unfold(grid_pt0_f[..., :-1, :-1], kernel_size=(W, W), stride=W, padding=W//2)
|
326 |
+
grid_pt0_f_unfold = rearrange(grid_pt0_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 2]
|
327 |
+
grid_pt0_f_unfold = repeat(grid_pt0_f_unfold[0], 'l ww c -> N l ww c', N=N)
|
328 |
+
|
329 |
+
# 2. select only the predicted matches
|
330 |
+
grid_pt0_f_unfold = grid_pt0_f_unfold[data['b_ids'], data['i_ids']] # [m, ww, 2]
|
331 |
+
grid_pt0_f_unfold = scalei0[:,None,:] * grid_pt0_f_unfold # [m, ww, 2]
|
332 |
+
|
333 |
+
# use depth mask
|
334 |
+
if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
|
335 |
+
# depth_mask --> (n, 1, hf, wf)
|
336 |
+
homo_mask0 = data['homo_mask0']
|
337 |
+
homo_mask0 = F.unfold(homo_mask0[..., :-1, :-1], kernel_size=(W, W), stride=W, padding=W//2)
|
338 |
+
homo_mask0 = rearrange(homo_mask0, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 1]
|
339 |
+
homo_mask0 = repeat(homo_mask0[0], 'l ww c -> N l ww c', N=N)
|
340 |
+
# select only the predicted matches
|
341 |
+
homo_mask0 = homo_mask0[data['b_ids'], data['i_ids']]
|
342 |
+
|
343 |
+
correct_0to1_f_list, w_pt0_i_list = [], []
|
344 |
+
|
345 |
+
correct_0to1_f = torch.zeros(m, WW, device=device, dtype=torch.bool)
|
346 |
+
w_pt0_i = torch.zeros(m, WW, 2, device=device, dtype=torch.float32)
|
347 |
+
for b in range(N):
|
348 |
+
mask = b_ids == b
|
349 |
+
|
350 |
+
match = int(mask.sum())
|
351 |
+
skip_reshape = False
|
352 |
+
if match == 0:
|
353 |
+
print(f"no pred fine matches, skip!")
|
354 |
+
continue
|
355 |
+
if (data['homography'][b].sum() != 0) | (data['homo_sample_normed'][b].sum() != 0):
|
356 |
+
if data['homography'][b].sum()==0:
|
357 |
+
if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
|
358 |
+
correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_with_mask_f(grid_pt0_f_unfold[mask].reshape(1,-1,2), homo_mask0[mask].reshape(1,-1), data['norm_pixel_mat'][[b]], \
|
359 |
+
data['homo_sample_normed'][[b]], data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
|
360 |
+
else:
|
361 |
+
correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['norm_pixel_mat'][[b]], \
|
362 |
+
data['homo_sample_normed'][[b]], data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
|
363 |
+
else:
|
364 |
+
if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
|
365 |
+
correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_glue_with_mask_f(grid_pt0_f_unfold[mask].reshape(1,-1,2), homo_mask0[mask].reshape(1,-1), data['homography'][[b]], \
|
366 |
+
data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
|
367 |
+
else:
|
368 |
+
correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_glue(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['homography'][[b]], \
|
369 |
+
data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
|
370 |
+
elif data['T_0to1'][b].sum() != 0:
|
371 |
+
correct_0to1_f_mask, w_pt0_i_mask = warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['depth0'][[b],...],
|
372 |
+
data['depth1'][[b],...], data['T_0to1'][[b],...],
|
373 |
+
data['K0'][[b],...], data['K1'][[b],...]) # [k, WW], [k, WW, 2]
|
374 |
+
elif data['gt_matches_mask'][b].sum() != 0:
|
375 |
+
correct_0to1_f_mask, w_pt0_i_mask = warp_kpts_by_sparse_gt_matches_fine_chunks(grid_pt0_f_unfold[mask], gt_matches=data['gt_matches'][[b]], dist_thr=scale0[[b]].max(dim=-1)[0])
|
376 |
+
skip_reshape = True
|
377 |
+
correct_0to1_f[mask] = correct_0to1_f_mask.reshape(match, WW) if not skip_reshape else correct_0to1_f_mask
|
378 |
+
w_pt0_i[mask] = w_pt0_i_mask.reshape(match, WW, 2) if not skip_reshape else w_pt0_i_mask
|
379 |
+
|
380 |
+
delta_w_pt0_i = w_pt0_i - pt1_i[b_ids, j_ids][:,None,:] # [m, WW, 2]
|
381 |
+
delta_w_pt0_f = delta_w_pt0_i / scalei1[:,None,:] + W // 2 - 0.5
|
382 |
+
delta_w_pt0_f_round = delta_w_pt0_f[:, :, :].round()
|
383 |
+
if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT and config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2:
|
384 |
+
w_pt0_f_error = (1.0 - torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0.25, 1]
|
385 |
+
elif config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
|
386 |
+
w_pt0_f_error = (1.0 - 2*torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0, 1]
|
387 |
+
delta_w_pt0_f_round = delta_w_pt0_f_round.long()
|
388 |
+
|
389 |
+
|
390 |
+
nearest_index1 = delta_w_pt0_f_round[..., 0] + delta_w_pt0_f_round[..., 1] * W # [m, WW]
|
391 |
+
|
392 |
+
def out_bound_mask(pt, w, h):
|
393 |
+
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
|
394 |
+
ob_mask = out_bound_mask(delta_w_pt0_f_round, W, W)
|
395 |
+
nearest_index1[ob_mask] = 0
|
396 |
+
|
397 |
+
correct_0to1_f[ob_mask] = 0
|
398 |
+
m_ids_d, i_ids_d = torch.where(correct_0to1_f != 0)
|
399 |
+
|
400 |
+
j_ids_d = nearest_index1[m_ids_d, i_ids_d]
|
401 |
+
|
402 |
+
# For plotting:
|
403 |
+
mkpts0_f_gt = grid_pt0_f_unfold[m_ids_d, i_ids_d] # [m, 2]
|
404 |
+
mkpts1_f_gt = w_pt0_i[m_ids_d, i_ids_d] # [m, 2]
|
405 |
+
data.update({'mkpts0_f_gt_b_ids': m_ids_d})
|
406 |
+
data.update({'mkpts0_f_gt': mkpts0_f_gt})
|
407 |
+
data.update({'mkpts1_f_gt': mkpts1_f_gt})
|
408 |
+
|
409 |
+
if config.LOFTR.MATCH_FINE.MULTI_REGRESS:
|
410 |
+
assert not config.LOFTR.MATCH_FINE.LOCAL_REGRESS
|
411 |
+
expec_f_gt = delta_w_pt0_f - W // 2 + 0.5 # use delta(e.g. [-3.5,3.5]) in regression rather than [0,W] (e.g. [0,7])
|
412 |
+
expec_f_gt = expec_f_gt[m_ids_d, i_ids_d] / (W // 2 - 1) # specific radius for overlaped even windows & align_corner=False
|
413 |
+
data.update({'expec_f_gt': expec_f_gt})
|
414 |
+
data.update({'m_ids_d': m_ids_d, 'i_ids_d': i_ids_d})
|
415 |
+
else: # spv fine dual softmax
|
416 |
+
if config.LOFTR.MATCH_FINE.LOCAL_REGRESS:
|
417 |
+
expec_f_gt = delta_w_pt0_f - delta_w_pt0_f_round
|
418 |
+
|
419 |
+
# mask fine windows boarder
|
420 |
+
j_ids_d_il, j_ids_d_jl = j_ids_d // W, j_ids_d % W
|
421 |
+
if config.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK:
|
422 |
+
mask = None
|
423 |
+
m_ids_dl, i_ids_dl, j_ids_d_il, j_ids_d_jl = m_ids_d.to(torch.long), i_ids_d.to(torch.long), j_ids_d_il.to(torch.long), j_ids_d_jl.to(torch.long)
|
424 |
+
else:
|
425 |
+
mask = (j_ids_d_il >= 1) & (j_ids_d_il < W-1) & (j_ids_d_jl >= 1) & (j_ids_d_jl < W-1)
|
426 |
+
if W == 10:
|
427 |
+
i_ids_d_il, i_ids_d_jl = i_ids_d // W, i_ids_d % W
|
428 |
+
mask = mask & (i_ids_d_il >= 1) & (i_ids_d_il <= W-2) & (i_ids_d_jl >= 1) & (i_ids_d_jl <= W-2)
|
429 |
+
|
430 |
+
m_ids_dl, i_ids_dl, j_ids_d_il, j_ids_d_jl = m_ids_d[mask].to(torch.long), i_ids_d[mask].to(torch.long), j_ids_d_il[mask].to(torch.long), j_ids_d_jl[mask].to(torch.long)
|
431 |
+
if mask is not None:
|
432 |
+
loguru_logger.info(f'percent of gt mask.sum / mask.numel: {mask.sum().float()/mask.numel():.2f}')
|
433 |
+
if m_ids_dl.numel() == 0:
|
434 |
+
loguru_logger.warning(f"No groundtruth fine match found for local regress: {data['pair_names']}")
|
435 |
+
data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
|
436 |
+
data.update({'expec_f': torch.zeros(1, 2, device=device)})
|
437 |
+
else:
|
438 |
+
expec_f_gt = expec_f_gt[m_ids_dl, i_ids_dl]
|
439 |
+
data.update({"expec_f_gt": expec_f_gt})
|
440 |
+
|
441 |
+
data.update({"m_ids_dl": m_ids_dl,
|
442 |
+
"i_ids_dl": i_ids_dl,
|
443 |
+
"j_ids_d_il": j_ids_d_il,
|
444 |
+
"j_ids_d_jl": j_ids_d_jl
|
445 |
+
})
|
446 |
+
else: # no fine regress
|
447 |
+
pass
|
448 |
+
|
449 |
+
# spv fine dual softmax
|
450 |
+
conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device, dtype=torch.bool)
|
451 |
+
conf_matrix_f_gt[m_ids_d, i_ids_d, j_ids_d] = 1
|
452 |
+
data.update({'conf_matrix_f_gt': conf_matrix_f_gt})
|
453 |
+
if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
|
454 |
+
w_pt0_f_error = w_pt0_f_error[m_ids_d, i_ids_d]
|
455 |
+
assert torch.all(w_pt0_f_error >= -0.001)
|
456 |
+
assert torch.all(w_pt0_f_error <= 1.001)
|
457 |
+
data.update({'conf_matrix_f_error_gt': w_pt0_f_error})
|
458 |
+
|
459 |
+
conf_matrix_f_gt_sum = conf_matrix_f_gt.sum()
|
460 |
+
if conf_matrix_f_gt_sum != 0:
|
461 |
+
pass
|
462 |
+
else:
|
463 |
+
loguru_logger.info(f'[no gt plot]no fine matches to supervise')
|
464 |
+
|
465 |
+
else:
|
466 |
+
expec_f_gt = (spv_w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scalei1 / 4 # [M, 2]
|
467 |
+
data.update({"expec_f_gt": expec_f_gt})
|
468 |
+
|
469 |
+
|
470 |
+
def compute_supervision_fine(data, config, logger=None):
|
471 |
+
data_source = data['dataset_name'][0]
|
472 |
+
if data_source.lower() in ['scannet', 'megadepth']:
|
473 |
+
spvs_fine(data, config, logger)
|
474 |
+
else:
|
475 |
+
raise NotImplementedError
|
imcui/third_party/MatchAnything/src/optimizers/__init__.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
|
3 |
+
|
4 |
+
|
5 |
+
def build_optimizer(model, config):
|
6 |
+
name = config.TRAINER.OPTIMIZER
|
7 |
+
lr = config.TRAINER.TRUE_LR
|
8 |
+
|
9 |
+
if name == "adam":
|
10 |
+
return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY, eps=config.TRAINER.OPTIMIZER_EPS)
|
11 |
+
elif name == "adamw":
|
12 |
+
if ("ROMA" in config.METHOD) or ("DKM" in config.METHOD):
|
13 |
+
# Filter the backbone param and others param:
|
14 |
+
keyword = 'model.encoder'
|
15 |
+
backbone_params = [param for name, param in list(filter(lambda kv: keyword in kv[0], model.named_parameters()))]
|
16 |
+
base_params = [param for name, param in list(filter(lambda kv: keyword not in kv[0], model.named_parameters()))]
|
17 |
+
params = [{'params': backbone_params, 'lr': lr * 0.05}, {'params': base_params}]
|
18 |
+
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY, eps=config.TRAINER.OPTIMIZER_EPS)
|
19 |
+
else:
|
20 |
+
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY, eps=config.TRAINER.OPTIMIZER_EPS)
|
21 |
+
else:
|
22 |
+
raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
|
23 |
+
|
24 |
+
|
25 |
+
def build_scheduler(config, optimizer):
|
26 |
+
"""
|
27 |
+
Returns:
|
28 |
+
scheduler (dict):{
|
29 |
+
'scheduler': lr_scheduler,
|
30 |
+
'interval': 'step', # or 'epoch'
|
31 |
+
'monitor': 'val_f1', (optional)
|
32 |
+
'frequency': x, (optional)
|
33 |
+
}
|
34 |
+
"""
|
35 |
+
scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
|
36 |
+
name = config.TRAINER.SCHEDULER
|
37 |
+
|
38 |
+
if name == 'MultiStepLR':
|
39 |
+
scheduler.update(
|
40 |
+
{'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
|
41 |
+
elif name == 'CosineAnnealing':
|
42 |
+
scheduler.update(
|
43 |
+
{'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
|
44 |
+
elif name == 'ExponentialLR':
|
45 |
+
scheduler.update(
|
46 |
+
{'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
|
47 |
+
else:
|
48 |
+
raise NotImplementedError()
|
49 |
+
|
50 |
+
return scheduler
|
imcui/third_party/MatchAnything/src/utils/__init__.py
ADDED
File without changes
|
imcui/third_party/MatchAnything/src/utils/augment.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
|
3 |
+
|
4 |
+
class DarkAug(object):
|
5 |
+
"""
|
6 |
+
Extreme dark augmentation aiming at Aachen Day-Night
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(self) -> None:
|
10 |
+
self.augmentor = A.Compose([
|
11 |
+
A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
|
12 |
+
A.Blur(p=0.1, blur_limit=(3, 9)),
|
13 |
+
A.MotionBlur(p=0.2, blur_limit=(3, 25)),
|
14 |
+
A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
|
15 |
+
A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
|
16 |
+
], p=0.75)
|
17 |
+
|
18 |
+
def __call__(self, x):
|
19 |
+
return self.augmentor(image=x)['image']
|
20 |
+
|
21 |
+
|
22 |
+
class MobileAug(object):
|
23 |
+
"""
|
24 |
+
Random augmentations aiming at images of mobile/handhold devices.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
self.augmentor = A.Compose([
|
29 |
+
A.MotionBlur(p=0.25),
|
30 |
+
A.ColorJitter(p=0.5),
|
31 |
+
A.RandomRain(p=0.1), # random occlusion
|
32 |
+
A.RandomSunFlare(p=0.1),
|
33 |
+
A.JpegCompression(p=0.25),
|
34 |
+
A.ISONoise(p=0.25)
|
35 |
+
], p=1.0)
|
36 |
+
|
37 |
+
def __call__(self, x):
|
38 |
+
return self.augmentor(image=x)['image']
|
39 |
+
|
40 |
+
|
41 |
+
def build_augmentor(method=None, **kwargs):
|
42 |
+
if method is not None:
|
43 |
+
raise NotImplementedError('Using of augmentation functions are not supported yet!')
|
44 |
+
if method == 'dark':
|
45 |
+
return DarkAug()
|
46 |
+
elif method == 'mobile':
|
47 |
+
return MobileAug()
|
48 |
+
elif method is None:
|
49 |
+
return None
|
50 |
+
else:
|
51 |
+
raise ValueError(f'Invalid augmentation method: {method}')
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
augmentor = build_augmentor('FDA')
|
imcui/third_party/MatchAnything/src/utils/colmap.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, ETH Zurich and UNC Chapel Hill.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Redistribution and use in source and binary forms, with or without
|
5 |
+
# modification, are permitted provided that the following conditions are met:
|
6 |
+
#
|
7 |
+
# * Redistributions of source code must retain the above copyright
|
8 |
+
# notice, this list of conditions and the following disclaimer.
|
9 |
+
#
|
10 |
+
# * Redistributions in binary form must reproduce the above copyright
|
11 |
+
# notice, this list of conditions and the following disclaimer in the
|
12 |
+
# documentation and/or other materials provided with the distribution.
|
13 |
+
#
|
14 |
+
# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
|
15 |
+
# its contributors may be used to endorse or promote products derived
|
16 |
+
# from this software without specific prior written permission.
|
17 |
+
#
|
18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
19 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
20 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
21 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
|
22 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
23 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
24 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
25 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
26 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
27 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
28 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
#
|
30 |
+
# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
|
31 |
+
|
32 |
+
from typing import List, Tuple, Dict
|
33 |
+
import os
|
34 |
+
import collections
|
35 |
+
import numpy as np
|
36 |
+
import struct
|
37 |
+
import argparse
|
38 |
+
|
39 |
+
|
40 |
+
CameraModel = collections.namedtuple(
|
41 |
+
"CameraModel", ["model_id", "model_name", "num_params"])
|
42 |
+
BaseCamera = collections.namedtuple(
|
43 |
+
"Camera", ["id", "model", "width", "height", "params"])
|
44 |
+
BaseImage = collections.namedtuple(
|
45 |
+
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
|
46 |
+
Point3D = collections.namedtuple(
|
47 |
+
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
|
48 |
+
|
49 |
+
|
50 |
+
class Image(BaseImage):
|
51 |
+
def qvec2rotmat(self):
|
52 |
+
return qvec2rotmat(self.qvec)
|
53 |
+
|
54 |
+
@property
|
55 |
+
def world_to_camera(self) -> np.ndarray:
|
56 |
+
R = qvec2rotmat(self.qvec)
|
57 |
+
t = self.tvec
|
58 |
+
world2cam = np.eye(4)
|
59 |
+
world2cam[:3, :3] = R
|
60 |
+
world2cam[:3, 3] = t
|
61 |
+
return world2cam
|
62 |
+
|
63 |
+
|
64 |
+
class Camera(BaseCamera):
|
65 |
+
@property
|
66 |
+
def K(self):
|
67 |
+
K = np.eye(3)
|
68 |
+
if self.model == "SIMPLE_PINHOLE" or self.model == "SIMPLE_RADIAL" or self.model == "RADIAL" or self.model == "SIMPLE_RADIAL_FISHEYE" or self.model == "RADIAL_FISHEYE":
|
69 |
+
K[0, 0] = self.params[0]
|
70 |
+
K[1, 1] = self.params[0]
|
71 |
+
K[0, 2] = self.params[1]
|
72 |
+
K[1, 2] = self.params[2]
|
73 |
+
elif self.model == "PINHOLE" or self.model == "OPENCV" or self.model == "OPENCV_FISHEYE" or self.model == "FULL_OPENCV" or self.model == "FOV" or self.model == "THIN_PRISM_FISHEYE":
|
74 |
+
K[0, 0] = self.params[0]
|
75 |
+
K[1, 1] = self.params[1]
|
76 |
+
K[0, 2] = self.params[2]
|
77 |
+
K[1, 2] = self.params[3]
|
78 |
+
else:
|
79 |
+
raise NotImplementedError
|
80 |
+
return K
|
81 |
+
|
82 |
+
|
83 |
+
CAMERA_MODELS = {
|
84 |
+
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
|
85 |
+
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
|
86 |
+
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
|
87 |
+
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
|
88 |
+
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
|
89 |
+
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
|
90 |
+
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
|
91 |
+
CameraModel(model_id=7, model_name="FOV", num_params=5),
|
92 |
+
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
|
93 |
+
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
|
94 |
+
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
|
95 |
+
}
|
96 |
+
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
|
97 |
+
for camera_model in CAMERA_MODELS])
|
98 |
+
CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
|
99 |
+
for camera_model in CAMERA_MODELS])
|
100 |
+
|
101 |
+
|
102 |
+
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
|
103 |
+
"""Read and unpack the next bytes from a binary file.
|
104 |
+
:param fid:
|
105 |
+
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
|
106 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
107 |
+
:param endian_character: Any of {@, =, <, >, !}
|
108 |
+
:return: Tuple of read and unpacked values.
|
109 |
+
"""
|
110 |
+
data = fid.read(num_bytes)
|
111 |
+
return struct.unpack(endian_character + format_char_sequence, data)
|
112 |
+
|
113 |
+
|
114 |
+
def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
|
115 |
+
"""pack and write to a binary file.
|
116 |
+
:param fid:
|
117 |
+
:param data: data to send, if multiple elements are sent at the same time,
|
118 |
+
they should be encapsuled either in a list or a tuple
|
119 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
120 |
+
should be the same length as the data list or tuple
|
121 |
+
:param endian_character: Any of {@, =, <, >, !}
|
122 |
+
"""
|
123 |
+
if isinstance(data, (list, tuple)):
|
124 |
+
bytes = struct.pack(endian_character + format_char_sequence, *data)
|
125 |
+
else:
|
126 |
+
bytes = struct.pack(endian_character + format_char_sequence, data)
|
127 |
+
fid.write(bytes)
|
128 |
+
|
129 |
+
|
130 |
+
def read_cameras_text(path):
|
131 |
+
"""
|
132 |
+
see: src/base/reconstruction.cc
|
133 |
+
void Reconstruction::WriteCamerasText(const std::string& path)
|
134 |
+
void Reconstruction::ReadCamerasText(const std::string& path)
|
135 |
+
"""
|
136 |
+
cameras = {}
|
137 |
+
with open(path, "r") as fid:
|
138 |
+
while True:
|
139 |
+
line = fid.readline()
|
140 |
+
if not line:
|
141 |
+
break
|
142 |
+
line = line.strip()
|
143 |
+
if len(line) > 0 and line[0] != "#":
|
144 |
+
elems = line.split()
|
145 |
+
camera_id = int(elems[0])
|
146 |
+
model = elems[1]
|
147 |
+
width = int(elems[2])
|
148 |
+
height = int(elems[3])
|
149 |
+
params = np.array(tuple(map(float, elems[4:])))
|
150 |
+
cameras[camera_id] = Camera(id=camera_id, model=model,
|
151 |
+
width=width, height=height,
|
152 |
+
params=params)
|
153 |
+
return cameras
|
154 |
+
|
155 |
+
|
156 |
+
def read_cameras_binary(path_to_model_file):
|
157 |
+
"""
|
158 |
+
see: src/base/reconstruction.cc
|
159 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
160 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
161 |
+
"""
|
162 |
+
cameras = {}
|
163 |
+
with open(path_to_model_file, "rb") as fid:
|
164 |
+
num_cameras = read_next_bytes(fid, 8, "Q")[0]
|
165 |
+
for _ in range(num_cameras):
|
166 |
+
camera_properties = read_next_bytes(
|
167 |
+
fid, num_bytes=24, format_char_sequence="iiQQ")
|
168 |
+
camera_id = camera_properties[0]
|
169 |
+
model_id = camera_properties[1]
|
170 |
+
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
|
171 |
+
width = camera_properties[2]
|
172 |
+
height = camera_properties[3]
|
173 |
+
num_params = CAMERA_MODEL_IDS[model_id].num_params
|
174 |
+
params = read_next_bytes(fid, num_bytes=8*num_params,
|
175 |
+
format_char_sequence="d"*num_params)
|
176 |
+
cameras[camera_id] = Camera(id=camera_id,
|
177 |
+
model=model_name,
|
178 |
+
width=width,
|
179 |
+
height=height,
|
180 |
+
params=np.array(params))
|
181 |
+
assert len(cameras) == num_cameras
|
182 |
+
return cameras
|
183 |
+
|
184 |
+
|
185 |
+
def write_cameras_text(cameras, path):
|
186 |
+
"""
|
187 |
+
see: src/base/reconstruction.cc
|
188 |
+
void Reconstruction::WriteCamerasText(const std::string& path)
|
189 |
+
void Reconstruction::ReadCamerasText(const std::string& path)
|
190 |
+
"""
|
191 |
+
HEADER = "# Camera list with one line of data per camera:\n" + \
|
192 |
+
"# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + \
|
193 |
+
"# Number of cameras: {}\n".format(len(cameras))
|
194 |
+
with open(path, "w") as fid:
|
195 |
+
fid.write(HEADER)
|
196 |
+
for _, cam in cameras.items():
|
197 |
+
to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
|
198 |
+
line = " ".join([str(elem) for elem in to_write])
|
199 |
+
fid.write(line + "\n")
|
200 |
+
|
201 |
+
|
202 |
+
def write_cameras_binary(cameras, path_to_model_file):
|
203 |
+
"""
|
204 |
+
see: src/base/reconstruction.cc
|
205 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
206 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
207 |
+
"""
|
208 |
+
with open(path_to_model_file, "wb") as fid:
|
209 |
+
write_next_bytes(fid, len(cameras), "Q")
|
210 |
+
for _, cam in cameras.items():
|
211 |
+
model_id = CAMERA_MODEL_NAMES[cam.model].model_id
|
212 |
+
camera_properties = [cam.id,
|
213 |
+
model_id,
|
214 |
+
cam.width,
|
215 |
+
cam.height]
|
216 |
+
write_next_bytes(fid, camera_properties, "iiQQ")
|
217 |
+
for p in cam.params:
|
218 |
+
write_next_bytes(fid, float(p), "d")
|
219 |
+
return cameras
|
220 |
+
|
221 |
+
|
222 |
+
def read_images_text(path):
|
223 |
+
"""
|
224 |
+
see: src/base/reconstruction.cc
|
225 |
+
void Reconstruction::ReadImagesText(const std::string& path)
|
226 |
+
void Reconstruction::WriteImagesText(const std::string& path)
|
227 |
+
"""
|
228 |
+
images = {}
|
229 |
+
with open(path, "r") as fid:
|
230 |
+
while True:
|
231 |
+
line = fid.readline()
|
232 |
+
if not line:
|
233 |
+
break
|
234 |
+
line = line.strip()
|
235 |
+
if len(line) > 0 and line[0] != "#":
|
236 |
+
elems = line.split()
|
237 |
+
image_id = int(elems[0])
|
238 |
+
qvec = np.array(tuple(map(float, elems[1:5])))
|
239 |
+
tvec = np.array(tuple(map(float, elems[5:8])))
|
240 |
+
camera_id = int(elems[8])
|
241 |
+
image_name = elems[9]
|
242 |
+
elems = fid.readline().split()
|
243 |
+
xys = np.column_stack([tuple(map(float, elems[0::3])),
|
244 |
+
tuple(map(float, elems[1::3]))])
|
245 |
+
point3D_ids = np.array(tuple(map(int, elems[2::3])))
|
246 |
+
images[image_id] = Image(
|
247 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
248 |
+
camera_id=camera_id, name=image_name,
|
249 |
+
xys=xys, point3D_ids=point3D_ids)
|
250 |
+
return images
|
251 |
+
|
252 |
+
|
253 |
+
def read_images_binary(path_to_model_file):
|
254 |
+
"""
|
255 |
+
see: src/base/reconstruction.cc
|
256 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
257 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
258 |
+
"""
|
259 |
+
images = {}
|
260 |
+
with open(path_to_model_file, "rb") as fid:
|
261 |
+
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
|
262 |
+
for _ in range(num_reg_images):
|
263 |
+
binary_image_properties = read_next_bytes(
|
264 |
+
fid, num_bytes=64, format_char_sequence="idddddddi")
|
265 |
+
image_id = binary_image_properties[0]
|
266 |
+
qvec = np.array(binary_image_properties[1:5])
|
267 |
+
tvec = np.array(binary_image_properties[5:8])
|
268 |
+
camera_id = binary_image_properties[8]
|
269 |
+
image_name = ""
|
270 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
271 |
+
while current_char != b"\x00": # look for the ASCII 0 entry
|
272 |
+
image_name += current_char.decode("utf-8")
|
273 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
274 |
+
num_points2D = read_next_bytes(fid, num_bytes=8,
|
275 |
+
format_char_sequence="Q")[0]
|
276 |
+
x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
|
277 |
+
format_char_sequence="ddq"*num_points2D)
|
278 |
+
xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
|
279 |
+
tuple(map(float, x_y_id_s[1::3]))])
|
280 |
+
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
|
281 |
+
images[image_id] = Image(
|
282 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
283 |
+
camera_id=camera_id, name=image_name,
|
284 |
+
xys=xys, point3D_ids=point3D_ids)
|
285 |
+
return images
|
286 |
+
|
287 |
+
|
288 |
+
def write_images_text(images, path):
|
289 |
+
"""
|
290 |
+
see: src/base/reconstruction.cc
|
291 |
+
void Reconstruction::ReadImagesText(const std::string& path)
|
292 |
+
void Reconstruction::WriteImagesText(const std::string& path)
|
293 |
+
"""
|
294 |
+
if len(images) == 0:
|
295 |
+
mean_observations = 0
|
296 |
+
else:
|
297 |
+
mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images)
|
298 |
+
HEADER = "# Image list with two lines of data per image:\n" + \
|
299 |
+
"# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + \
|
300 |
+
"# POINTS2D[] as (X, Y, POINT3D_ID)\n" + \
|
301 |
+
"# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations)
|
302 |
+
|
303 |
+
with open(path, "w") as fid:
|
304 |
+
fid.write(HEADER)
|
305 |
+
for _, img in images.items():
|
306 |
+
image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
|
307 |
+
first_line = " ".join(map(str, image_header))
|
308 |
+
fid.write(first_line + "\n")
|
309 |
+
|
310 |
+
points_strings = []
|
311 |
+
for xy, point3D_id in zip(img.xys, img.point3D_ids):
|
312 |
+
points_strings.append(" ".join(map(str, [*xy, point3D_id])))
|
313 |
+
fid.write(" ".join(points_strings) + "\n")
|
314 |
+
|
315 |
+
|
316 |
+
def write_images_binary(images, path_to_model_file):
|
317 |
+
"""
|
318 |
+
see: src/base/reconstruction.cc
|
319 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
320 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
321 |
+
"""
|
322 |
+
with open(path_to_model_file, "wb") as fid:
|
323 |
+
write_next_bytes(fid, len(images), "Q")
|
324 |
+
for _, img in images.items():
|
325 |
+
write_next_bytes(fid, img.id, "i")
|
326 |
+
write_next_bytes(fid, img.qvec.tolist(), "dddd")
|
327 |
+
write_next_bytes(fid, img.tvec.tolist(), "ddd")
|
328 |
+
write_next_bytes(fid, img.camera_id, "i")
|
329 |
+
for char in img.name:
|
330 |
+
write_next_bytes(fid, char.encode("utf-8"), "c")
|
331 |
+
write_next_bytes(fid, b"\x00", "c")
|
332 |
+
write_next_bytes(fid, len(img.point3D_ids), "Q")
|
333 |
+
for xy, p3d_id in zip(img.xys, img.point3D_ids):
|
334 |
+
write_next_bytes(fid, [*xy, p3d_id], "ddq")
|
335 |
+
|
336 |
+
|
337 |
+
def read_points3D_text(path):
|
338 |
+
"""
|
339 |
+
see: src/base/reconstruction.cc
|
340 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
341 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
342 |
+
"""
|
343 |
+
points3D = {}
|
344 |
+
with open(path, "r") as fid:
|
345 |
+
while True:
|
346 |
+
line = fid.readline()
|
347 |
+
if not line:
|
348 |
+
break
|
349 |
+
line = line.strip()
|
350 |
+
if len(line) > 0 and line[0] != "#":
|
351 |
+
elems = line.split()
|
352 |
+
point3D_id = int(elems[0])
|
353 |
+
xyz = np.array(tuple(map(float, elems[1:4])))
|
354 |
+
rgb = np.array(tuple(map(int, elems[4:7])))
|
355 |
+
error = float(elems[7])
|
356 |
+
image_ids = np.array(tuple(map(int, elems[8::2])))
|
357 |
+
point2D_idxs = np.array(tuple(map(int, elems[9::2])))
|
358 |
+
points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
|
359 |
+
error=error, image_ids=image_ids,
|
360 |
+
point2D_idxs=point2D_idxs)
|
361 |
+
return points3D
|
362 |
+
|
363 |
+
|
364 |
+
def read_points3D_binary(path_to_model_file):
|
365 |
+
"""
|
366 |
+
see: src/base/reconstruction.cc
|
367 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
368 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
369 |
+
"""
|
370 |
+
points3D = {}
|
371 |
+
with open(path_to_model_file, "rb") as fid:
|
372 |
+
num_points = read_next_bytes(fid, 8, "Q")[0]
|
373 |
+
for _ in range(num_points):
|
374 |
+
binary_point_line_properties = read_next_bytes(
|
375 |
+
fid, num_bytes=43, format_char_sequence="QdddBBBd")
|
376 |
+
point3D_id = binary_point_line_properties[0]
|
377 |
+
xyz = np.array(binary_point_line_properties[1:4])
|
378 |
+
rgb = np.array(binary_point_line_properties[4:7])
|
379 |
+
error = np.array(binary_point_line_properties[7])
|
380 |
+
track_length = read_next_bytes(
|
381 |
+
fid, num_bytes=8, format_char_sequence="Q")[0]
|
382 |
+
track_elems = read_next_bytes(
|
383 |
+
fid, num_bytes=8*track_length,
|
384 |
+
format_char_sequence="ii"*track_length)
|
385 |
+
image_ids = np.array(tuple(map(int, track_elems[0::2])))
|
386 |
+
point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
|
387 |
+
points3D[point3D_id] = Point3D(
|
388 |
+
id=point3D_id, xyz=xyz, rgb=rgb,
|
389 |
+
error=error, image_ids=image_ids,
|
390 |
+
point2D_idxs=point2D_idxs)
|
391 |
+
return points3D
|
392 |
+
|
393 |
+
|
394 |
+
def write_points3D_text(points3D, path):
|
395 |
+
"""
|
396 |
+
see: src/base/reconstruction.cc
|
397 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
398 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
399 |
+
"""
|
400 |
+
if len(points3D) == 0:
|
401 |
+
mean_track_length = 0
|
402 |
+
else:
|
403 |
+
mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
|
404 |
+
HEADER = "# 3D point list with one line of data per point:\n" + \
|
405 |
+
"# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \
|
406 |
+
"# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)
|
407 |
+
|
408 |
+
with open(path, "w") as fid:
|
409 |
+
fid.write(HEADER)
|
410 |
+
for _, pt in points3D.items():
|
411 |
+
point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
|
412 |
+
fid.write(" ".join(map(str, point_header)) + " ")
|
413 |
+
track_strings = []
|
414 |
+
for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
|
415 |
+
track_strings.append(" ".join(map(str, [image_id, point2D])))
|
416 |
+
fid.write(" ".join(track_strings) + "\n")
|
417 |
+
|
418 |
+
|
419 |
+
def write_points3D_binary(points3D, path_to_model_file):
|
420 |
+
"""
|
421 |
+
see: src/base/reconstruction.cc
|
422 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
423 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
424 |
+
"""
|
425 |
+
with open(path_to_model_file, "wb") as fid:
|
426 |
+
write_next_bytes(fid, len(points3D), "Q")
|
427 |
+
for _, pt in points3D.items():
|
428 |
+
write_next_bytes(fid, pt.id, "Q")
|
429 |
+
write_next_bytes(fid, pt.xyz.tolist(), "ddd")
|
430 |
+
write_next_bytes(fid, pt.rgb.tolist(), "BBB")
|
431 |
+
write_next_bytes(fid, pt.error, "d")
|
432 |
+
track_length = pt.image_ids.shape[0]
|
433 |
+
write_next_bytes(fid, track_length, "Q")
|
434 |
+
for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
|
435 |
+
write_next_bytes(fid, [image_id, point2D_id], "ii")
|
436 |
+
|
437 |
+
|
438 |
+
def detect_model_format(path, ext):
|
439 |
+
if os.path.isfile(os.path.join(path, "cameras" + ext)) and \
|
440 |
+
os.path.isfile(os.path.join(path, "images" + ext)) and \
|
441 |
+
os.path.isfile(os.path.join(path, "points3D" + ext)):
|
442 |
+
print("Detected model format: '" + ext + "'")
|
443 |
+
return True
|
444 |
+
|
445 |
+
return False
|
446 |
+
|
447 |
+
|
448 |
+
def read_model(path, ext="") -> Tuple[Dict[int, Camera], Dict[int, Image], Dict[int, Point3D]]:
|
449 |
+
# try to detect the extension automatically
|
450 |
+
if ext == "":
|
451 |
+
if detect_model_format(path, ".bin"):
|
452 |
+
ext = ".bin"
|
453 |
+
elif detect_model_format(path, ".txt"):
|
454 |
+
ext = ".txt"
|
455 |
+
else:
|
456 |
+
raise ValueError("Provide model format: '.bin' or '.txt'")
|
457 |
+
|
458 |
+
if ext == ".txt":
|
459 |
+
cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
|
460 |
+
images = read_images_text(os.path.join(path, "images" + ext))
|
461 |
+
points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
|
462 |
+
else:
|
463 |
+
cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
|
464 |
+
images = read_images_binary(os.path.join(path, "images" + ext))
|
465 |
+
points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
|
466 |
+
return cameras, images, points3D
|
467 |
+
|
468 |
+
|
469 |
+
def write_model(cameras, images, points3D, path, ext=".bin"):
|
470 |
+
if ext == ".txt":
|
471 |
+
write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
|
472 |
+
write_images_text(images, os.path.join(path, "images" + ext))
|
473 |
+
write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
|
474 |
+
else:
|
475 |
+
write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
|
476 |
+
write_images_binary(images, os.path.join(path, "images" + ext))
|
477 |
+
write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
|
478 |
+
return cameras, images, points3D
|
479 |
+
|
480 |
+
|
481 |
+
def qvec2rotmat(qvec):
|
482 |
+
return np.array([
|
483 |
+
[1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
|
484 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
485 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
|
486 |
+
[2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
487 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
|
488 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
|
489 |
+
[2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
490 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
491 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
|
492 |
+
|
493 |
+
|
494 |
+
def rotmat2qvec(R):
|
495 |
+
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
|
496 |
+
K = np.array([
|
497 |
+
[Rxx - Ryy - Rzz, 0, 0, 0],
|
498 |
+
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
|
499 |
+
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
|
500 |
+
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
|
501 |
+
eigvals, eigvecs = np.linalg.eigh(K)
|
502 |
+
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
|
503 |
+
if qvec[0] < 0:
|
504 |
+
qvec *= -1
|
505 |
+
return qvec
|
506 |
+
|
507 |
+
|
508 |
+
def main():
|
509 |
+
parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models")
|
510 |
+
parser.add_argument("--input_model", help="path to input model folder")
|
511 |
+
parser.add_argument("--input_format", choices=[".bin", ".txt"],
|
512 |
+
help="input model format", default="")
|
513 |
+
parser.add_argument("--output_model",
|
514 |
+
help="path to output model folder")
|
515 |
+
parser.add_argument("--output_format", choices=[".bin", ".txt"],
|
516 |
+
help="outut model format", default=".txt")
|
517 |
+
args = parser.parse_args()
|
518 |
+
|
519 |
+
cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
|
520 |
+
|
521 |
+
print("num_cameras:", len(cameras))
|
522 |
+
print("num_images:", len(images))
|
523 |
+
print("num_points3D:", len(points3D))
|
524 |
+
|
525 |
+
if args.output_model is not None:
|
526 |
+
write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format)
|
527 |
+
|
528 |
+
|
529 |
+
if __name__ == "__main__":
|
530 |
+
main()
|
imcui/third_party/MatchAnything/src/utils/colmap/__init__.py
ADDED
File without changes
|
imcui/third_party/MatchAnything/src/utils/colmap/database.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Redistribution and use in source and binary forms, with or without
|
5 |
+
# modification, are permitted provided that the following conditions are met:
|
6 |
+
#
|
7 |
+
# * Redistributions of source code must retain the above copyright
|
8 |
+
# notice, this list of conditions and the following disclaimer.
|
9 |
+
#
|
10 |
+
# * Redistributions in binary form must reproduce the above copyright
|
11 |
+
# notice, this list of conditions and the following disclaimer in the
|
12 |
+
# documentation and/or other materials provided with the distribution.
|
13 |
+
#
|
14 |
+
# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
|
15 |
+
# its contributors may be used to endorse or promote products derived
|
16 |
+
# from this software without specific prior written permission.
|
17 |
+
#
|
18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
19 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
20 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
21 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
|
22 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
23 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
24 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
25 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
26 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
27 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
28 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
#
|
30 |
+
# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
|
31 |
+
|
32 |
+
# This script is based on an original implementation by True Price.
|
33 |
+
|
34 |
+
import sys
|
35 |
+
import sqlite3
|
36 |
+
import numpy as np
|
37 |
+
from loguru import logger
|
38 |
+
|
39 |
+
|
40 |
+
IS_PYTHON3 = sys.version_info[0] >= 3
|
41 |
+
|
42 |
+
MAX_IMAGE_ID = 2**31 - 1
|
43 |
+
|
44 |
+
CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
|
45 |
+
camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
46 |
+
model INTEGER NOT NULL,
|
47 |
+
width INTEGER NOT NULL,
|
48 |
+
height INTEGER NOT NULL,
|
49 |
+
params BLOB,
|
50 |
+
prior_focal_length INTEGER NOT NULL)"""
|
51 |
+
|
52 |
+
CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
|
53 |
+
image_id INTEGER PRIMARY KEY NOT NULL,
|
54 |
+
rows INTEGER NOT NULL,
|
55 |
+
cols INTEGER NOT NULL,
|
56 |
+
data BLOB,
|
57 |
+
FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
|
58 |
+
|
59 |
+
CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
|
60 |
+
image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
61 |
+
name TEXT NOT NULL UNIQUE,
|
62 |
+
camera_id INTEGER NOT NULL,
|
63 |
+
prior_qw REAL,
|
64 |
+
prior_qx REAL,
|
65 |
+
prior_qy REAL,
|
66 |
+
prior_qz REAL,
|
67 |
+
prior_tx REAL,
|
68 |
+
prior_ty REAL,
|
69 |
+
prior_tz REAL,
|
70 |
+
CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
|
71 |
+
FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
|
72 |
+
""".format(MAX_IMAGE_ID)
|
73 |
+
|
74 |
+
CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
|
75 |
+
CREATE TABLE IF NOT EXISTS two_view_geometries (
|
76 |
+
pair_id INTEGER PRIMARY KEY NOT NULL,
|
77 |
+
rows INTEGER NOT NULL,
|
78 |
+
cols INTEGER NOT NULL,
|
79 |
+
data BLOB,
|
80 |
+
config INTEGER NOT NULL,
|
81 |
+
F BLOB,
|
82 |
+
E BLOB,
|
83 |
+
H BLOB,
|
84 |
+
qvec BLOB,
|
85 |
+
tvec BLOB)
|
86 |
+
"""
|
87 |
+
|
88 |
+
CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
|
89 |
+
image_id INTEGER PRIMARY KEY NOT NULL,
|
90 |
+
rows INTEGER NOT NULL,
|
91 |
+
cols INTEGER NOT NULL,
|
92 |
+
data BLOB,
|
93 |
+
FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
|
94 |
+
"""
|
95 |
+
|
96 |
+
CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
|
97 |
+
pair_id INTEGER PRIMARY KEY NOT NULL,
|
98 |
+
rows INTEGER NOT NULL,
|
99 |
+
cols INTEGER NOT NULL,
|
100 |
+
data BLOB)"""
|
101 |
+
|
102 |
+
CREATE_NAME_INDEX = \
|
103 |
+
"CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
|
104 |
+
|
105 |
+
CREATE_ALL = "; ".join([
|
106 |
+
CREATE_CAMERAS_TABLE,
|
107 |
+
CREATE_IMAGES_TABLE,
|
108 |
+
CREATE_KEYPOINTS_TABLE,
|
109 |
+
CREATE_DESCRIPTORS_TABLE,
|
110 |
+
CREATE_MATCHES_TABLE,
|
111 |
+
CREATE_TWO_VIEW_GEOMETRIES_TABLE,
|
112 |
+
CREATE_NAME_INDEX
|
113 |
+
])
|
114 |
+
|
115 |
+
|
116 |
+
def image_ids_to_pair_id(image_id1, image_id2):
|
117 |
+
if image_id1 > image_id2:
|
118 |
+
image_id1, image_id2 = image_id2, image_id1
|
119 |
+
return image_id1 * MAX_IMAGE_ID + image_id2
|
120 |
+
|
121 |
+
|
122 |
+
def pair_id_to_image_ids(pair_id):
|
123 |
+
image_id2 = pair_id % MAX_IMAGE_ID
|
124 |
+
image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
|
125 |
+
return image_id1, image_id2
|
126 |
+
|
127 |
+
|
128 |
+
def array_to_blob(array):
|
129 |
+
if IS_PYTHON3:
|
130 |
+
return array.tobytes()
|
131 |
+
else:
|
132 |
+
return np.getbuffer(array)
|
133 |
+
|
134 |
+
|
135 |
+
def blob_to_array(blob, dtype, shape=(-1,)):
|
136 |
+
if IS_PYTHON3:
|
137 |
+
return np.fromstring(blob, dtype=dtype).reshape(*shape)
|
138 |
+
else:
|
139 |
+
return np.frombuffer(blob, dtype=dtype).reshape(*shape)
|
140 |
+
|
141 |
+
|
142 |
+
class COLMAPDatabase(sqlite3.Connection):
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def connect(database_path):
|
146 |
+
return sqlite3.connect(str(database_path), factory=COLMAPDatabase)
|
147 |
+
|
148 |
+
|
149 |
+
def __init__(self, *args, **kwargs):
|
150 |
+
super(COLMAPDatabase, self).__init__(*args, **kwargs)
|
151 |
+
|
152 |
+
self.create_tables = lambda: self.executescript(CREATE_ALL)
|
153 |
+
self.create_cameras_table = \
|
154 |
+
lambda: self.executescript(CREATE_CAMERAS_TABLE)
|
155 |
+
self.create_descriptors_table = \
|
156 |
+
lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
|
157 |
+
self.create_images_table = \
|
158 |
+
lambda: self.executescript(CREATE_IMAGES_TABLE)
|
159 |
+
self.create_two_view_geometries_table = \
|
160 |
+
lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
|
161 |
+
self.create_keypoints_table = \
|
162 |
+
lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
|
163 |
+
self.create_matches_table = \
|
164 |
+
lambda: self.executescript(CREATE_MATCHES_TABLE)
|
165 |
+
self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
|
166 |
+
|
167 |
+
def add_camera(self, model, width, height, params,
|
168 |
+
prior_focal_length=False, camera_id=None):
|
169 |
+
params = np.asarray(params, np.float64)
|
170 |
+
cursor = self.execute(
|
171 |
+
"INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
|
172 |
+
(camera_id, model, width, height, array_to_blob(params),
|
173 |
+
prior_focal_length))
|
174 |
+
return cursor.lastrowid
|
175 |
+
|
176 |
+
def add_image(self, name, camera_id,
|
177 |
+
prior_q=np.full(4, np.NaN), prior_t=np.full(3, np.NaN),
|
178 |
+
image_id=None):
|
179 |
+
cursor = self.execute(
|
180 |
+
"INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
181 |
+
(image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
|
182 |
+
prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
|
183 |
+
return cursor.lastrowid
|
184 |
+
|
185 |
+
def add_keypoints(self, image_id, keypoints):
|
186 |
+
assert(len(keypoints.shape) == 2)
|
187 |
+
assert(keypoints.shape[1] in [2, 4, 6])
|
188 |
+
|
189 |
+
keypoints = np.asarray(keypoints, np.float32)
|
190 |
+
self.execute(
|
191 |
+
"INSERT INTO keypoints VALUES (?, ?, ?, ?)",
|
192 |
+
(image_id,) + keypoints.shape + (array_to_blob(keypoints),))
|
193 |
+
|
194 |
+
def add_descriptors(self, image_id, descriptors):
|
195 |
+
descriptors = np.ascontiguousarray(descriptors, np.uint8)
|
196 |
+
self.execute(
|
197 |
+
"INSERT INTO descriptors VALUES (?, ?, ?, ?)",
|
198 |
+
(image_id,) + descriptors.shape + (array_to_blob(descriptors),))
|
199 |
+
|
200 |
+
def add_matches(self, image_id1, image_id2, matches):
|
201 |
+
assert(len(matches.shape) == 2)
|
202 |
+
assert(matches.shape[1] == 2)
|
203 |
+
|
204 |
+
if image_id1 > image_id2:
|
205 |
+
matches = matches[:,::-1]
|
206 |
+
|
207 |
+
pair_id = image_ids_to_pair_id(image_id1, image_id2)
|
208 |
+
matches = np.asarray(matches, np.uint32)
|
209 |
+
self.execute(
|
210 |
+
"INSERT INTO matches VALUES (?, ?, ?, ?)",
|
211 |
+
(pair_id,) + matches.shape + (array_to_blob(matches),))
|
212 |
+
|
213 |
+
def add_two_view_geometry(self, image_id1, image_id2, matches,
|
214 |
+
F=np.eye(3), E=np.eye(3), H=np.eye(3),
|
215 |
+
qvec=np.array([1.0, 0.0, 0.0, 0.0]),
|
216 |
+
tvec=np.zeros(3), config=2):
|
217 |
+
assert(len(matches.shape) == 2)
|
218 |
+
assert(matches.shape[1] == 2)
|
219 |
+
|
220 |
+
if image_id1 > image_id2:
|
221 |
+
matches = matches[:,::-1]
|
222 |
+
|
223 |
+
pair_id = image_ids_to_pair_id(image_id1, image_id2)
|
224 |
+
matches = np.asarray(matches, np.uint32)
|
225 |
+
F = np.asarray(F, dtype=np.float64)
|
226 |
+
E = np.asarray(E, dtype=np.float64)
|
227 |
+
H = np.asarray(H, dtype=np.float64)
|
228 |
+
qvec = np.asarray(qvec, dtype=np.float64)
|
229 |
+
tvec = np.asarray(tvec, dtype=np.float64)
|
230 |
+
self.execute(
|
231 |
+
"INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
232 |
+
(pair_id,) + matches.shape + (array_to_blob(matches), config,
|
233 |
+
array_to_blob(F), array_to_blob(E), array_to_blob(H),
|
234 |
+
array_to_blob(qvec), array_to_blob(tvec)))
|
235 |
+
|
236 |
+
def update_two_view_geometry(self, image_id1, image_id2, matches,
|
237 |
+
F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
|
238 |
+
assert(len(matches.shape) == 2)
|
239 |
+
assert(matches.shape[1] == 2)
|
240 |
+
|
241 |
+
if image_id1 > image_id2:
|
242 |
+
matches = matches[:,::-1]
|
243 |
+
|
244 |
+
pair_id = image_ids_to_pair_id(image_id1, image_id2)
|
245 |
+
matches = np.asarray(matches, np.uint32)
|
246 |
+
F = np.asarray(F, dtype=np.float64)
|
247 |
+
E = np.asarray(E, dtype=np.float64)
|
248 |
+
H = np.asarray(H, dtype=np.float64)
|
249 |
+
|
250 |
+
# Find whether exists:
|
251 |
+
row = self.execute(f"SELECT * FROM two_view_geometries WHERE pair_id = {pair_id} ")
|
252 |
+
data = list(next(row))
|
253 |
+
try:
|
254 |
+
matches_old = blob_to_array(data[3], np.uint32, (-1, 2))
|
255 |
+
except:
|
256 |
+
matches_old = None
|
257 |
+
|
258 |
+
if matches_old is not None:
|
259 |
+
for match in matches:
|
260 |
+
img0_id, img1_id = match
|
261 |
+
|
262 |
+
# Find duplicated pts
|
263 |
+
img0_dup_idxs = np.where(matches_old[:, 0] == img0_id)
|
264 |
+
img1_dup_idxs = np.where(matches_old[:, 1] == img1_id)
|
265 |
+
|
266 |
+
if len(img0_dup_idxs[0]) == 0 and len(img1_dup_idxs[0]) == 0:
|
267 |
+
# No duplicated matches:
|
268 |
+
matches_old = np.concatenate([matches_old, match[None]], axis=0)
|
269 |
+
elif len(img0_dup_idxs[0]) == 1 and len(img1_dup_idxs[0]) == 0:
|
270 |
+
matches_old[img0_dup_idxs[0]][0,1] = img1_id
|
271 |
+
elif len(img0_dup_idxs[0]) == 0 and len(img1_dup_idxs[0]) == 1:
|
272 |
+
matches_old[img1_dup_idxs[0]][0,0] = img0_id
|
273 |
+
elif len(img0_dup_idxs[0]) == 1 and len(img1_dup_idxs[0]) == 1:
|
274 |
+
if img0_dup_idxs[0] != img1_dup_idxs[0]:
|
275 |
+
# logger.warning(f"Duplicated matches exists!")
|
276 |
+
matches_old[img0_dup_idxs[0]][0,1] = img1_id
|
277 |
+
matches_old[img1_dup_idxs[0]][0,0] = img0_id
|
278 |
+
else:
|
279 |
+
raise NotImplementedError
|
280 |
+
|
281 |
+
# matches = np.concatenate([matches_old, matches], axis=0) # N * 2
|
282 |
+
matches = matches_old
|
283 |
+
self.execute(f"DELETE FROM two_view_geometries WHERE pair_id = {pair_id}")
|
284 |
+
|
285 |
+
data[1:4] = matches.shape + (array_to_blob(np.asarray(matches, np.uint32)),)
|
286 |
+
self.execute("INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", tuple(data))
|
287 |
+
else:
|
288 |
+
raise NotImplementedError
|
289 |
+
|
290 |
+
# self.add_two_view_geometry(image_id1, image_id2, matches)
|
291 |
+
|
292 |
+
|
293 |
+
def example_usage():
|
294 |
+
import os
|
295 |
+
import argparse
|
296 |
+
|
297 |
+
parser = argparse.ArgumentParser()
|
298 |
+
parser.add_argument("--database_path", default="database.db")
|
299 |
+
args = parser.parse_args()
|
300 |
+
|
301 |
+
if os.path.exists(args.database_path):
|
302 |
+
print("ERROR: database path already exists -- will not modify it.")
|
303 |
+
return
|
304 |
+
|
305 |
+
# Open the database.
|
306 |
+
|
307 |
+
db = COLMAPDatabase.connect(args.database_path)
|
308 |
+
|
309 |
+
# For convenience, try creating all the tables upfront.
|
310 |
+
|
311 |
+
db.create_tables()
|
312 |
+
|
313 |
+
# Create dummy cameras.
|
314 |
+
|
315 |
+
model1, width1, height1, params1 = \
|
316 |
+
0, 1024, 768, np.array((1024., 512., 384.))
|
317 |
+
model2, width2, height2, params2 = \
|
318 |
+
2, 1024, 768, np.array((1024., 512., 384., 0.1))
|
319 |
+
|
320 |
+
camera_id1 = db.add_camera(model1, width1, height1, params1)
|
321 |
+
camera_id2 = db.add_camera(model2, width2, height2, params2)
|
322 |
+
|
323 |
+
# Create dummy images.
|
324 |
+
|
325 |
+
image_id1 = db.add_image("image1.png", camera_id1)
|
326 |
+
image_id2 = db.add_image("image2.png", camera_id1)
|
327 |
+
image_id3 = db.add_image("image3.png", camera_id2)
|
328 |
+
image_id4 = db.add_image("image4.png", camera_id2)
|
329 |
+
|
330 |
+
# Create dummy keypoints.
|
331 |
+
#
|
332 |
+
# Note that COLMAP supports:
|
333 |
+
# - 2D keypoints: (x, y)
|
334 |
+
# - 4D keypoints: (x, y, theta, scale)
|
335 |
+
# - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
|
336 |
+
|
337 |
+
num_keypoints = 1000
|
338 |
+
keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
|
339 |
+
keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
|
340 |
+
keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
|
341 |
+
keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
|
342 |
+
|
343 |
+
db.add_keypoints(image_id1, keypoints1)
|
344 |
+
db.add_keypoints(image_id2, keypoints2)
|
345 |
+
db.add_keypoints(image_id3, keypoints3)
|
346 |
+
db.add_keypoints(image_id4, keypoints4)
|
347 |
+
|
348 |
+
# Create dummy matches.
|
349 |
+
|
350 |
+
M = 50
|
351 |
+
matches12 = np.random.randint(num_keypoints, size=(M, 2))
|
352 |
+
matches23 = np.random.randint(num_keypoints, size=(M, 2))
|
353 |
+
matches34 = np.random.randint(num_keypoints, size=(M, 2))
|
354 |
+
|
355 |
+
db.add_matches(image_id1, image_id2, matches12)
|
356 |
+
db.add_matches(image_id2, image_id3, matches23)
|
357 |
+
db.add_matches(image_id3, image_id4, matches34)
|
358 |
+
|
359 |
+
# Commit the data to the file.
|
360 |
+
|
361 |
+
db.commit()
|
362 |
+
|
363 |
+
# Read and check cameras.
|
364 |
+
|
365 |
+
rows = db.execute("SELECT * FROM cameras")
|
366 |
+
|
367 |
+
camera_id, model, width, height, params, prior = next(rows)
|
368 |
+
params = blob_to_array(params, np.float64)
|
369 |
+
assert camera_id == camera_id1
|
370 |
+
assert model == model1 and width == width1 and height == height1
|
371 |
+
assert np.allclose(params, params1)
|
372 |
+
|
373 |
+
camera_id, model, width, height, params, prior = next(rows)
|
374 |
+
params = blob_to_array(params, np.float64)
|
375 |
+
assert camera_id == camera_id2
|
376 |
+
assert model == model2 and width == width2 and height == height2
|
377 |
+
assert np.allclose(params, params2)
|
378 |
+
|
379 |
+
# Read and check keypoints.
|
380 |
+
|
381 |
+
keypoints = dict(
|
382 |
+
(image_id, blob_to_array(data, np.float32, (-1, 2)))
|
383 |
+
for image_id, data in db.execute(
|
384 |
+
"SELECT image_id, data FROM keypoints"))
|
385 |
+
|
386 |
+
assert np.allclose(keypoints[image_id1], keypoints1)
|
387 |
+
assert np.allclose(keypoints[image_id2], keypoints2)
|
388 |
+
assert np.allclose(keypoints[image_id3], keypoints3)
|
389 |
+
assert np.allclose(keypoints[image_id4], keypoints4)
|
390 |
+
|
391 |
+
# Read and check matches.
|
392 |
+
|
393 |
+
pair_ids = [image_ids_to_pair_id(*pair) for pair in
|
394 |
+
((image_id1, image_id2),
|
395 |
+
(image_id2, image_id3),
|
396 |
+
(image_id3, image_id4))]
|
397 |
+
|
398 |
+
matches = dict(
|
399 |
+
(pair_id_to_image_ids(pair_id),
|
400 |
+
blob_to_array(data, np.uint32, (-1, 2)))
|
401 |
+
for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
|
402 |
+
)
|
403 |
+
|
404 |
+
assert np.all(matches[(image_id1, image_id2)] == matches12)
|
405 |
+
assert np.all(matches[(image_id2, image_id3)] == matches23)
|
406 |
+
assert np.all(matches[(image_id3, image_id4)] == matches34)
|
407 |
+
|
408 |
+
# Clean up.
|
409 |
+
|
410 |
+
db.close()
|
411 |
+
|
412 |
+
if os.path.exists(args.database_path):
|
413 |
+
os.remove(args.database_path)
|
414 |
+
|
415 |
+
|
416 |
+
if __name__ == "__main__":
|
417 |
+
example_usage()
|
imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from .read_write_model import read_images_binary
|
6 |
+
|
7 |
+
|
8 |
+
def align_model(model, rot, trans, scale):
|
9 |
+
return (np.matmul(rot, model) + trans) * scale
|
10 |
+
|
11 |
+
|
12 |
+
def align(model, data):
|
13 |
+
'''
|
14 |
+
Source: https://vision.in.tum.de/data/datasets/rgbd-dataset/tools
|
15 |
+
#absolute_trajectory_error_ate
|
16 |
+
Align two trajectories using the method of Horn (closed-form).
|
17 |
+
|
18 |
+
Input:
|
19 |
+
model -- first trajectory (3xn)
|
20 |
+
data -- second trajectory (3xn)
|
21 |
+
|
22 |
+
Output:
|
23 |
+
rot -- rotation matrix (3x3)
|
24 |
+
trans -- translation vector (3x1)
|
25 |
+
trans_error -- translational error per point (1xn)
|
26 |
+
|
27 |
+
'''
|
28 |
+
|
29 |
+
if model.shape[1] < 3:
|
30 |
+
print('Need at least 3 points for ATE: {}'.format(model))
|
31 |
+
return np.identity(3), np.zeros((3, 1)), 1
|
32 |
+
|
33 |
+
# Get zero centered point cloud
|
34 |
+
model_zerocentered = model - model.mean(1, keepdims=True)
|
35 |
+
data_zerocentered = data - data.mean(1, keepdims=True)
|
36 |
+
|
37 |
+
# constructed covariance matrix
|
38 |
+
W = np.zeros((3, 3))
|
39 |
+
for column in range(model.shape[1]):
|
40 |
+
W += np.outer(model_zerocentered[:, column],
|
41 |
+
data_zerocentered[:, column])
|
42 |
+
|
43 |
+
# SVD
|
44 |
+
U, d, Vh = np.linalg.linalg.svd(W.transpose())
|
45 |
+
S = np.identity(3)
|
46 |
+
if (np.linalg.det(U) * np.linalg.det(Vh) < 0):
|
47 |
+
S[2, 2] = -1
|
48 |
+
rot = np.matmul(np.matmul(U, S), Vh)
|
49 |
+
trans = data.mean(1, keepdims=True) - np.matmul(
|
50 |
+
rot, model.mean(1, keepdims=True))
|
51 |
+
|
52 |
+
# apply rot and trans to point cloud
|
53 |
+
model_aligned = align_model(model, rot, trans, 1.0)
|
54 |
+
model_aligned_zerocentered = model_aligned - model_aligned.mean(
|
55 |
+
1, keepdims=True)
|
56 |
+
|
57 |
+
# calc scale based on distance to point cloud center
|
58 |
+
data_dist = np.sqrt((data_zerocentered * data_zerocentered).sum(axis=0))
|
59 |
+
model_aligned_dist = np.sqrt(
|
60 |
+
(model_aligned_zerocentered * model_aligned_zerocentered).sum(axis=0))
|
61 |
+
scale_array = data_dist / model_aligned_dist
|
62 |
+
scale = np.median(scale_array)
|
63 |
+
|
64 |
+
return rot, trans, scale
|
65 |
+
|
66 |
+
|
67 |
+
def quaternion_matrix(quaternion):
|
68 |
+
'''Return homogeneous rotation matrix from quaternion.
|
69 |
+
|
70 |
+
>>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
|
71 |
+
>>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
|
72 |
+
True
|
73 |
+
>>> M = quaternion_matrix([1, 0, 0, 0])
|
74 |
+
>>> numpy.allclose(M, numpy.identity(4))
|
75 |
+
True
|
76 |
+
>>> M = quaternion_matrix([0, 1, 0, 0])
|
77 |
+
>>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
|
78 |
+
True
|
79 |
+
'''
|
80 |
+
|
81 |
+
q = np.array(quaternion, dtype=np.float64, copy=True)
|
82 |
+
n = np.dot(q, q)
|
83 |
+
if n < _EPS:
|
84 |
+
return np.identity(4)
|
85 |
+
|
86 |
+
q *= math.sqrt(2.0 / n)
|
87 |
+
q = np.outer(q, q)
|
88 |
+
|
89 |
+
return np.array(
|
90 |
+
[[1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0],
|
91 |
+
[q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0],
|
92 |
+
[q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0],
|
93 |
+
[0.0, 0.0, 0.0, 1.0]])
|
94 |
+
|
95 |
+
|
96 |
+
def quaternion_from_matrix(matrix, isprecise=False):
|
97 |
+
'''Return quaternion from rotation matrix.
|
98 |
+
|
99 |
+
If isprecise is True, the input matrix is assumed to be a precise rotation
|
100 |
+
matrix and a faster algorithm is used.
|
101 |
+
|
102 |
+
>>> q = quaternion_from_matrix(numpy.identity(4), True)
|
103 |
+
>>> numpy.allclose(q, [1, 0, 0, 0])
|
104 |
+
True
|
105 |
+
>>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
|
106 |
+
>>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
|
107 |
+
True
|
108 |
+
>>> R = rotation_matrix(0.123, (1, 2, 3))
|
109 |
+
>>> q = quaternion_from_matrix(R, True)
|
110 |
+
>>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
|
111 |
+
True
|
112 |
+
>>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
|
113 |
+
... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
|
114 |
+
>>> q = quaternion_from_matrix(R)
|
115 |
+
>>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
|
116 |
+
True
|
117 |
+
>>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
|
118 |
+
... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
|
119 |
+
>>> q = quaternion_from_matrix(R)
|
120 |
+
>>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
|
121 |
+
True
|
122 |
+
>>> R = random_rotation_matrix()
|
123 |
+
>>> q = quaternion_from_matrix(R)
|
124 |
+
>>> is_same_transform(R, quaternion_matrix(q))
|
125 |
+
True
|
126 |
+
>>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
|
127 |
+
>>> numpy.allclose(quaternion_from_matrix(R, isprecise=False),
|
128 |
+
... quaternion_from_matrix(R, isprecise=True))
|
129 |
+
True
|
130 |
+
|
131 |
+
'''
|
132 |
+
|
133 |
+
M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4]
|
134 |
+
if isprecise:
|
135 |
+
q = np.empty((4, ))
|
136 |
+
t = np.trace(M)
|
137 |
+
if t > M[3, 3]:
|
138 |
+
q[0] = t
|
139 |
+
q[3] = M[1, 0] - M[0, 1]
|
140 |
+
q[2] = M[0, 2] - M[2, 0]
|
141 |
+
q[1] = M[2, 1] - M[1, 2]
|
142 |
+
else:
|
143 |
+
i, j, k = 1, 2, 3
|
144 |
+
if M[1, 1] > M[0, 0]:
|
145 |
+
i, j, k = 2, 3, 1
|
146 |
+
if M[2, 2] > M[i, i]:
|
147 |
+
i, j, k = 3, 1, 2
|
148 |
+
t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
|
149 |
+
q[i] = t
|
150 |
+
q[j] = M[i, j] + M[j, i]
|
151 |
+
q[k] = M[k, i] + M[i, k]
|
152 |
+
q[3] = M[k, j] - M[j, k]
|
153 |
+
q *= 0.5 / math.sqrt(t * M[3, 3])
|
154 |
+
else:
|
155 |
+
m00 = M[0, 0]
|
156 |
+
m01 = M[0, 1]
|
157 |
+
m02 = M[0, 2]
|
158 |
+
m10 = M[1, 0]
|
159 |
+
m11 = M[1, 1]
|
160 |
+
m12 = M[1, 2]
|
161 |
+
m20 = M[2, 0]
|
162 |
+
m21 = M[2, 1]
|
163 |
+
m22 = M[2, 2]
|
164 |
+
|
165 |
+
# symmetric matrix K
|
166 |
+
K = np.array([[m00 - m11 - m22, 0.0, 0.0, 0.0],
|
167 |
+
[m01 + m10, m11 - m00 - m22, 0.0, 0.0],
|
168 |
+
[m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
|
169 |
+
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22]])
|
170 |
+
K /= 3.0
|
171 |
+
|
172 |
+
# quaternion is eigenvector of K that corresponds to largest eigenvalue
|
173 |
+
w, V = np.linalg.eigh(K)
|
174 |
+
q = V[[3, 0, 1, 2], np.argmax(w)]
|
175 |
+
|
176 |
+
if q[0] < 0.0:
|
177 |
+
np.negative(q, q)
|
178 |
+
|
179 |
+
return q
|
180 |
+
|
181 |
+
def is_colmap_img_valid(colmap_img_file):
|
182 |
+
'''Return validity of a colmap reconstruction'''
|
183 |
+
|
184 |
+
images_bin = read_images_binary(colmap_img_file)
|
185 |
+
# Check if everything is finite for this subset
|
186 |
+
for key in images_bin.keys():
|
187 |
+
q = np.asarray(images_bin[key].qvec).flatten()
|
188 |
+
t = np.asarray(images_bin[key].tvec).flatten()
|
189 |
+
|
190 |
+
is_cur_valid = True
|
191 |
+
is_cur_valid = is_cur_valid and q.shape == (4, )
|
192 |
+
is_cur_valid = is_cur_valid and t.shape == (3, )
|
193 |
+
is_cur_valid = is_cur_valid and np.all(np.isfinite(q))
|
194 |
+
is_cur_valid = is_cur_valid and np.all(np.isfinite(t))
|
195 |
+
|
196 |
+
# If any is invalid, immediately return
|
197 |
+
if not is_cur_valid:
|
198 |
+
return False
|
199 |
+
|
200 |
+
return True
|
201 |
+
|
202 |
+
def get_best_colmap_index(colmap_output_path):
|
203 |
+
'''
|
204 |
+
Determines the colmap model with the most images if there is more than one.
|
205 |
+
'''
|
206 |
+
|
207 |
+
# First find the colmap reconstruction with the most number of images.
|
208 |
+
best_index, best_num_images = -1, 0
|
209 |
+
|
210 |
+
# Check all valid sub reconstructions.
|
211 |
+
if os.path.exists(colmap_output_path):
|
212 |
+
idx_list = [
|
213 |
+
_d for _d in os.listdir(colmap_output_path)
|
214 |
+
if os.path.isdir(os.path.join(colmap_output_path, _d))
|
215 |
+
]
|
216 |
+
else:
|
217 |
+
idx_list = []
|
218 |
+
|
219 |
+
for cur_index in idx_list:
|
220 |
+
cur_output_path = os.path.join(colmap_output_path, cur_index)
|
221 |
+
if os.path.isdir(cur_output_path):
|
222 |
+
colmap_img_file = os.path.join(cur_output_path, 'images.bin')
|
223 |
+
images_bin = read_images_binary(colmap_img_file)
|
224 |
+
# Check validity
|
225 |
+
if not is_colmap_img_valid(colmap_img_file):
|
226 |
+
continue
|
227 |
+
# Find the reconstruction with most number of images
|
228 |
+
if len(images_bin) > best_num_images:
|
229 |
+
best_index = int(cur_index)
|
230 |
+
best_num_images = len(images_bin)
|
231 |
+
|
232 |
+
return str(best_index)
|
imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Redistribution and use in source and binary forms, with or without
|
5 |
+
# modification, are permitted provided that the following conditions are met:
|
6 |
+
#
|
7 |
+
# * Redistributions of source code must retain the above copyright
|
8 |
+
# notice, this list of conditions and the following disclaimer.
|
9 |
+
#
|
10 |
+
# * Redistributions in binary form must reproduce the above copyright
|
11 |
+
# notice, this list of conditions and the following disclaimer in the
|
12 |
+
# documentation and/or other materials provided with the distribution.
|
13 |
+
#
|
14 |
+
# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
|
15 |
+
# its contributors may be used to endorse or promote products derived
|
16 |
+
# from this software without specific prior written permission.
|
17 |
+
#
|
18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
19 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
20 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
21 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
|
22 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
23 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
24 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
25 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
26 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
27 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
28 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
29 |
+
#
|
30 |
+
# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
|
31 |
+
|
32 |
+
import os
|
33 |
+
import sys
|
34 |
+
import collections
|
35 |
+
import numpy as np
|
36 |
+
import struct
|
37 |
+
import argparse
|
38 |
+
|
39 |
+
|
40 |
+
CameraModel = collections.namedtuple(
|
41 |
+
"CameraModel", ["model_id", "model_name", "num_params"])
|
42 |
+
Camera = collections.namedtuple(
|
43 |
+
"Camera", ["id", "model", "width", "height", "params"])
|
44 |
+
BaseImage = collections.namedtuple(
|
45 |
+
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
|
46 |
+
Point3D = collections.namedtuple(
|
47 |
+
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
|
48 |
+
|
49 |
+
|
50 |
+
class Image(BaseImage):
|
51 |
+
def qvec2rotmat(self):
|
52 |
+
return qvec2rotmat(self.qvec)
|
53 |
+
|
54 |
+
|
55 |
+
CAMERA_MODELS = {
|
56 |
+
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
|
57 |
+
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
|
58 |
+
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
|
59 |
+
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
|
60 |
+
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
|
61 |
+
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
|
62 |
+
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
|
63 |
+
CameraModel(model_id=7, model_name="FOV", num_params=5),
|
64 |
+
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
|
65 |
+
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
|
66 |
+
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
|
67 |
+
}
|
68 |
+
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
|
69 |
+
for camera_model in CAMERA_MODELS])
|
70 |
+
CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
|
71 |
+
for camera_model in CAMERA_MODELS])
|
72 |
+
|
73 |
+
|
74 |
+
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
|
75 |
+
"""Read and unpack the next bytes from a binary file.
|
76 |
+
:param fid:
|
77 |
+
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
|
78 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
79 |
+
:param endian_character: Any of {@, =, <, >, !}
|
80 |
+
:return: Tuple of read and unpacked values.
|
81 |
+
"""
|
82 |
+
data = fid.read(num_bytes)
|
83 |
+
return struct.unpack(endian_character + format_char_sequence, data)
|
84 |
+
|
85 |
+
|
86 |
+
def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
|
87 |
+
"""pack and write to a binary file.
|
88 |
+
:param fid:
|
89 |
+
:param data: data to send, if multiple elements are sent at the same time,
|
90 |
+
they should be encapsuled either in a list or a tuple
|
91 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
92 |
+
should be the same length as the data list or tuple
|
93 |
+
:param endian_character: Any of {@, =, <, >, !}
|
94 |
+
"""
|
95 |
+
if isinstance(data, (list, tuple)):
|
96 |
+
bytes = struct.pack(endian_character + format_char_sequence, *data)
|
97 |
+
else:
|
98 |
+
bytes = struct.pack(endian_character + format_char_sequence, data)
|
99 |
+
fid.write(bytes)
|
100 |
+
|
101 |
+
|
102 |
+
def read_cameras_text(path):
|
103 |
+
"""
|
104 |
+
see: src/base/reconstruction.cc
|
105 |
+
void Reconstruction::WriteCamerasText(const std::string& path)
|
106 |
+
void Reconstruction::ReadCamerasText(const std::string& path)
|
107 |
+
"""
|
108 |
+
cameras = {}
|
109 |
+
with open(path, "r") as fid:
|
110 |
+
while True:
|
111 |
+
line = fid.readline()
|
112 |
+
if not line:
|
113 |
+
break
|
114 |
+
line = line.strip()
|
115 |
+
if len(line) > 0 and line[0] != "#":
|
116 |
+
elems = line.split()
|
117 |
+
camera_id = int(elems[0])
|
118 |
+
model = elems[1]
|
119 |
+
width = int(elems[2])
|
120 |
+
height = int(elems[3])
|
121 |
+
params = np.array(tuple(map(float, elems[4:])))
|
122 |
+
cameras[camera_id] = Camera(id=camera_id, model=model,
|
123 |
+
width=width, height=height,
|
124 |
+
params=params)
|
125 |
+
return cameras
|
126 |
+
|
127 |
+
|
128 |
+
def read_cameras_binary(path_to_model_file):
|
129 |
+
"""
|
130 |
+
see: src/base/reconstruction.cc
|
131 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
132 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
133 |
+
"""
|
134 |
+
cameras = {}
|
135 |
+
with open(path_to_model_file, "rb") as fid:
|
136 |
+
num_cameras = read_next_bytes(fid, 8, "Q")[0]
|
137 |
+
for _ in range(num_cameras):
|
138 |
+
camera_properties = read_next_bytes(
|
139 |
+
fid, num_bytes=24, format_char_sequence="iiQQ")
|
140 |
+
camera_id = camera_properties[0]
|
141 |
+
model_id = camera_properties[1]
|
142 |
+
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
|
143 |
+
width = camera_properties[2]
|
144 |
+
height = camera_properties[3]
|
145 |
+
num_params = CAMERA_MODEL_IDS[model_id].num_params
|
146 |
+
params = read_next_bytes(fid, num_bytes=8*num_params,
|
147 |
+
format_char_sequence="d"*num_params)
|
148 |
+
cameras[camera_id] = Camera(id=camera_id,
|
149 |
+
model=model_name,
|
150 |
+
width=width,
|
151 |
+
height=height,
|
152 |
+
params=np.array(params))
|
153 |
+
assert len(cameras) == num_cameras
|
154 |
+
return cameras
|
155 |
+
|
156 |
+
|
157 |
+
def write_cameras_text(cameras, path):
|
158 |
+
"""
|
159 |
+
see: src/base/reconstruction.cc
|
160 |
+
void Reconstruction::WriteCamerasText(const std::string& path)
|
161 |
+
void Reconstruction::ReadCamerasText(const std::string& path)
|
162 |
+
"""
|
163 |
+
HEADER = "# Camera list with one line of data per camera:\n"
|
164 |
+
"# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
|
165 |
+
"# Number of cameras: {}\n".format(len(cameras))
|
166 |
+
with open(path, "w") as fid:
|
167 |
+
fid.write(HEADER)
|
168 |
+
for _, cam in cameras.items():
|
169 |
+
to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
|
170 |
+
line = " ".join([str(elem) for elem in to_write])
|
171 |
+
fid.write(line + "\n")
|
172 |
+
|
173 |
+
|
174 |
+
def write_cameras_binary(cameras, path_to_model_file):
|
175 |
+
"""
|
176 |
+
see: src/base/reconstruction.cc
|
177 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
178 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
179 |
+
"""
|
180 |
+
with open(path_to_model_file, "wb") as fid:
|
181 |
+
write_next_bytes(fid, len(cameras), "Q")
|
182 |
+
for _, cam in cameras.items():
|
183 |
+
model_id = CAMERA_MODEL_NAMES[cam.model].model_id
|
184 |
+
camera_properties = [cam.id,
|
185 |
+
model_id,
|
186 |
+
cam.width,
|
187 |
+
cam.height]
|
188 |
+
write_next_bytes(fid, camera_properties, "iiQQ")
|
189 |
+
for p in cam.params:
|
190 |
+
write_next_bytes(fid, float(p), "d")
|
191 |
+
return cameras
|
192 |
+
|
193 |
+
|
194 |
+
def read_images_text(path):
|
195 |
+
"""
|
196 |
+
see: src/base/reconstruction.cc
|
197 |
+
void Reconstruction::ReadImagesText(const std::string& path)
|
198 |
+
void Reconstruction::WriteImagesText(const std::string& path)
|
199 |
+
"""
|
200 |
+
images = {}
|
201 |
+
with open(path, "r") as fid:
|
202 |
+
while True:
|
203 |
+
line = fid.readline()
|
204 |
+
if not line:
|
205 |
+
break
|
206 |
+
line = line.strip()
|
207 |
+
if len(line) > 0 and line[0] != "#":
|
208 |
+
elems = line.split()
|
209 |
+
image_id = int(elems[0])
|
210 |
+
qvec = np.array(tuple(map(float, elems[1:5])))
|
211 |
+
tvec = np.array(tuple(map(float, elems[5:8])))
|
212 |
+
camera_id = int(elems[8])
|
213 |
+
image_name = elems[9]
|
214 |
+
elems = fid.readline().split()
|
215 |
+
xys = np.column_stack([tuple(map(float, elems[0::3])),
|
216 |
+
tuple(map(float, elems[1::3]))])
|
217 |
+
point3D_ids = np.array(tuple(map(int, elems[2::3])))
|
218 |
+
images[image_id] = Image(
|
219 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
220 |
+
camera_id=camera_id, name=image_name,
|
221 |
+
xys=xys, point3D_ids=point3D_ids)
|
222 |
+
return images
|
223 |
+
|
224 |
+
|
225 |
+
def read_images_binary(path_to_model_file):
|
226 |
+
"""
|
227 |
+
see: src/base/reconstruction.cc
|
228 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
229 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
230 |
+
"""
|
231 |
+
images = {}
|
232 |
+
with open(path_to_model_file, "rb") as fid:
|
233 |
+
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
|
234 |
+
for _ in range(num_reg_images):
|
235 |
+
binary_image_properties = read_next_bytes(
|
236 |
+
fid, num_bytes=64, format_char_sequence="idddddddi")
|
237 |
+
image_id = binary_image_properties[0]
|
238 |
+
qvec = np.array(binary_image_properties[1:5])
|
239 |
+
tvec = np.array(binary_image_properties[5:8])
|
240 |
+
camera_id = binary_image_properties[8]
|
241 |
+
image_name = ""
|
242 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
243 |
+
while current_char != b"\x00": # look for the ASCII 0 entry
|
244 |
+
image_name += current_char.decode("utf-8")
|
245 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
246 |
+
num_points2D = read_next_bytes(fid, num_bytes=8,
|
247 |
+
format_char_sequence="Q")[0]
|
248 |
+
x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
|
249 |
+
format_char_sequence="ddq"*num_points2D)
|
250 |
+
xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
|
251 |
+
tuple(map(float, x_y_id_s[1::3]))])
|
252 |
+
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
|
253 |
+
images[image_id] = Image(
|
254 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
255 |
+
camera_id=camera_id, name=image_name,
|
256 |
+
xys=xys, point3D_ids=point3D_ids)
|
257 |
+
return images
|
258 |
+
|
259 |
+
|
260 |
+
def write_images_text(images, path):
|
261 |
+
"""
|
262 |
+
see: src/base/reconstruction.cc
|
263 |
+
void Reconstruction::ReadImagesText(const std::string& path)
|
264 |
+
void Reconstruction::WriteImagesText(const std::string& path)
|
265 |
+
"""
|
266 |
+
if len(images) == 0:
|
267 |
+
mean_observations = 0
|
268 |
+
else:
|
269 |
+
mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images)
|
270 |
+
HEADER = "# Image list with two lines of data per image:\n"
|
271 |
+
"# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
|
272 |
+
"# POINTS2D[] as (X, Y, POINT3D_ID)\n"
|
273 |
+
"# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations)
|
274 |
+
|
275 |
+
with open(path, "w") as fid:
|
276 |
+
fid.write(HEADER)
|
277 |
+
for _, img in images.items():
|
278 |
+
image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
|
279 |
+
first_line = " ".join(map(str, image_header))
|
280 |
+
fid.write(first_line + "\n")
|
281 |
+
|
282 |
+
points_strings = []
|
283 |
+
for xy, point3D_id in zip(img.xys, img.point3D_ids):
|
284 |
+
points_strings.append(" ".join(map(str, [*xy, point3D_id])))
|
285 |
+
fid.write(" ".join(points_strings) + "\n")
|
286 |
+
|
287 |
+
|
288 |
+
def write_images_binary(images, path_to_model_file):
|
289 |
+
"""
|
290 |
+
see: src/base/reconstruction.cc
|
291 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
292 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
293 |
+
"""
|
294 |
+
with open(path_to_model_file, "wb") as fid:
|
295 |
+
write_next_bytes(fid, len(images), "Q")
|
296 |
+
for _, img in images.items():
|
297 |
+
write_next_bytes(fid, img.id, "i")
|
298 |
+
write_next_bytes(fid, img.qvec.tolist(), "dddd")
|
299 |
+
write_next_bytes(fid, img.tvec.tolist(), "ddd")
|
300 |
+
write_next_bytes(fid, img.camera_id, "i")
|
301 |
+
for char in img.name:
|
302 |
+
write_next_bytes(fid, char.encode("utf-8"), "c")
|
303 |
+
write_next_bytes(fid, b"\x00", "c")
|
304 |
+
write_next_bytes(fid, len(img.point3D_ids), "Q")
|
305 |
+
for xy, p3d_id in zip(img.xys, img.point3D_ids):
|
306 |
+
write_next_bytes(fid, [*xy, p3d_id], "ddq")
|
307 |
+
|
308 |
+
|
309 |
+
def read_points3D_text(path):
|
310 |
+
"""
|
311 |
+
see: src/base/reconstruction.cc
|
312 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
313 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
314 |
+
"""
|
315 |
+
points3D = {}
|
316 |
+
with open(path, "r") as fid:
|
317 |
+
while True:
|
318 |
+
line = fid.readline()
|
319 |
+
if not line:
|
320 |
+
break
|
321 |
+
line = line.strip()
|
322 |
+
if len(line) > 0 and line[0] != "#":
|
323 |
+
elems = line.split()
|
324 |
+
point3D_id = int(elems[0])
|
325 |
+
xyz = np.array(tuple(map(float, elems[1:4])))
|
326 |
+
rgb = np.array(tuple(map(int, elems[4:7])))
|
327 |
+
error = float(elems[7])
|
328 |
+
image_ids = np.array(tuple(map(int, elems[8::2])))
|
329 |
+
point2D_idxs = np.array(tuple(map(int, elems[9::2])))
|
330 |
+
points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
|
331 |
+
error=error, image_ids=image_ids,
|
332 |
+
point2D_idxs=point2D_idxs)
|
333 |
+
return points3D
|
334 |
+
|
335 |
+
|
336 |
+
def read_points3d_binary(path_to_model_file):
|
337 |
+
"""
|
338 |
+
see: src/base/reconstruction.cc
|
339 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
340 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
341 |
+
"""
|
342 |
+
points3D = {}
|
343 |
+
with open(path_to_model_file, "rb") as fid:
|
344 |
+
num_points = read_next_bytes(fid, 8, "Q")[0]
|
345 |
+
for _ in range(num_points):
|
346 |
+
binary_point_line_properties = read_next_bytes(
|
347 |
+
fid, num_bytes=43, format_char_sequence="QdddBBBd")
|
348 |
+
point3D_id = binary_point_line_properties[0]
|
349 |
+
xyz = np.array(binary_point_line_properties[1:4])
|
350 |
+
rgb = np.array(binary_point_line_properties[4:7])
|
351 |
+
error = np.array(binary_point_line_properties[7])
|
352 |
+
track_length = read_next_bytes(
|
353 |
+
fid, num_bytes=8, format_char_sequence="Q")[0]
|
354 |
+
track_elems = read_next_bytes(
|
355 |
+
fid, num_bytes=8*track_length,
|
356 |
+
format_char_sequence="ii"*track_length)
|
357 |
+
image_ids = np.array(tuple(map(int, track_elems[0::2])))
|
358 |
+
point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
|
359 |
+
points3D[point3D_id] = Point3D(
|
360 |
+
id=point3D_id, xyz=xyz, rgb=rgb,
|
361 |
+
error=error, image_ids=image_ids,
|
362 |
+
point2D_idxs=point2D_idxs)
|
363 |
+
return points3D
|
364 |
+
|
365 |
+
def write_points3D_text(points3D, path):
|
366 |
+
"""
|
367 |
+
see: src/base/reconstruction.cc
|
368 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
369 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
370 |
+
"""
|
371 |
+
if len(points3D) == 0:
|
372 |
+
mean_track_length = 0
|
373 |
+
else:
|
374 |
+
mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
|
375 |
+
HEADER = "# 3D point list with one line of data per point:\n"
|
376 |
+
"# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
|
377 |
+
"# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)
|
378 |
+
|
379 |
+
with open(path, "w") as fid:
|
380 |
+
fid.write(HEADER)
|
381 |
+
for _, pt in points3D.items():
|
382 |
+
point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
|
383 |
+
fid.write(" ".join(map(str, point_header)) + " ")
|
384 |
+
track_strings = []
|
385 |
+
for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
|
386 |
+
track_strings.append(" ".join(map(str, [image_id, point2D])))
|
387 |
+
fid.write(" ".join(track_strings) + "\n")
|
388 |
+
|
389 |
+
|
390 |
+
def write_points3d_binary(points3D, path_to_model_file):
|
391 |
+
"""
|
392 |
+
see: src/base/reconstruction.cc
|
393 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
394 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
395 |
+
"""
|
396 |
+
with open(path_to_model_file, "wb") as fid:
|
397 |
+
write_next_bytes(fid, len(points3D), "Q")
|
398 |
+
for _, pt in points3D.items():
|
399 |
+
write_next_bytes(fid, pt.id, "Q")
|
400 |
+
write_next_bytes(fid, pt.xyz.tolist(), "ddd")
|
401 |
+
write_next_bytes(fid, pt.rgb.tolist(), "BBB")
|
402 |
+
write_next_bytes(fid, pt.error, "d")
|
403 |
+
track_length = pt.image_ids.shape[0]
|
404 |
+
write_next_bytes(fid, track_length, "Q")
|
405 |
+
for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
|
406 |
+
write_next_bytes(fid, [image_id, point2D_id], "ii")
|
407 |
+
|
408 |
+
|
409 |
+
def detect_model_format(path, ext):
|
410 |
+
if os.path.isfile(os.path.join(path, "cameras" + ext)) and \
|
411 |
+
os.path.isfile(os.path.join(path, "images" + ext)) and \
|
412 |
+
os.path.isfile(os.path.join(path, "points3D" + ext)):
|
413 |
+
print("Detected model format: '" + ext + "'")
|
414 |
+
return True
|
415 |
+
|
416 |
+
return False
|
417 |
+
|
418 |
+
|
419 |
+
def read_model(path, ext=""):
|
420 |
+
# try to detect the extension automatically
|
421 |
+
if ext == "":
|
422 |
+
if detect_model_format(path, ".bin"):
|
423 |
+
ext = ".bin"
|
424 |
+
elif detect_model_format(path, ".txt"):
|
425 |
+
ext = ".txt"
|
426 |
+
else:
|
427 |
+
print("Provide model format: '.bin' or '.txt'")
|
428 |
+
return
|
429 |
+
|
430 |
+
if ext == ".txt":
|
431 |
+
cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
|
432 |
+
images = read_images_text(os.path.join(path, "images" + ext))
|
433 |
+
points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
|
434 |
+
else:
|
435 |
+
cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
|
436 |
+
images = read_images_binary(os.path.join(path, "images" + ext))
|
437 |
+
points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
|
438 |
+
return cameras, images, points3D
|
439 |
+
|
440 |
+
|
441 |
+
def write_model(cameras, images, points3D, path, ext=".bin"):
|
442 |
+
if ext == ".txt":
|
443 |
+
write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
|
444 |
+
write_images_text(images, os.path.join(path, "images" + ext))
|
445 |
+
write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
|
446 |
+
else:
|
447 |
+
write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
|
448 |
+
write_images_binary(images, os.path.join(path, "images" + ext))
|
449 |
+
write_points3d_binary(points3D, os.path.join(path, "points3D") + ext)
|
450 |
+
return cameras, images, points3D
|
451 |
+
|
452 |
+
|
453 |
+
def qvec2rotmat(qvec):
|
454 |
+
return np.array([
|
455 |
+
[1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
|
456 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
457 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
|
458 |
+
[2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
459 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
|
460 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
|
461 |
+
[2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
462 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
463 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
|
464 |
+
|
465 |
+
|
466 |
+
def rotmat2qvec(R):
|
467 |
+
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
|
468 |
+
K = np.array([
|
469 |
+
[Rxx - Ryy - Rzz, 0, 0, 0],
|
470 |
+
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
|
471 |
+
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
|
472 |
+
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
|
473 |
+
eigvals, eigvecs = np.linalg.eigh(K)
|
474 |
+
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
|
475 |
+
if qvec[0] < 0:
|
476 |
+
qvec *= -1
|
477 |
+
return qvec
|
478 |
+
|
479 |
+
|
480 |
+
def main():
|
481 |
+
parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models")
|
482 |
+
parser.add_argument("--input_model", help="path to input model folder")
|
483 |
+
parser.add_argument("--input_format", choices=[".bin", ".txt"],
|
484 |
+
help="input model format", default="")
|
485 |
+
parser.add_argument("--output_model",
|
486 |
+
help="path to output model folder")
|
487 |
+
parser.add_argument("--output_format", choices=[".bin", ".txt"],
|
488 |
+
help="outut model format", default=".txt")
|
489 |
+
args = parser.parse_args()
|
490 |
+
|
491 |
+
cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
|
492 |
+
|
493 |
+
# FIXME: for debug only
|
494 |
+
# images_ = images[1]
|
495 |
+
# tvec, qvec = images_.tvec, images_.qvec
|
496 |
+
# rotation = qvec2rotmat(qvec).reshape(3, 3)
|
497 |
+
# pose = np.concatenate([rotation, tvec.reshape(3, 1)], axis=1)
|
498 |
+
# import ipdb; ipdb.set_trace()
|
499 |
+
|
500 |
+
print("num_cameras:", len(cameras))
|
501 |
+
print("num_images:", len(images))
|
502 |
+
print("num_points3D:", len(points3D))
|
503 |
+
|
504 |
+
if args.output_model is not None:
|
505 |
+
write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format)
|
506 |
+
|
507 |
+
|
508 |
+
if __name__ == "__main__":
|
509 |
+
main()
|
imcui/third_party/MatchAnything/src/utils/comm.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
[Copied from detectron2]
|
4 |
+
This file contains primitives for multi-gpu communication.
|
5 |
+
This is useful when doing distributed training.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import functools
|
9 |
+
import logging
|
10 |
+
import numpy as np
|
11 |
+
import pickle
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
|
15 |
+
_LOCAL_PROCESS_GROUP = None
|
16 |
+
"""
|
17 |
+
A torch process group which only includes processes that on the same machine as the current process.
|
18 |
+
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
19 |
+
"""
|
20 |
+
|
21 |
+
|
22 |
+
def get_world_size() -> int:
|
23 |
+
if not dist.is_available():
|
24 |
+
return 1
|
25 |
+
if not dist.is_initialized():
|
26 |
+
return 1
|
27 |
+
return dist.get_world_size()
|
28 |
+
|
29 |
+
|
30 |
+
def get_rank() -> int:
|
31 |
+
if not dist.is_available():
|
32 |
+
return 0
|
33 |
+
if not dist.is_initialized():
|
34 |
+
return 0
|
35 |
+
return dist.get_rank()
|
36 |
+
|
37 |
+
|
38 |
+
def get_local_rank() -> int:
|
39 |
+
"""
|
40 |
+
Returns:
|
41 |
+
The rank of the current process within the local (per-machine) process group.
|
42 |
+
"""
|
43 |
+
if not dist.is_available():
|
44 |
+
return 0
|
45 |
+
if not dist.is_initialized():
|
46 |
+
return 0
|
47 |
+
assert _LOCAL_PROCESS_GROUP is not None
|
48 |
+
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
49 |
+
|
50 |
+
|
51 |
+
def get_local_size() -> int:
|
52 |
+
"""
|
53 |
+
Returns:
|
54 |
+
The size of the per-machine process group,
|
55 |
+
i.e. the number of processes per machine.
|
56 |
+
"""
|
57 |
+
if not dist.is_available():
|
58 |
+
return 1
|
59 |
+
if not dist.is_initialized():
|
60 |
+
return 1
|
61 |
+
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
62 |
+
|
63 |
+
|
64 |
+
def is_main_process() -> bool:
|
65 |
+
return get_rank() == 0
|
66 |
+
|
67 |
+
|
68 |
+
def synchronize():
|
69 |
+
"""
|
70 |
+
Helper function to synchronize (barrier) among all processes when
|
71 |
+
using distributed training
|
72 |
+
"""
|
73 |
+
if not dist.is_available():
|
74 |
+
return
|
75 |
+
if not dist.is_initialized():
|
76 |
+
return
|
77 |
+
world_size = dist.get_world_size()
|
78 |
+
if world_size == 1:
|
79 |
+
return
|
80 |
+
dist.barrier()
|
81 |
+
|
82 |
+
|
83 |
+
@functools.lru_cache()
|
84 |
+
def _get_global_gloo_group():
|
85 |
+
"""
|
86 |
+
Return a process group based on gloo backend, containing all the ranks
|
87 |
+
The result is cached.
|
88 |
+
"""
|
89 |
+
if dist.get_backend() == "nccl":
|
90 |
+
return dist.new_group(backend="gloo")
|
91 |
+
else:
|
92 |
+
return dist.group.WORLD
|
93 |
+
|
94 |
+
|
95 |
+
def _serialize_to_tensor(data, group):
|
96 |
+
backend = dist.get_backend(group)
|
97 |
+
assert backend in ["gloo", "nccl"]
|
98 |
+
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
99 |
+
|
100 |
+
buffer = pickle.dumps(data)
|
101 |
+
if len(buffer) > 1024 ** 3:
|
102 |
+
logger = logging.getLogger(__name__)
|
103 |
+
logger.warning(
|
104 |
+
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
105 |
+
get_rank(), len(buffer) / (1024 ** 3), device
|
106 |
+
)
|
107 |
+
)
|
108 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
109 |
+
tensor = torch.ByteTensor(storage).to(device=device)
|
110 |
+
return tensor
|
111 |
+
|
112 |
+
|
113 |
+
def _pad_to_largest_tensor(tensor, group):
|
114 |
+
"""
|
115 |
+
Returns:
|
116 |
+
list[int]: size of the tensor, on each rank
|
117 |
+
Tensor: padded tensor that has the max size
|
118 |
+
"""
|
119 |
+
world_size = dist.get_world_size(group=group)
|
120 |
+
assert (
|
121 |
+
world_size >= 1
|
122 |
+
), "comm.gather/all_gather must be called from ranks within the given group!"
|
123 |
+
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
|
124 |
+
size_list = [
|
125 |
+
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
|
126 |
+
]
|
127 |
+
dist.all_gather(size_list, local_size, group=group)
|
128 |
+
|
129 |
+
size_list = [int(size.item()) for size in size_list]
|
130 |
+
|
131 |
+
max_size = max(size_list)
|
132 |
+
|
133 |
+
# we pad the tensor because torch all_gather does not support
|
134 |
+
# gathering tensors of different shapes
|
135 |
+
if local_size != max_size:
|
136 |
+
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
|
137 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
138 |
+
return size_list, tensor
|
139 |
+
|
140 |
+
|
141 |
+
def all_gather(data, group=None):
|
142 |
+
"""
|
143 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
144 |
+
|
145 |
+
Args:
|
146 |
+
data: any picklable object
|
147 |
+
group: a torch process group. By default, will use a group which
|
148 |
+
contains all ranks on gloo backend.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
list[data]: list of data gathered from each rank
|
152 |
+
"""
|
153 |
+
if get_world_size() == 1:
|
154 |
+
return [data]
|
155 |
+
if group is None:
|
156 |
+
group = _get_global_gloo_group()
|
157 |
+
if dist.get_world_size(group) == 1:
|
158 |
+
return [data]
|
159 |
+
|
160 |
+
tensor = _serialize_to_tensor(data, group)
|
161 |
+
|
162 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
163 |
+
max_size = max(size_list)
|
164 |
+
|
165 |
+
# receiving Tensor from all ranks
|
166 |
+
tensor_list = [
|
167 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
168 |
+
]
|
169 |
+
dist.all_gather(tensor_list, tensor, group=group)
|
170 |
+
|
171 |
+
data_list = []
|
172 |
+
for size, tensor in zip(size_list, tensor_list):
|
173 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
174 |
+
data_list.append(pickle.loads(buffer))
|
175 |
+
|
176 |
+
return data_list
|
177 |
+
|
178 |
+
|
179 |
+
def gather(data, dst=0, group=None):
|
180 |
+
"""
|
181 |
+
Run gather on arbitrary picklable data (not necessarily tensors).
|
182 |
+
|
183 |
+
Args:
|
184 |
+
data: any picklable object
|
185 |
+
dst (int): destination rank
|
186 |
+
group: a torch process group. By default, will use a group which
|
187 |
+
contains all ranks on gloo backend.
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
191 |
+
an empty list.
|
192 |
+
"""
|
193 |
+
if get_world_size() == 1:
|
194 |
+
return [data]
|
195 |
+
if group is None:
|
196 |
+
group = _get_global_gloo_group()
|
197 |
+
if dist.get_world_size(group=group) == 1:
|
198 |
+
return [data]
|
199 |
+
rank = dist.get_rank(group=group)
|
200 |
+
|
201 |
+
tensor = _serialize_to_tensor(data, group)
|
202 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
203 |
+
|
204 |
+
# receiving Tensor from all ranks
|
205 |
+
if rank == dst:
|
206 |
+
max_size = max(size_list)
|
207 |
+
tensor_list = [
|
208 |
+
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
209 |
+
]
|
210 |
+
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
211 |
+
|
212 |
+
data_list = []
|
213 |
+
for size, tensor in zip(size_list, tensor_list):
|
214 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
215 |
+
data_list.append(pickle.loads(buffer))
|
216 |
+
return data_list
|
217 |
+
else:
|
218 |
+
dist.gather(tensor, [], dst=dst, group=group)
|
219 |
+
return []
|
220 |
+
|
221 |
+
|
222 |
+
def shared_random_seed():
|
223 |
+
"""
|
224 |
+
Returns:
|
225 |
+
int: a random number that is the same across all workers.
|
226 |
+
If workers need a shared RNG, they can use this shared seed to
|
227 |
+
create one.
|
228 |
+
|
229 |
+
All workers must call this function, otherwise it will deadlock.
|
230 |
+
"""
|
231 |
+
ints = np.random.randint(2 ** 31)
|
232 |
+
all_ints = all_gather(ints)
|
233 |
+
return all_ints[0]
|
234 |
+
|
235 |
+
|
236 |
+
def reduce_dict(input_dict, average=True):
|
237 |
+
"""
|
238 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
239 |
+
0 has the reduced results.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
243 |
+
average (bool): whether to do average or sum
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
a dict with the same keys as input_dict, after reduction.
|
247 |
+
"""
|
248 |
+
world_size = get_world_size()
|
249 |
+
if world_size < 2:
|
250 |
+
return input_dict
|
251 |
+
with torch.no_grad():
|
252 |
+
names = []
|
253 |
+
values = []
|
254 |
+
# sort the keys so that they are consistent across processes
|
255 |
+
for k in sorted(input_dict.keys()):
|
256 |
+
names.append(k)
|
257 |
+
values.append(input_dict[k])
|
258 |
+
values = torch.stack(values, dim=0)
|
259 |
+
dist.reduce(values, dst=0)
|
260 |
+
if dist.get_rank() == 0 and average:
|
261 |
+
# only main process gets accumulated, so only divide by
|
262 |
+
# world_size in this case
|
263 |
+
values /= world_size
|
264 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
265 |
+
return reduced_dict
|
imcui/third_party/MatchAnything/src/utils/dataloader.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
# --- PL-DATAMODULE ---
|
5 |
+
|
6 |
+
def get_local_split(items: list, world_size: int, rank: int, seed: int):
|
7 |
+
""" The local rank only loads a split of the dataset. """
|
8 |
+
n_items = len(items)
|
9 |
+
items_permute = np.random.RandomState(seed).permutation(items)
|
10 |
+
if n_items % world_size == 0:
|
11 |
+
padded_items = items_permute
|
12 |
+
else:
|
13 |
+
padding = np.random.RandomState(seed).choice(
|
14 |
+
items,
|
15 |
+
world_size - (n_items % world_size),
|
16 |
+
replace=True)
|
17 |
+
padded_items = np.concatenate([items_permute, padding])
|
18 |
+
assert len(padded_items) % world_size == 0, \
|
19 |
+
f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
|
20 |
+
n_per_rank = len(padded_items) // world_size
|
21 |
+
local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
|
22 |
+
|
23 |
+
return local_items
|
imcui/third_party/MatchAnything/src/utils/dataset.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from loguru import logger
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
import h5py
|
8 |
+
import torch
|
9 |
+
import re
|
10 |
+
from PIL import Image
|
11 |
+
from numpy.linalg import inv
|
12 |
+
from torchvision.transforms import Normalize
|
13 |
+
from .sample_homo import sample_homography_sap
|
14 |
+
from kornia.geometry import homography_warp, normalize_homography, normal_transform_pixel
|
15 |
+
OSS_FOLDER_PATH = '???'
|
16 |
+
PCACHE_FOLDER_PATH = '???'
|
17 |
+
|
18 |
+
import fsspec
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
# Initialize pcache
|
22 |
+
try:
|
23 |
+
PCACHE_HOST = "???"
|
24 |
+
PCACHE_PORT = 00000
|
25 |
+
pcache_kwargs = {"host": PCACHE_HOST, "port": PCACHE_PORT}
|
26 |
+
pcache_fs = fsspec.filesystem("pcache", pcache_kwargs=pcache_kwargs)
|
27 |
+
root_dir='???'
|
28 |
+
except Exception as e:
|
29 |
+
logger.error(f"Error captured:{e}")
|
30 |
+
|
31 |
+
try:
|
32 |
+
# for internel use only
|
33 |
+
from pcache_fileio import fileio
|
34 |
+
except Exception:
|
35 |
+
MEGADEPTH_CLIENT = SCANNET_CLIENT = None
|
36 |
+
|
37 |
+
# --- DATA IO ---
|
38 |
+
|
39 |
+
def load_pfm(pfm_path):
|
40 |
+
with open(pfm_path, 'rb') as fin:
|
41 |
+
color = None
|
42 |
+
width = None
|
43 |
+
height = None
|
44 |
+
scale = None
|
45 |
+
data_type = None
|
46 |
+
header = str(fin.readline().decode('UTF-8')).rstrip()
|
47 |
+
|
48 |
+
if header == 'PF':
|
49 |
+
color = True
|
50 |
+
elif header == 'Pf':
|
51 |
+
color = False
|
52 |
+
else:
|
53 |
+
raise Exception('Not a PFM file.')
|
54 |
+
|
55 |
+
dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8'))
|
56 |
+
if dim_match:
|
57 |
+
width, height = map(int, dim_match.groups())
|
58 |
+
else:
|
59 |
+
raise Exception('Malformed PFM header.')
|
60 |
+
scale = float((fin.readline().decode('UTF-8')).rstrip())
|
61 |
+
if scale < 0: # little-endian
|
62 |
+
data_type = '<f'
|
63 |
+
else:
|
64 |
+
data_type = '>f' # big-endian
|
65 |
+
data_string = fin.read()
|
66 |
+
data = np.fromstring(data_string, data_type)
|
67 |
+
shape = (height, width, 3) if color else (height, width)
|
68 |
+
data = np.reshape(data, shape)
|
69 |
+
data = np.flip(data, 0)
|
70 |
+
return data
|
71 |
+
|
72 |
+
def load_array_from_pcache(
|
73 |
+
path, cv_type,
|
74 |
+
use_h5py=False,
|
75 |
+
):
|
76 |
+
|
77 |
+
filename = path.split(root_dir)[1]
|
78 |
+
pcache_path = Path(root_dir) / filename
|
79 |
+
try:
|
80 |
+
if not use_h5py:
|
81 |
+
load_failed = True
|
82 |
+
failed_num = 0
|
83 |
+
while load_failed:
|
84 |
+
try:
|
85 |
+
with pcache_fs.open(str(pcache_path), 'rb') as f:
|
86 |
+
data = Image.open(f).convert("L")
|
87 |
+
data = np.array(data)
|
88 |
+
load_failed = False
|
89 |
+
except:
|
90 |
+
failed_num += 1
|
91 |
+
if failed_num > 5000:
|
92 |
+
logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times")
|
93 |
+
continue
|
94 |
+
else:
|
95 |
+
load_failed = True
|
96 |
+
failed_num = 0
|
97 |
+
while load_failed:
|
98 |
+
try:
|
99 |
+
with pcache_fs.open(str(pcache_path), 'rb') as f:
|
100 |
+
data = np.array(h5py.File(io.BytesIO(f.read()), 'r')['/depth'])
|
101 |
+
load_failed = False
|
102 |
+
except:
|
103 |
+
failed_num += 1
|
104 |
+
if failed_num > 5000:
|
105 |
+
logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times")
|
106 |
+
continue
|
107 |
+
|
108 |
+
except Exception as ex:
|
109 |
+
print(f"==> Data loading failure: {path}")
|
110 |
+
raise ex
|
111 |
+
|
112 |
+
assert data is not None
|
113 |
+
return data
|
114 |
+
|
115 |
+
|
116 |
+
def imread_gray(path, augment_fn=None, cv_type=None):
|
117 |
+
if path.startswith('oss://'):
|
118 |
+
path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH)
|
119 |
+
if path.startswith('pcache://'):
|
120 |
+
path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/'
|
121 |
+
|
122 |
+
if cv_type is None:
|
123 |
+
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
|
124 |
+
else cv2.IMREAD_COLOR
|
125 |
+
if str(path).startswith('oss://') or str(path).startswith('pcache://'):
|
126 |
+
image = load_array_from_pcache(str(path), cv_type)
|
127 |
+
else:
|
128 |
+
image = cv2.imread(str(path), cv_type)
|
129 |
+
|
130 |
+
if augment_fn is not None:
|
131 |
+
image = cv2.imread(str(path), cv2.IMREAD_COLOR)
|
132 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
133 |
+
image = augment_fn(image)
|
134 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
135 |
+
return image # (h, w)
|
136 |
+
|
137 |
+
def imread_color(path, augment_fn=None):
|
138 |
+
if path.startswith('oss://'):
|
139 |
+
path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH)
|
140 |
+
if path.startswith('pcache://'):
|
141 |
+
path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/'
|
142 |
+
|
143 |
+
if str(path).startswith('oss://') or str(path).startswith('pcache://'):
|
144 |
+
filename = path.split(root_dir)[1]
|
145 |
+
pcache_path = Path(root_dir) / filename
|
146 |
+
load_failed = True
|
147 |
+
failed_num = 0
|
148 |
+
while load_failed:
|
149 |
+
try:
|
150 |
+
with pcache_fs.open(str(pcache_path), 'rb') as f:
|
151 |
+
pil_image = Image.open(f).convert("RGB")
|
152 |
+
load_failed = False
|
153 |
+
except:
|
154 |
+
failed_num += 1
|
155 |
+
if failed_num > 5000:
|
156 |
+
logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times")
|
157 |
+
continue
|
158 |
+
else:
|
159 |
+
pil_image = Image.open(str(path)).convert("RGB")
|
160 |
+
image = np.array(pil_image)
|
161 |
+
|
162 |
+
if augment_fn is not None:
|
163 |
+
image = augment_fn(image)
|
164 |
+
return image # (h, w)
|
165 |
+
|
166 |
+
|
167 |
+
def get_resized_wh(w, h, resize=None):
|
168 |
+
if resize is not None: # resize the longer edge
|
169 |
+
scale = resize / max(h, w)
|
170 |
+
w_new, h_new = int(round(w*scale)), int(round(h*scale))
|
171 |
+
else:
|
172 |
+
w_new, h_new = w, h
|
173 |
+
return w_new, h_new
|
174 |
+
|
175 |
+
|
176 |
+
def get_divisible_wh(w, h, df=None):
|
177 |
+
if df is not None:
|
178 |
+
w_new, h_new = map(lambda x: int(x // df * df), [w, h])
|
179 |
+
else:
|
180 |
+
w_new, h_new = w, h
|
181 |
+
return w_new, h_new
|
182 |
+
|
183 |
+
|
184 |
+
def pad_bottom_right(inp, pad_size, ret_mask=False):
|
185 |
+
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
|
186 |
+
mask = None
|
187 |
+
if inp.ndim == 2:
|
188 |
+
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
|
189 |
+
padded[:inp.shape[0], :inp.shape[1]] = inp
|
190 |
+
if ret_mask:
|
191 |
+
mask = np.zeros((pad_size, pad_size), dtype=bool)
|
192 |
+
mask[:inp.shape[0], :inp.shape[1]] = True
|
193 |
+
elif inp.ndim == 3:
|
194 |
+
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
|
195 |
+
padded[:, :inp.shape[1], :inp.shape[2]] = inp
|
196 |
+
if ret_mask:
|
197 |
+
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
|
198 |
+
mask[:, :inp.shape[1], :inp.shape[2]] = True
|
199 |
+
mask = mask[0]
|
200 |
+
else:
|
201 |
+
raise NotImplementedError()
|
202 |
+
return padded, mask
|
203 |
+
|
204 |
+
|
205 |
+
# --- MEGADEPTH ---
|
206 |
+
|
207 |
+
def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False):
|
208 |
+
"""
|
209 |
+
Args:
|
210 |
+
resize (int, optional): the longer edge of resized images. None for no resize.
|
211 |
+
padding (bool): If set to 'True', zero-pad resized images to squared size.
|
212 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects
|
213 |
+
Returns:
|
214 |
+
image (torch.tensor): (1, h, w)
|
215 |
+
mask (torch.tensor): (h, w)
|
216 |
+
scale (torch.tensor): [w/w_new, h/h_new]
|
217 |
+
"""
|
218 |
+
# read image
|
219 |
+
if read_gray:
|
220 |
+
image = imread_gray(path, augment_fn)
|
221 |
+
else:
|
222 |
+
image = imread_color(path, augment_fn)
|
223 |
+
|
224 |
+
# resize image
|
225 |
+
try:
|
226 |
+
w, h = image.shape[1], image.shape[0]
|
227 |
+
except:
|
228 |
+
logger.error(f"{path} not exist or read image error!")
|
229 |
+
if resize_by_stretch:
|
230 |
+
w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0])
|
231 |
+
else:
|
232 |
+
if resize:
|
233 |
+
if not isinstance(resize, int):
|
234 |
+
assert resize[0] == resize[1]
|
235 |
+
resize = resize[0]
|
236 |
+
w_new, h_new = get_resized_wh(w, h, resize)
|
237 |
+
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
238 |
+
else:
|
239 |
+
w_new, h_new = w, h
|
240 |
+
|
241 |
+
image = cv2.resize(image, (w_new, h_new))
|
242 |
+
scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
|
243 |
+
origin_img_size = torch.tensor([h, w], dtype=torch.float)
|
244 |
+
|
245 |
+
if not read_gray:
|
246 |
+
image = image.transpose(2,0,1)
|
247 |
+
|
248 |
+
if padding: # padding
|
249 |
+
pad_to = max(h_new, w_new)
|
250 |
+
image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
|
251 |
+
else:
|
252 |
+
mask = None
|
253 |
+
|
254 |
+
if len(image.shape) == 2:
|
255 |
+
image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
|
256 |
+
else:
|
257 |
+
image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
|
258 |
+
if mask is not None:
|
259 |
+
mask = torch.from_numpy(mask)
|
260 |
+
|
261 |
+
if image.shape[0] == 3 and normalize_img:
|
262 |
+
# Normalize image:
|
263 |
+
image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W
|
264 |
+
|
265 |
+
return image, mask, scale, origin_img_size
|
266 |
+
|
267 |
+
def read_megadepth_gray_sample_homowarp(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False):
|
268 |
+
"""
|
269 |
+
Args:
|
270 |
+
resize (int, optional): the longer edge of resized images. None for no resize.
|
271 |
+
padding (bool): If set to 'True', zero-pad resized images to squared size.
|
272 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects
|
273 |
+
Returns:
|
274 |
+
image (torch.tensor): (1, h, w)
|
275 |
+
mask (torch.tensor): (h, w)
|
276 |
+
scale (torch.tensor): [w/w_new, h/h_new]
|
277 |
+
"""
|
278 |
+
# read image
|
279 |
+
if read_gray:
|
280 |
+
image = imread_gray(path, augment_fn)
|
281 |
+
else:
|
282 |
+
image = imread_color(path, augment_fn)
|
283 |
+
|
284 |
+
# resize image
|
285 |
+
w, h = image.shape[1], image.shape[0]
|
286 |
+
if resize_by_stretch:
|
287 |
+
w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0])
|
288 |
+
else:
|
289 |
+
if not isinstance(resize, int):
|
290 |
+
assert resize[0] == resize[1]
|
291 |
+
resize = resize[0]
|
292 |
+
w_new, h_new = get_resized_wh(w, h, resize)
|
293 |
+
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
294 |
+
|
295 |
+
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
296 |
+
|
297 |
+
origin_img_size = torch.tensor([h, w], dtype=torch.float)
|
298 |
+
|
299 |
+
# Sample homography and warp:
|
300 |
+
homo_sampled = sample_homography_sap(h, w) # 3*3
|
301 |
+
homo_sampled_normed = normalize_homography(
|
302 |
+
torch.from_numpy(homo_sampled[None]).to(torch.float32),
|
303 |
+
(h, w),
|
304 |
+
(h, w),
|
305 |
+
)
|
306 |
+
|
307 |
+
if len(image.shape) == 2:
|
308 |
+
image = torch.from_numpy(image).float()[None, None] / 255 # B * C * H * W
|
309 |
+
else:
|
310 |
+
image = torch.from_numpy(image).float().permute(2,0,1)[None] / 255
|
311 |
+
|
312 |
+
homo_warpped_image = homography_warp(
|
313 |
+
image, # 1 * C * H * W
|
314 |
+
torch.linalg.inv(homo_sampled_normed),
|
315 |
+
(h, w),
|
316 |
+
)
|
317 |
+
image = (homo_warpped_image[0].permute(1,2,0).numpy() * 255).astype(np.uint8)
|
318 |
+
norm_pixel_mat = normal_transform_pixel(h, w) # 1 * 3 * 3
|
319 |
+
|
320 |
+
image = cv2.resize(image, (w_new, h_new))
|
321 |
+
scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
|
322 |
+
|
323 |
+
if not read_gray:
|
324 |
+
image = image.transpose(2,0,1)
|
325 |
+
|
326 |
+
if padding: # padding
|
327 |
+
pad_to = max(h_new, w_new)
|
328 |
+
image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
|
329 |
+
else:
|
330 |
+
mask = None
|
331 |
+
|
332 |
+
if len(image.shape) == 2:
|
333 |
+
image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
|
334 |
+
else:
|
335 |
+
image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
|
336 |
+
if mask is not None:
|
337 |
+
mask = torch.from_numpy(mask)
|
338 |
+
|
339 |
+
if image.shape[0] == 3 and normalize_img:
|
340 |
+
# Normalize image:
|
341 |
+
image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W
|
342 |
+
|
343 |
+
return image, mask, scale, origin_img_size, norm_pixel_mat[0], homo_sampled_normed[0]
|
344 |
+
|
345 |
+
|
346 |
+
def read_megadepth_depth_gray(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False):
|
347 |
+
"""
|
348 |
+
Args:
|
349 |
+
resize (int, optional): the longer edge of resized images. None for no resize.
|
350 |
+
padding (bool): If set to 'True', zero-pad resized images to squared size.
|
351 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects
|
352 |
+
Returns:
|
353 |
+
image (torch.tensor): (1, h, w)
|
354 |
+
mask (torch.tensor): (h, w)
|
355 |
+
scale (torch.tensor): [w/w_new, h/h_new]
|
356 |
+
"""
|
357 |
+
depth = read_megadepth_depth(path, return_tensor=False)
|
358 |
+
|
359 |
+
# following controlnet 1-depth
|
360 |
+
depth = depth.astype(np.float64)
|
361 |
+
depth_non_zero = depth[depth!=0]
|
362 |
+
vmin = np.percentile(depth_non_zero, 2)
|
363 |
+
vmax = np.percentile(depth_non_zero, 85)
|
364 |
+
depth -= vmin
|
365 |
+
depth /= (vmax - vmin + 1e-4)
|
366 |
+
depth = 1.0 - depth
|
367 |
+
image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
368 |
+
|
369 |
+
# resize image
|
370 |
+
w, h = image.shape[1], image.shape[0]
|
371 |
+
if resize_by_stretch:
|
372 |
+
w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0])
|
373 |
+
else:
|
374 |
+
if not isinstance(resize, int):
|
375 |
+
assert resize[0] == resize[1]
|
376 |
+
resize = resize[0]
|
377 |
+
w_new, h_new = get_resized_wh(w, h, resize)
|
378 |
+
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
379 |
+
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
380 |
+
origin_img_size = torch.tensor([h, w], dtype=torch.float)
|
381 |
+
|
382 |
+
image = cv2.resize(image, (w_new, h_new))
|
383 |
+
scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
|
384 |
+
|
385 |
+
if padding: # padding
|
386 |
+
pad_to = max(h_new, w_new)
|
387 |
+
image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
|
388 |
+
else:
|
389 |
+
mask = None
|
390 |
+
|
391 |
+
if read_gray:
|
392 |
+
image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
|
393 |
+
else:
|
394 |
+
image = np.stack([image]*3) # 3 * H * W
|
395 |
+
image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
|
396 |
+
if mask is not None:
|
397 |
+
mask = torch.from_numpy(mask)
|
398 |
+
|
399 |
+
if image.shape[0] == 3 and normalize_img:
|
400 |
+
# Normalize image:
|
401 |
+
image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W
|
402 |
+
|
403 |
+
return image, mask, scale, origin_img_size
|
404 |
+
|
405 |
+
def read_megadepth_depth(path, pad_to=None, return_tensor=True):
|
406 |
+
if path.startswith('oss://'):
|
407 |
+
path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH)
|
408 |
+
if path.startswith('pcache://'):
|
409 |
+
path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/'
|
410 |
+
|
411 |
+
load_failed = True
|
412 |
+
failed_num = 0
|
413 |
+
while load_failed:
|
414 |
+
try:
|
415 |
+
if '.png' in path:
|
416 |
+
if 'scannet_plus' in path:
|
417 |
+
depth = imread_gray(path, cv_type=cv2.IMREAD_UNCHANGED).astype(np.float32)
|
418 |
+
|
419 |
+
with open(path, 'rb') as f:
|
420 |
+
# CO3D
|
421 |
+
depth = np.asarray(Image.open(f)).astype(np.float32)
|
422 |
+
depth = depth / 1000
|
423 |
+
elif '.pfm' in path:
|
424 |
+
# For BlendedMVS dataset (not support pcache):
|
425 |
+
depth = load_pfm(path).copy()
|
426 |
+
else:
|
427 |
+
# For MegaDepth
|
428 |
+
if str(path).startswith('oss://') or str(path).startswith('pcache://'):
|
429 |
+
depth = load_array_from_pcache(path, None, use_h5py=True)
|
430 |
+
else:
|
431 |
+
depth = np.array(h5py.File(path, 'r')['depth'])
|
432 |
+
load_failed = False
|
433 |
+
except:
|
434 |
+
failed_num += 1
|
435 |
+
if failed_num > 5000:
|
436 |
+
logger.error(f"Try to load: {path}, but failed {failed_num} times")
|
437 |
+
continue
|
438 |
+
|
439 |
+
if pad_to is not None:
|
440 |
+
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
|
441 |
+
if return_tensor:
|
442 |
+
depth = torch.from_numpy(depth).float() # (h, w)
|
443 |
+
return depth
|
444 |
+
|
445 |
+
|
446 |
+
# --- ScanNet ---
|
447 |
+
|
448 |
+
def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
|
449 |
+
"""
|
450 |
+
Args:
|
451 |
+
resize (tuple): align image to depthmap, in (w, h).
|
452 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects
|
453 |
+
Returns:
|
454 |
+
image (torch.tensor): (1, h, w)
|
455 |
+
mask (torch.tensor): (h, w)
|
456 |
+
scale (torch.tensor): [w/w_new, h/h_new]
|
457 |
+
"""
|
458 |
+
# read and resize image
|
459 |
+
image = imread_gray(path, augment_fn)
|
460 |
+
image = cv2.resize(image, resize)
|
461 |
+
|
462 |
+
# (h, w) -> (1, h, w) and normalized
|
463 |
+
image = torch.from_numpy(image).float()[None] / 255
|
464 |
+
return image
|
465 |
+
|
466 |
+
|
467 |
+
def read_scannet_depth(path):
|
468 |
+
if str(path).startswith('s3://'):
|
469 |
+
depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
|
470 |
+
else:
|
471 |
+
depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
|
472 |
+
depth = depth / 1000
|
473 |
+
depth = torch.from_numpy(depth).float() # (h, w)
|
474 |
+
return depth
|
475 |
+
|
476 |
+
|
477 |
+
def read_scannet_pose(path):
|
478 |
+
""" Read ScanNet's Camera2World pose and transform it to World2Camera.
|
479 |
+
|
480 |
+
Returns:
|
481 |
+
pose_w2c (np.ndarray): (4, 4)
|
482 |
+
"""
|
483 |
+
cam2world = np.loadtxt(path, delimiter=' ')
|
484 |
+
world2cam = inv(cam2world)
|
485 |
+
return world2cam
|
486 |
+
|
487 |
+
|
488 |
+
def read_scannet_intrinsic(path):
|
489 |
+
""" Read ScanNet's intrinsic matrix and return the 3x3 matrix.
|
490 |
+
"""
|
491 |
+
intrinsic = np.loadtxt(path, delimiter=' ')
|
492 |
+
return intrinsic[:-1, :-1]
|
493 |
+
|
494 |
+
def dict_to_cuda(data_dict):
|
495 |
+
data_dict_cuda = {}
|
496 |
+
for k, v in data_dict.items():
|
497 |
+
if isinstance(v, torch.Tensor):
|
498 |
+
data_dict_cuda[k] = v.cuda()
|
499 |
+
elif isinstance(v, dict):
|
500 |
+
data_dict_cuda[k] = dict_to_cuda(v)
|
501 |
+
elif isinstance(v, list):
|
502 |
+
data_dict_cuda[k] = list_to_cuda(v)
|
503 |
+
else:
|
504 |
+
data_dict_cuda[k] = v
|
505 |
+
return data_dict_cuda
|
506 |
+
|
507 |
+
def list_to_cuda(data_list):
|
508 |
+
data_list_cuda = []
|
509 |
+
for obj in data_list:
|
510 |
+
if isinstance(obj, torch.Tensor):
|
511 |
+
data_list_cuda.append(obj.cuda())
|
512 |
+
elif isinstance(obj, dict):
|
513 |
+
data_list_cuda.append(dict_to_cuda(obj))
|
514 |
+
elif isinstance(obj, list):
|
515 |
+
data_list_cuda.append(list_to_cuda(obj))
|
516 |
+
else:
|
517 |
+
data_list_cuda.append(obj)
|
518 |
+
return data_list_cuda
|
imcui/third_party/MatchAnything/src/utils/easydict.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class EasyDict(dict):
|
2 |
+
"""
|
3 |
+
Get attributes
|
4 |
+
|
5 |
+
>>> d = EasyDict({'foo':3})
|
6 |
+
>>> d['foo']
|
7 |
+
3
|
8 |
+
>>> d.foo
|
9 |
+
3
|
10 |
+
>>> d.bar
|
11 |
+
Traceback (most recent call last):
|
12 |
+
...
|
13 |
+
AttributeError: 'EasyDict' object has no attribute 'bar'
|
14 |
+
|
15 |
+
Works recursively
|
16 |
+
|
17 |
+
>>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
|
18 |
+
>>> isinstance(d.bar, dict)
|
19 |
+
True
|
20 |
+
>>> d.bar.x
|
21 |
+
1
|
22 |
+
|
23 |
+
Bullet-proof
|
24 |
+
|
25 |
+
>>> EasyDict({})
|
26 |
+
{}
|
27 |
+
>>> EasyDict(d={})
|
28 |
+
{}
|
29 |
+
>>> EasyDict(None)
|
30 |
+
{}
|
31 |
+
>>> d = {'a': 1}
|
32 |
+
>>> EasyDict(**d)
|
33 |
+
{'a': 1}
|
34 |
+
|
35 |
+
Set attributes
|
36 |
+
|
37 |
+
>>> d = EasyDict()
|
38 |
+
>>> d.foo = 3
|
39 |
+
>>> d.foo
|
40 |
+
3
|
41 |
+
>>> d.bar = {'prop': 'value'}
|
42 |
+
>>> d.bar.prop
|
43 |
+
'value'
|
44 |
+
>>> d
|
45 |
+
{'foo': 3, 'bar': {'prop': 'value'}}
|
46 |
+
>>> d.bar.prop = 'newer'
|
47 |
+
>>> d.bar.prop
|
48 |
+
'newer'
|
49 |
+
|
50 |
+
|
51 |
+
Values extraction
|
52 |
+
|
53 |
+
>>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
|
54 |
+
>>> isinstance(d.bar, list)
|
55 |
+
True
|
56 |
+
>>> from operator import attrgetter
|
57 |
+
>>> map(attrgetter('x'), d.bar)
|
58 |
+
[1, 3]
|
59 |
+
>>> map(attrgetter('y'), d.bar)
|
60 |
+
[2, 4]
|
61 |
+
>>> d = EasyDict()
|
62 |
+
>>> d.keys()
|
63 |
+
[]
|
64 |
+
>>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
|
65 |
+
>>> d.foo
|
66 |
+
3
|
67 |
+
>>> d.bar.x
|
68 |
+
1
|
69 |
+
|
70 |
+
Still like a dict though
|
71 |
+
|
72 |
+
>>> o = EasyDict({'clean':True})
|
73 |
+
>>> o.items()
|
74 |
+
[('clean', True)]
|
75 |
+
|
76 |
+
And like a class
|
77 |
+
|
78 |
+
>>> class Flower(EasyDict):
|
79 |
+
... power = 1
|
80 |
+
...
|
81 |
+
>>> f = Flower()
|
82 |
+
>>> f.power
|
83 |
+
1
|
84 |
+
>>> f = Flower({'height': 12})
|
85 |
+
>>> f.height
|
86 |
+
12
|
87 |
+
>>> f['power']
|
88 |
+
1
|
89 |
+
>>> sorted(f.keys())
|
90 |
+
['height', 'power']
|
91 |
+
|
92 |
+
update and pop items
|
93 |
+
>>> d = EasyDict(a=1, b='2')
|
94 |
+
>>> e = EasyDict(c=3.0, a=9.0)
|
95 |
+
>>> d.update(e)
|
96 |
+
>>> d.c
|
97 |
+
3.0
|
98 |
+
>>> d['c']
|
99 |
+
3.0
|
100 |
+
>>> d.get('c')
|
101 |
+
3.0
|
102 |
+
>>> d.update(a=4, b=4)
|
103 |
+
>>> d.b
|
104 |
+
4
|
105 |
+
>>> d.pop('a')
|
106 |
+
4
|
107 |
+
>>> d.a
|
108 |
+
Traceback (most recent call last):
|
109 |
+
...
|
110 |
+
AttributeError: 'EasyDict' object has no attribute 'a'
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self, d=None, **kwargs):
|
114 |
+
if d is None:
|
115 |
+
d = {}
|
116 |
+
if kwargs:
|
117 |
+
d.update(**kwargs)
|
118 |
+
for k, v in d.items():
|
119 |
+
setattr(self, k, v)
|
120 |
+
# Class attributes
|
121 |
+
for k in self.__class__.__dict__.keys():
|
122 |
+
if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
|
123 |
+
setattr(self, k, getattr(self, k))
|
124 |
+
|
125 |
+
def __setattr__(self, name, value):
|
126 |
+
if isinstance(value, (list, tuple)):
|
127 |
+
value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
|
128 |
+
elif isinstance(value, dict) and not isinstance(value, self.__class__):
|
129 |
+
value = self.__class__(value)
|
130 |
+
super(EasyDict, self).__setattr__(name, value)
|
131 |
+
super(EasyDict, self).__setitem__(name, value)
|
132 |
+
|
133 |
+
__setitem__ = __setattr__
|
134 |
+
|
135 |
+
def update(self, e=None, **f):
|
136 |
+
d = e or dict()
|
137 |
+
d.update(f)
|
138 |
+
for k in d:
|
139 |
+
setattr(self, k, d[k])
|
140 |
+
|
141 |
+
def pop(self, k, d=None):
|
142 |
+
if hasattr(self, k):
|
143 |
+
delattr(self, k)
|
144 |
+
return super(EasyDict, self).pop(k, d)
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
import doctest
|
imcui/third_party/MatchAnything/src/utils/geometry.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
# from numba import jit
|
6 |
+
|
7 |
+
pixel_coords = None
|
8 |
+
|
9 |
+
def set_id_grid(depth):
|
10 |
+
b, h, w = depth.size()
|
11 |
+
i_range = torch.arange(0, h).view(1, h, 1).expand(1,h,w).type_as(depth) # [1, H, W]
|
12 |
+
j_range = torch.arange(0, w).view(1, 1, w).expand(1,h,w).type_as(depth) # [1, H, W]
|
13 |
+
ones = torch.ones(1,h,w).type_as(depth)
|
14 |
+
|
15 |
+
pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W]
|
16 |
+
return pixel_coords
|
17 |
+
|
18 |
+
def check_sizes(input, input_name, expected):
|
19 |
+
condition = [input.ndimension() == len(expected)]
|
20 |
+
for i,size in enumerate(expected):
|
21 |
+
if size.isdigit():
|
22 |
+
condition.append(input.size(i) == int(size))
|
23 |
+
assert(all(condition)), "wrong size for {}, expected {}, got {}".format(input_name, 'x'.join(expected), list(input.size()))
|
24 |
+
|
25 |
+
|
26 |
+
def pixel2cam(depth, intrinsics_inv):
|
27 |
+
"""Transform coordinates in the pixel frame to the camera frame.
|
28 |
+
Args:
|
29 |
+
depth: depth maps -- [B, H, W]
|
30 |
+
intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3]
|
31 |
+
Returns:
|
32 |
+
array of (u,v,1) cam coordinates -- [B, 3, H, W]
|
33 |
+
"""
|
34 |
+
b, h, w = depth.size()
|
35 |
+
pixel_coords = set_id_grid(depth)
|
36 |
+
current_pixel_coords = pixel_coords[:,:,:h,:w].expand(b,3,h,w).reshape(b, 3, -1) # [B, 3, H*W]
|
37 |
+
cam_coords = (intrinsics_inv.float() @ current_pixel_coords.float()).reshape(b, 3, h, w)
|
38 |
+
return cam_coords * depth.unsqueeze(1)
|
39 |
+
|
40 |
+
def cam2pixel_depth(cam_coords, proj_c2p_rot, proj_c2p_tr):
|
41 |
+
"""Transform coordinates in the camera frame to the pixel frame and get depth map.
|
42 |
+
Args:
|
43 |
+
cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
|
44 |
+
proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4]
|
45 |
+
proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
|
46 |
+
Returns:
|
47 |
+
tensor of [-1,1] coordinates -- [B, 2, H, W]
|
48 |
+
depth map -- [B, H, W]
|
49 |
+
"""
|
50 |
+
b, _, h, w = cam_coords.size()
|
51 |
+
cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W]
|
52 |
+
if proj_c2p_rot is not None:
|
53 |
+
pcoords = proj_c2p_rot @ cam_coords_flat
|
54 |
+
else:
|
55 |
+
pcoords = cam_coords_flat
|
56 |
+
|
57 |
+
if proj_c2p_tr is not None:
|
58 |
+
pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
|
59 |
+
X = pcoords[:, 0]
|
60 |
+
Y = pcoords[:, 1]
|
61 |
+
Z = pcoords[:, 2].clamp(min=1e-3) # [B, H*W] min_depth = 1 mm
|
62 |
+
|
63 |
+
X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W]
|
64 |
+
Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W]
|
65 |
+
|
66 |
+
pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
|
67 |
+
return pixel_coords.reshape(b,h,w,2), Z.reshape(b, h, w)
|
68 |
+
|
69 |
+
|
70 |
+
def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr):
|
71 |
+
"""Transform coordinates in the camera frame to the pixel frame.
|
72 |
+
Args:
|
73 |
+
cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
|
74 |
+
proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4]
|
75 |
+
proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
|
76 |
+
Returns:
|
77 |
+
array of [-1,1] coordinates -- [B, 2, H, W]
|
78 |
+
"""
|
79 |
+
b, _, h, w = cam_coords.size()
|
80 |
+
cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W]
|
81 |
+
if proj_c2p_rot is not None:
|
82 |
+
pcoords = proj_c2p_rot @ cam_coords_flat
|
83 |
+
else:
|
84 |
+
pcoords = cam_coords_flat
|
85 |
+
|
86 |
+
if proj_c2p_tr is not None:
|
87 |
+
pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
|
88 |
+
X = pcoords[:, 0]
|
89 |
+
Y = pcoords[:, 1]
|
90 |
+
Z = pcoords[:, 2].clamp(min=1e-3) # [B, H*W] min_depth = 1 mm
|
91 |
+
|
92 |
+
X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W]
|
93 |
+
Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W]
|
94 |
+
|
95 |
+
pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
|
96 |
+
return pixel_coords.reshape(b,h,w,2)
|
97 |
+
|
98 |
+
|
99 |
+
def reproject_kpts(dim0_idxs, kpts, depth, rel_pose, K0, K1):
|
100 |
+
""" Reproject keypoints with depth, relative pose and camera intrinsics
|
101 |
+
Args:
|
102 |
+
dim0_idxs (torch.LoneTensor): (B*max_kpts, )
|
103 |
+
kpts (torch.LongTensor): (B, max_kpts, 2) - <x,y>
|
104 |
+
depth (torch.Tensor): (B, H, W)
|
105 |
+
rel_pose (torch.Tensor): (B, 3, 4) relative transfomation from target to source (T_0to1) --
|
106 |
+
K0: (torch.Tensor): (N, 3, 3) - (K_0)
|
107 |
+
K1: (torch.Tensor): (N, 3, 3) - (K_1)
|
108 |
+
Returns:
|
109 |
+
(torch.Tensor): (B, max_kpts, 2) the reprojected kpts
|
110 |
+
"""
|
111 |
+
# pixel to camera
|
112 |
+
device = kpts.device
|
113 |
+
B, max_kpts, _ = kpts.shape
|
114 |
+
|
115 |
+
kpts = kpts.reshape(-1, 2) # (B*K, 2)
|
116 |
+
kpts_depth = depth[dim0_idxs, kpts[:, 1], kpts[:, 0]] # (B*K, )
|
117 |
+
kpts = torch.cat([kpts.float(),
|
118 |
+
torch.ones((kpts.shape[0], 1), dtype=torch.float32, device=device)], -1) # (B*K, 3)
|
119 |
+
pixel_coords = (kpts * kpts_depth[:, None]).reshape(B, max_kpts, 3).permute(0, 2, 1) # (B, 3, K)
|
120 |
+
|
121 |
+
cam_coords = K0.inverse() @ pixel_coords # (N, 3, max_kpts)
|
122 |
+
# camera1 to camera 2
|
123 |
+
rel_pose_R = rel_pose[:, :, :-1] # (B, 3, 3)
|
124 |
+
rel_pose_t = rel_pose[:, :, -1][..., None] # (B, 3, 1)
|
125 |
+
cam2_coords = rel_pose_R @ cam_coords + rel_pose_t # (B, 3, max_kpts)
|
126 |
+
# projection
|
127 |
+
pixel2_coords = K1 @ cam2_coords # (B, 3, max_kpts)
|
128 |
+
reproj_kpts = pixel2_coords[:, :-1, :] / pixel2_coords[:, -1, :][:, None].expand(-1, 2, -1)
|
129 |
+
return reproj_kpts.permute(0, 2, 1)
|
130 |
+
|
131 |
+
|
132 |
+
def check_depth_consistency(b_idxs, kpts0, depth0, kpts1, depth1, T_0to1, K0, K1,
|
133 |
+
atol=0.1, rtol=0.0):
|
134 |
+
"""
|
135 |
+
Args:
|
136 |
+
b_idxs (torch.LongTensor): (n_kpts, ) the batch indices which each keypoints pairs belong to
|
137 |
+
kpts0 (torch.LongTensor): (n_kpts, 2) - <x, y>
|
138 |
+
depth0 (torch.Tensor): (B, H, W)
|
139 |
+
kpts1 (torch.LongTensor): (n_kpts, 2)
|
140 |
+
depth1 (torch.Tensor): (B, H, W)
|
141 |
+
T_0to1 (torch.Tensor): (B, 3, 4)
|
142 |
+
K0: (torch.Tensor): (N, 3, 3) - (K_0)
|
143 |
+
K1: (torch.Tensor): (N, 3, 3) - (K_1)
|
144 |
+
atol (float): the absolute tolerance for depth consistency check
|
145 |
+
rtol (float): the relative tolerance for depth consistency check
|
146 |
+
Returns:
|
147 |
+
valid_mask (torch.Tensor): (n_kpts, )
|
148 |
+
Notes:
|
149 |
+
The two corresponding keypoints are depth consistent if the following equation is held:
|
150 |
+
abs(kpt_0to1_depth - kpt1_depth) <= (atol + rtol * abs(kpt1_depth))
|
151 |
+
* In the initial reimplementation, `atol=0.1, rtol=0` is used, and the result is better with
|
152 |
+
`atol=1.0, rtol=0` (which nearly ignore the depth consistency check).
|
153 |
+
* However, the author suggests using `atol=0.0, rtol=0.1` as in https://github.com/magicleap/SuperGluePretrainedNetwork/issues/31#issuecomment-681866054
|
154 |
+
"""
|
155 |
+
device = kpts0.device
|
156 |
+
n_kpts = kpts0.shape[0]
|
157 |
+
|
158 |
+
kpts0_depth = depth0[b_idxs, kpts0[:, 1], kpts0[:, 0]] # (n_kpts, )
|
159 |
+
kpts1_depth = depth1[b_idxs, kpts1[:, 1], kpts1[:, 0]] # (n_kpts, )
|
160 |
+
kpts0 = torch.cat([kpts0.float(),
|
161 |
+
torch.ones((n_kpts, 1), dtype=torch.float32, device=device)], -1) # (n_kpts, 3)
|
162 |
+
pixel_coords = (kpts0 * kpts0_depth[:, None])[..., None] # (n_kpts, 3, 1)
|
163 |
+
|
164 |
+
# indexing from T_0to1 and K - treat all kpts as a batch
|
165 |
+
K0 = K0[b_idxs, :, :] # (n_kpts, 3, 3)
|
166 |
+
T_0to1 = T_0to1[b_idxs, :, :] # (n_kpts, 3, 4)
|
167 |
+
cam_coords = K0.inverse() @ pixel_coords # (n_kpts, 3, 1)
|
168 |
+
|
169 |
+
# camera1 to camera2
|
170 |
+
R_0to1 = T_0to1[:, :, :-1] # (n_kpts, 3, 3)
|
171 |
+
t_0to1 = T_0to1[:, :, -1][..., None] # (n_kpts, 3, 1)
|
172 |
+
cam1_coords = R_0to1 @ cam_coords + t_0to1 # (n_kpts, 3, 1)
|
173 |
+
K1 = K1[b_idxs, :, :] # (n_kpts, 3, 3)
|
174 |
+
pixel1_coords = K1 @ cam1_coords # (n_kpts, 3, 1)
|
175 |
+
kpts_0to1_depth = pixel1_coords[:, -1, 0] # (n_kpts, )
|
176 |
+
return (kpts_0to1_depth - kpts1_depth).abs() <= atol + rtol * kpts1_depth.abs()
|
177 |
+
|
178 |
+
|
179 |
+
def inverse_warp(img, depth, pose, intrinsics, mode='bilinear', padding_mode='zeros'):
|
180 |
+
"""
|
181 |
+
Inverse warp a source image to the target image plane.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
img: the source image (where to sample pixels) -- [B, 3, H, W]
|
185 |
+
depth: depth map of the target image -- [B, H, W]
|
186 |
+
pose: relative transfomation from target to source -- [B, 3, 4]
|
187 |
+
intrinsics: camera intrinsic matrix -- [B, 3, 3]
|
188 |
+
Returns:
|
189 |
+
projected_img: Source image warped to the target image plane
|
190 |
+
valid_points: Boolean array indicating point validity
|
191 |
+
"""
|
192 |
+
# check_sizes(img, 'img', 'B3HW')
|
193 |
+
check_sizes(depth, 'depth', 'BHW')
|
194 |
+
# check_sizes(pose, 'pose', 'B6')
|
195 |
+
check_sizes(intrinsics, 'intrinsics', 'B33')
|
196 |
+
|
197 |
+
batch_size, _, img_height, img_width = img.size()
|
198 |
+
|
199 |
+
cam_coords = pixel2cam(depth, intrinsics.inverse()) # [B,3,H,W]
|
200 |
+
|
201 |
+
pose_mat = pose # (B, 3, 4)
|
202 |
+
|
203 |
+
# Get projection matrix for target camera frame to source pixel frame
|
204 |
+
proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4]
|
205 |
+
|
206 |
+
rot, tr = proj_cam_to_src_pixel[:,:,:3], proj_cam_to_src_pixel[:,:,-1:]
|
207 |
+
src_pixel_coords = cam2pixel(cam_coords, rot, tr) # [B,H,W,2]
|
208 |
+
projected_img = F.grid_sample(img, src_pixel_coords, mode=mode,
|
209 |
+
padding_mode=padding_mode, align_corners=True)
|
210 |
+
|
211 |
+
valid_points = src_pixel_coords.abs().max(dim=-1)[0] <= 1
|
212 |
+
|
213 |
+
return projected_img, valid_points
|
214 |
+
|
215 |
+
def depth_inverse_warp(depth_source, depth, pose, intrinsic_source, intrinsic, mode='nearest', padding_mode='zeros'):
|
216 |
+
"""
|
217 |
+
1. Inversely warp a source depth map to the target image plane (warped depth map still in source frame)
|
218 |
+
2. Transform the target depth map to the source image frame
|
219 |
+
Args:
|
220 |
+
depth_source: the source image (where to sample pixels) -- [B, H, W]
|
221 |
+
depth: depth map of the target image -- [B, H, W]
|
222 |
+
pose: relative transfomation from target to source -- [B, 3, 4]
|
223 |
+
intrinsics: camera intrinsic matrix -- [B, 3, 3]
|
224 |
+
Returns:
|
225 |
+
warped_depth: Source depth warped to the target image plane -- [B, H, W]
|
226 |
+
projected_depth: Target depth projected to the source image frame -- [B, H, W]
|
227 |
+
valid_points: Boolean array indicating point validity -- [B, H, W]
|
228 |
+
"""
|
229 |
+
check_sizes(depth_source, 'depth', 'BHW')
|
230 |
+
check_sizes(depth, 'depth', 'BHW')
|
231 |
+
check_sizes(intrinsic_source, 'intrinsics', 'B33')
|
232 |
+
|
233 |
+
b, h, w = depth.size()
|
234 |
+
|
235 |
+
cam_coords = pixel2cam(depth, intrinsic.inverse()) # [B,3,H,W]
|
236 |
+
|
237 |
+
pose_mat = pose # (B, 3, 4)
|
238 |
+
|
239 |
+
# Get projection matrix from target camera frame to source pixel frame
|
240 |
+
proj_cam_to_src_pixel = intrinsic_source @ pose_mat # [B, 3, 4]
|
241 |
+
|
242 |
+
rot, tr = proj_cam_to_src_pixel[:,:,:3], proj_cam_to_src_pixel[:,:,-1:]
|
243 |
+
src_pixel_coords, depth_target2src = cam2pixel_depth(cam_coords, rot, tr) # [B,H,W,2]
|
244 |
+
warped_depth = F.grid_sample(depth_source[:, None], src_pixel_coords, mode=mode,
|
245 |
+
padding_mode=padding_mode, align_corners=True) # [B, 1, H, W]
|
246 |
+
|
247 |
+
valid_points = (src_pixel_coords.abs().max(dim=-1)[0] <= 1) &\
|
248 |
+
(depth > 0.0) & (warped_depth[:, 0] > 0.0) # [B, H, W]
|
249 |
+
return warped_depth[:, 0], depth_target2src, valid_points
|
250 |
+
|
251 |
+
def to_skew(t):
|
252 |
+
""" Transform the translation vector t to skew-symmetric matrix.
|
253 |
+
Args:
|
254 |
+
t (torch.Tensor): (B, 3)
|
255 |
+
"""
|
256 |
+
t_skew = t.new_ones((t.shape[0], 3, 3))
|
257 |
+
t_skew[:, 0, 1] = -t[:, 2]
|
258 |
+
t_skew[:, 1, 0] = t[:, 2]
|
259 |
+
t_skew[:, 0, 2] = t[:, 1]
|
260 |
+
t_skew[:, 2, 0] = -t[:, 1]
|
261 |
+
t_skew[:, 1, 2] = -t[:, 0]
|
262 |
+
t_skew[:, 2, 1] = t[:, 0]
|
263 |
+
return t_skew # (B, 3, 3)
|
264 |
+
|
265 |
+
|
266 |
+
def to_homogeneous(pts):
|
267 |
+
"""
|
268 |
+
Args:
|
269 |
+
pts (torch.Tensor): (B, K, 2)
|
270 |
+
"""
|
271 |
+
return torch.cat([pts, torch.ones_like(pts[..., :1])], -1) # (B, K, 3)
|
272 |
+
|
273 |
+
|
274 |
+
def pix2img(pts, K):
|
275 |
+
"""
|
276 |
+
Args:
|
277 |
+
pts (torch.Tensor): (B, K, 2)
|
278 |
+
K (torch.Tensor): (B, 3, 3)
|
279 |
+
"""
|
280 |
+
return (pts - K[:, [0, 1], [2, 2]][:, None]) / K[:, [0, 1], [0, 1]][:, None]
|
281 |
+
|
282 |
+
|
283 |
+
def weighted_blind_sed(kpts0, kpts1, weights, E, K0, K1):
|
284 |
+
""" Calculate the squared weighted blind symmetric epipolar distance, which is the sed between
|
285 |
+
all possible keypoints pairs.
|
286 |
+
Args:
|
287 |
+
kpts0 (torch.Tensor): (B, K0, 2)
|
288 |
+
ktps1 (torch.Tensor): (B, K1, 2)
|
289 |
+
weights (torch.Tensor): (B, K0, K1)
|
290 |
+
E (torch.Tensor): (B, 3, 3) - the essential matrix
|
291 |
+
K0 (torch.Tensor): (B, 3, 3)
|
292 |
+
K1 (torch.Tensor): (B, 3, 3)
|
293 |
+
Returns:
|
294 |
+
w_sed (torch.Tensor): (B, K0, K1)
|
295 |
+
"""
|
296 |
+
M, N = kpts0.shape[1], kpts1.shape[1]
|
297 |
+
|
298 |
+
kpts0 = to_homogeneous(pix2img(kpts0, K0))
|
299 |
+
kpts1 = to_homogeneous(pix2img(kpts1, K1)) # (B, K1, 3)
|
300 |
+
|
301 |
+
R = kpts0 @ E.transpose(1, 2) @ kpts1.transpose(1, 2) # (B, K0, K1)
|
302 |
+
# w_R = weights * R # (B, K0, K1)
|
303 |
+
|
304 |
+
Ep0 = kpts0 @ E.transpose(1, 2) # (B, K0, 3)
|
305 |
+
Etp1 = kpts1 @ E # (B, K1, 3)
|
306 |
+
d = R**2 * (1.0 / (Ep0[..., 0]**2 + Ep0[..., 1]**2)[..., None].expand(-1, -1, N)
|
307 |
+
+ 1.0 / (Etp1[..., 0]**2 + Etp1[..., 1]**2)[:, None].expand(-1, M, -1)) * weights # (B, K0, K1)
|
308 |
+
return d
|
309 |
+
|
310 |
+
def weighted_blind_sampson(kpts0, kpts1, weights, E, K0, K1):
|
311 |
+
""" Calculate the squared weighted blind sampson distance, which is the sampson distance between
|
312 |
+
all possible keypoints pairs weighted by the given weights.
|
313 |
+
"""
|
314 |
+
M, N = kpts0.shape[1], kpts1.shape[1]
|
315 |
+
|
316 |
+
kpts0 = to_homogeneous(pix2img(kpts0, K0))
|
317 |
+
kpts1 = to_homogeneous(pix2img(kpts1, K1)) # (B, K1, 3)
|
318 |
+
|
319 |
+
R = kpts0 @ E.transpose(1, 2) @ kpts1.transpose(1, 2) # (B, K0, K1)
|
320 |
+
# w_R = weights * R # (B, K0, K1)
|
321 |
+
|
322 |
+
Ep0 = kpts0 @ E.transpose(1, 2) # (B, K0, 3)
|
323 |
+
Etp1 = kpts1 @ E # (B, K1, 3)
|
324 |
+
d = R**2 * (1.0 / ((Ep0[..., 0]**2 + Ep0[..., 1]**2)[..., None].expand(-1, -1, N)
|
325 |
+
+ (Etp1[..., 0]**2 + Etp1[..., 1]**2)[:, None].expand(-1, M, -1))) * weights # (B, K0, K1)
|
326 |
+
return d
|
327 |
+
|
328 |
+
|
329 |
+
def angular_rel_rot(T_0to1):
|
330 |
+
"""
|
331 |
+
Args:
|
332 |
+
T0_to_1 (np.ndarray): (4, 4)
|
333 |
+
"""
|
334 |
+
cos = (np.trace(T_0to1[:-1, :-1]) - 1) / 2
|
335 |
+
if cos < -1:
|
336 |
+
cos = -1.0
|
337 |
+
if cos > 1:
|
338 |
+
cos = 1.0
|
339 |
+
angle_error_rot = np.rad2deg(np.abs(np.arccos(cos)))
|
340 |
+
|
341 |
+
return angle_error_rot
|
342 |
+
|
343 |
+
def angular_rel_pose(T0, T1):
|
344 |
+
"""
|
345 |
+
Args:
|
346 |
+
T0 (np.ndarray): (4, 4)
|
347 |
+
T1 (np.ndarray): (4, 4)
|
348 |
+
|
349 |
+
"""
|
350 |
+
cos = (np.trace(T0[:-1, :-1].T @ T1[:-1, :-1]) - 1) / 2
|
351 |
+
if cos < -1:
|
352 |
+
cos = -1.0
|
353 |
+
if cos > 1:
|
354 |
+
cos = 1.0
|
355 |
+
angle_error_rot = np.rad2deg(np.abs(np.arccos(cos)))
|
356 |
+
|
357 |
+
# calculate angular translation error
|
358 |
+
n = np.linalg.norm(T0[:-1, -1]) * np.linalg.norm(T1[:-1, -1])
|
359 |
+
cos = np.dot(T0[:-1, -1], T1[:-1, -1]) / n
|
360 |
+
if cos < -1:
|
361 |
+
cos = -1.0
|
362 |
+
if cos > 1:
|
363 |
+
cos = 1.0
|
364 |
+
angle_error_trans = np.rad2deg(np.arccos(cos))
|
365 |
+
|
366 |
+
return angle_error_rot, angle_error_trans
|
imcui/third_party/MatchAnything/src/utils/homography_utils.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
def to_homogeneous(points):
|
8 |
+
"""Convert N-dimensional points to homogeneous coordinates.
|
9 |
+
Args:
|
10 |
+
points: torch.Tensor or numpy.ndarray with size (..., N).
|
11 |
+
Returns:
|
12 |
+
A torch.Tensor or numpy.ndarray with size (..., N+1).
|
13 |
+
"""
|
14 |
+
if isinstance(points, torch.Tensor):
|
15 |
+
pad = points.new_ones(points.shape[:-1] + (1,))
|
16 |
+
return torch.cat([points, pad], dim=-1)
|
17 |
+
elif isinstance(points, np.ndarray):
|
18 |
+
pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
|
19 |
+
return np.concatenate([points, pad], axis=-1)
|
20 |
+
else:
|
21 |
+
raise ValueError
|
22 |
+
|
23 |
+
|
24 |
+
def from_homogeneous(points, eps=0.0):
|
25 |
+
"""Remove the homogeneous dimension of N-dimensional points.
|
26 |
+
Args:
|
27 |
+
points: torch.Tensor or numpy.ndarray with size (..., N+1).
|
28 |
+
eps: Epsilon value to prevent zero division.
|
29 |
+
Returns:
|
30 |
+
A torch.Tensor or numpy ndarray with size (..., N).
|
31 |
+
"""
|
32 |
+
return points[..., :-1] / (points[..., -1:] + eps)
|
33 |
+
|
34 |
+
|
35 |
+
def flat2mat(H):
|
36 |
+
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
|
37 |
+
|
38 |
+
|
39 |
+
# Homography creation
|
40 |
+
|
41 |
+
|
42 |
+
def create_center_patch(shape, patch_shape=None):
|
43 |
+
if patch_shape is None:
|
44 |
+
patch_shape = shape
|
45 |
+
width, height = shape
|
46 |
+
pwidth, pheight = patch_shape
|
47 |
+
left = int((width - pwidth) / 2)
|
48 |
+
bottom = int((height - pheight) / 2)
|
49 |
+
right = int((width + pwidth) / 2)
|
50 |
+
top = int((height + pheight) / 2)
|
51 |
+
return np.array([[left, bottom], [left, top], [right, top], [right, bottom]])
|
52 |
+
|
53 |
+
|
54 |
+
def check_convex(patch, min_convexity=0.05):
|
55 |
+
"""Checks if given polygon vertices [N,2] form a convex shape"""
|
56 |
+
for i in range(patch.shape[0]):
|
57 |
+
x1, y1 = patch[(i - 1) % patch.shape[0]]
|
58 |
+
x2, y2 = patch[i]
|
59 |
+
x3, y3 = patch[(i + 1) % patch.shape[0]]
|
60 |
+
if (x2 - x1) * (y3 - y2) - (x3 - x2) * (y2 - y1) > -min_convexity:
|
61 |
+
return False
|
62 |
+
return True
|
63 |
+
|
64 |
+
|
65 |
+
def sample_homography_corners(
|
66 |
+
shape,
|
67 |
+
patch_shape,
|
68 |
+
difficulty=1.0,
|
69 |
+
translation=0.4,
|
70 |
+
n_angles=10,
|
71 |
+
max_angle=90,
|
72 |
+
min_convexity=0.05,
|
73 |
+
rng=np.random,
|
74 |
+
):
|
75 |
+
max_angle = max_angle / 180.0 * math.pi
|
76 |
+
width, height = shape
|
77 |
+
pwidth, pheight = width * (1 - difficulty), height * (1 - difficulty)
|
78 |
+
min_pts1 = create_center_patch(shape, (pwidth, pheight))
|
79 |
+
full = create_center_patch(shape)
|
80 |
+
pts2 = create_center_patch(patch_shape)
|
81 |
+
scale = min_pts1 - full
|
82 |
+
found_valid = False
|
83 |
+
cnt = -1
|
84 |
+
while not found_valid:
|
85 |
+
offsets = rng.uniform(0.0, 1.0, size=(4, 2)) * scale
|
86 |
+
pts1 = full + offsets
|
87 |
+
found_valid = check_convex(pts1 / np.array(shape), min_convexity)
|
88 |
+
cnt += 1
|
89 |
+
|
90 |
+
# re-center
|
91 |
+
pts1 = pts1 - np.mean(pts1, axis=0, keepdims=True)
|
92 |
+
pts1 = pts1 + np.mean(min_pts1, axis=0, keepdims=True)
|
93 |
+
|
94 |
+
# Rotation
|
95 |
+
if n_angles > 0 and difficulty > 0:
|
96 |
+
angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
97 |
+
rng.shuffle(angles)
|
98 |
+
rng.shuffle(angles)
|
99 |
+
angles = np.concatenate([[0.0], angles], axis=0)
|
100 |
+
|
101 |
+
center = np.mean(pts1, axis=0, keepdims=True)
|
102 |
+
rot_mat = np.reshape(
|
103 |
+
np.stack(
|
104 |
+
[np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)],
|
105 |
+
axis=1,
|
106 |
+
),
|
107 |
+
[-1, 2, 2],
|
108 |
+
)
|
109 |
+
rotated = (
|
110 |
+
np.matmul(
|
111 |
+
np.tile(np.expand_dims(pts1 - center, axis=0), [n_angles + 1, 1, 1]),
|
112 |
+
rot_mat,
|
113 |
+
)
|
114 |
+
+ center
|
115 |
+
)
|
116 |
+
|
117 |
+
for idx in range(1, n_angles):
|
118 |
+
warped_points = rotated[idx] / np.array(shape)
|
119 |
+
if np.all((warped_points >= 0.0) & (warped_points < 1.0)):
|
120 |
+
pts1 = rotated[idx]
|
121 |
+
break
|
122 |
+
|
123 |
+
# Translation
|
124 |
+
if translation > 0:
|
125 |
+
min_trans = -np.min(pts1, axis=0)
|
126 |
+
max_trans = shape - np.max(pts1, axis=0)
|
127 |
+
trans = rng.uniform(min_trans, max_trans)[None]
|
128 |
+
pts1 += trans * translation * difficulty
|
129 |
+
|
130 |
+
H = compute_homography(pts1, pts2, [1.0, 1.0])
|
131 |
+
warped = warp_points(full, H, inverse=False)
|
132 |
+
return H, full, warped, patch_shape
|
133 |
+
|
134 |
+
|
135 |
+
def compute_homography(pts1_, pts2_, shape):
|
136 |
+
"""Compute the homography matrix from 4 point correspondences"""
|
137 |
+
# Rescale to actual size
|
138 |
+
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
|
139 |
+
pts1 = pts1_ * np.expand_dims(shape, axis=0)
|
140 |
+
pts2 = pts2_ * np.expand_dims(shape, axis=0)
|
141 |
+
|
142 |
+
def ax(p, q):
|
143 |
+
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
|
144 |
+
|
145 |
+
def ay(p, q):
|
146 |
+
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
|
147 |
+
|
148 |
+
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
|
149 |
+
p_mat = np.transpose(
|
150 |
+
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
|
151 |
+
)
|
152 |
+
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
|
153 |
+
return flat2mat(homography)
|
154 |
+
|
155 |
+
|
156 |
+
# Point warping utils
|
157 |
+
|
158 |
+
|
159 |
+
def warp_points(points, homography, inverse=True):
|
160 |
+
"""
|
161 |
+
Warp a list of points with the INVERSE of the given homography.
|
162 |
+
The inverse is used to be coherent with tf.contrib.image.transform
|
163 |
+
Arguments:
|
164 |
+
points: list of N points, shape (N, 2).
|
165 |
+
homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
166 |
+
Returns: a Tensor of shape (N, 2) or (B, N, 2) (depending on whether the homography
|
167 |
+
is batched) containing the new coordinates of the warped points.
|
168 |
+
"""
|
169 |
+
H = homography[None] if len(homography.shape) == 2 else homography
|
170 |
+
|
171 |
+
# Get the points to the homogeneous format
|
172 |
+
num_points = points.shape[0]
|
173 |
+
points = np.concatenate([points, np.ones([num_points, 1], dtype=np.float32)], -1)
|
174 |
+
|
175 |
+
H_inv = np.transpose(np.linalg.inv(H) if inverse else H)
|
176 |
+
warped_points = np.tensordot(points, H_inv, axes=[[1], [0]])
|
177 |
+
|
178 |
+
warped_points = np.transpose(warped_points, [2, 0, 1])
|
179 |
+
warped_points[np.abs(warped_points[:, :, 2]) < 1e-8, 2] = 1e-8
|
180 |
+
warped_points = warped_points[:, :, :2] / warped_points[:, :, 2:]
|
181 |
+
|
182 |
+
return warped_points[0] if len(homography.shape) == 2 else warped_points
|
183 |
+
|
184 |
+
|
185 |
+
def warp_points_torch(points, H, inverse=True):
|
186 |
+
"""
|
187 |
+
Warp a list of points with the INVERSE of the given homography.
|
188 |
+
The inverse is used to be coherent with tf.contrib.image.transform
|
189 |
+
Arguments:
|
190 |
+
points: batched list of N points, shape (B, N, 2).
|
191 |
+
H: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
192 |
+
inverse: Whether to multiply the points by H or the inverse of H
|
193 |
+
Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps.
|
194 |
+
"""
|
195 |
+
|
196 |
+
# Get the points to the homogeneous format
|
197 |
+
points = to_homogeneous(points)
|
198 |
+
|
199 |
+
# Apply the homography
|
200 |
+
H_mat = (torch.inverse(H) if inverse else H).transpose(-2, -1)
|
201 |
+
warped_points = torch.einsum("...nj,...ji->...ni", points, H_mat)
|
202 |
+
|
203 |
+
warped_points = from_homogeneous(warped_points, eps=1e-5)
|
204 |
+
return warped_points
|
205 |
+
|
206 |
+
|
207 |
+
# Line warping utils
|
208 |
+
|
209 |
+
|
210 |
+
def seg_equation(segs):
|
211 |
+
# calculate list of start, end and midpoints points from both lists
|
212 |
+
start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(
|
213 |
+
segs[..., 1, :]
|
214 |
+
)
|
215 |
+
# Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1
|
216 |
+
lines = torch.cross(start_points, end_points, dim=-1)
|
217 |
+
lines_norm = torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None]
|
218 |
+
assert torch.all(
|
219 |
+
lines_norm > 0
|
220 |
+
), "Error: trying to compute the equation of a line with a single point"
|
221 |
+
lines = lines / lines_norm
|
222 |
+
return lines
|
223 |
+
|
224 |
+
|
225 |
+
def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]):
|
226 |
+
h, w = img_shape
|
227 |
+
return (
|
228 |
+
(pts >= 0).all(dim=-1)
|
229 |
+
& (pts[..., 0] < w)
|
230 |
+
& (pts[..., 1] < h)
|
231 |
+
& (~torch.isinf(pts).any(dim=-1))
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor:
|
236 |
+
"""
|
237 |
+
Shrink an array of segments to fit inside the image.
|
238 |
+
:param segs: The tensor of segments with shape (N, 2, 2)
|
239 |
+
:param img_shape: The image shape in format (H, W)
|
240 |
+
"""
|
241 |
+
EPS = 1e-4
|
242 |
+
device = segs.device
|
243 |
+
w, h = img_shape[1], img_shape[0]
|
244 |
+
# Project the segments to the reference image
|
245 |
+
segs = segs.clone()
|
246 |
+
eqs = seg_equation(segs)
|
247 |
+
x0, y0 = torch.tensor([1.0, 0, 0.0], device=device), torch.tensor(
|
248 |
+
[0.0, 1, 0], device=device
|
249 |
+
)
|
250 |
+
x0 = x0.repeat(eqs.shape[:-1] + (1,))
|
251 |
+
y0 = y0.repeat(eqs.shape[:-1] + (1,))
|
252 |
+
pt_x0s = torch.cross(eqs, x0, dim=-1)
|
253 |
+
pt_x0s = pt_x0s[..., :-1] / pt_x0s[..., None, -1]
|
254 |
+
pt_x0s_valid = is_inside_img(pt_x0s, img_shape)
|
255 |
+
pt_y0s = torch.cross(eqs, y0, dim=-1)
|
256 |
+
pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1]
|
257 |
+
pt_y0s_valid = is_inside_img(pt_y0s, img_shape)
|
258 |
+
|
259 |
+
xW = torch.tensor([1.0, 0, EPS - w], device=device)
|
260 |
+
yH = torch.tensor([0.0, 1, EPS - h], device=device)
|
261 |
+
xW = xW.repeat(eqs.shape[:-1] + (1,))
|
262 |
+
yH = yH.repeat(eqs.shape[:-1] + (1,))
|
263 |
+
pt_xWs = torch.cross(eqs, xW, dim=-1)
|
264 |
+
pt_xWs = pt_xWs[..., :-1] / pt_xWs[..., None, -1]
|
265 |
+
pt_xWs_valid = is_inside_img(pt_xWs, img_shape)
|
266 |
+
pt_yHs = torch.cross(eqs, yH, dim=-1)
|
267 |
+
pt_yHs = pt_yHs[..., :-1] / pt_yHs[..., None, -1]
|
268 |
+
pt_yHs_valid = is_inside_img(pt_yHs, img_shape)
|
269 |
+
|
270 |
+
# If the X coordinate of the first endpoint is out
|
271 |
+
mask = (segs[..., 0, 0] < 0) & pt_x0s_valid
|
272 |
+
segs[mask, 0, :] = pt_x0s[mask]
|
273 |
+
mask = (segs[..., 0, 0] > (w - 1)) & pt_xWs_valid
|
274 |
+
segs[mask, 0, :] = pt_xWs[mask]
|
275 |
+
# If the X coordinate of the second endpoint is out
|
276 |
+
mask = (segs[..., 1, 0] < 0) & pt_x0s_valid
|
277 |
+
segs[mask, 1, :] = pt_x0s[mask]
|
278 |
+
mask = (segs[:, 1, 0] > (w - 1)) & pt_xWs_valid
|
279 |
+
segs[mask, 1, :] = pt_xWs[mask]
|
280 |
+
# If the Y coordinate of the first endpoint is out
|
281 |
+
mask = (segs[..., 0, 1] < 0) & pt_y0s_valid
|
282 |
+
segs[mask, 0, :] = pt_y0s[mask]
|
283 |
+
mask = (segs[..., 0, 1] > (h - 1)) & pt_yHs_valid
|
284 |
+
segs[mask, 0, :] = pt_yHs[mask]
|
285 |
+
# If the Y coordinate of the second endpoint is out
|
286 |
+
mask = (segs[..., 1, 1] < 0) & pt_y0s_valid
|
287 |
+
segs[mask, 1, :] = pt_y0s[mask]
|
288 |
+
mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid
|
289 |
+
segs[mask, 1, :] = pt_yHs[mask]
|
290 |
+
|
291 |
+
assert (
|
292 |
+
torch.all(segs >= 0)
|
293 |
+
and torch.all(segs[..., 0] < w)
|
294 |
+
and torch.all(segs[..., 1] < h)
|
295 |
+
)
|
296 |
+
return segs
|
297 |
+
|
298 |
+
|
299 |
+
def warp_lines_torch(
|
300 |
+
lines, H, inverse=True, dst_shape: Tuple[int, int] = None
|
301 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
302 |
+
"""
|
303 |
+
:param lines: A tensor of shape (B, N, 2, 2)
|
304 |
+
where B is the batch size, N the number of lines.
|
305 |
+
:param H: The homography used to convert the lines.
|
306 |
+
batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
307 |
+
:param inverse: Whether to apply H or the inverse of H
|
308 |
+
:param dst_shape:If provided, lines are trimmed to be inside the image
|
309 |
+
"""
|
310 |
+
device = lines.device
|
311 |
+
batch_size = len(lines)
|
312 |
+
lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(
|
313 |
+
lines.shape
|
314 |
+
)
|
315 |
+
|
316 |
+
if dst_shape is None:
|
317 |
+
return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device)
|
318 |
+
|
319 |
+
out_img = torch.any(
|
320 |
+
(lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1
|
321 |
+
)
|
322 |
+
valid = ~out_img.all(-1)
|
323 |
+
any_out_of_img = out_img.any(-1)
|
324 |
+
lines_to_trim = valid & any_out_of_img
|
325 |
+
|
326 |
+
for b in range(batch_size):
|
327 |
+
lines_to_trim_mask_b = lines_to_trim[b]
|
328 |
+
lines_to_trim_b = lines[b][lines_to_trim_mask_b]
|
329 |
+
corrected_lines = shrink_segs_to_img(lines_to_trim_b, dst_shape)
|
330 |
+
lines[b][lines_to_trim_mask_b] = corrected_lines
|
331 |
+
|
332 |
+
return lines, valid
|
333 |
+
|
334 |
+
|
335 |
+
# Homography evaluation utils
|
336 |
+
|
337 |
+
|
338 |
+
def sym_homography_error(kpts0, kpts1, T_0to1):
|
339 |
+
kpts0_1 = from_homogeneous(to_homogeneous(kpts0) @ T_0to1.transpose(-1, -2))
|
340 |
+
dist0_1 = ((kpts0_1 - kpts1) ** 2).sum(-1).sqrt()
|
341 |
+
|
342 |
+
kpts1_0 = from_homogeneous(
|
343 |
+
to_homogeneous(kpts1) @ torch.pinverse(T_0to1.transpose(-1, -2))
|
344 |
+
)
|
345 |
+
dist1_0 = ((kpts1_0 - kpts0) ** 2).sum(-1).sqrt()
|
346 |
+
|
347 |
+
return (dist0_1 + dist1_0) / 2.0
|
348 |
+
|
349 |
+
|
350 |
+
def sym_homography_error_all(kpts0, kpts1, H):
|
351 |
+
kp0_1 = warp_points_torch(kpts0, H, inverse=False)
|
352 |
+
kp1_0 = warp_points_torch(kpts1, H, inverse=True)
|
353 |
+
|
354 |
+
# build a distance matrix of size [... x M x N]
|
355 |
+
dist0 = torch.sum((kp0_1.unsqueeze(-2) - kpts1.unsqueeze(-3)) ** 2, -1).sqrt()
|
356 |
+
dist1 = torch.sum((kpts0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1).sqrt()
|
357 |
+
return (dist0 + dist1) / 2.0
|
358 |
+
|
359 |
+
|
360 |
+
def homography_corner_error(T, T_gt, image_size):
|
361 |
+
W, H = image_size[..., 0], image_size[..., 1]
|
362 |
+
corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T)
|
363 |
+
corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2))
|
364 |
+
corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2))
|
365 |
+
d = torch.sqrt(((corners1 - corners1_gt) ** 2).sum(-1))
|
366 |
+
return d.mean(-1)
|
imcui/third_party/MatchAnything/src/utils/metrics.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from collections import OrderedDict
|
5 |
+
from loguru import logger
|
6 |
+
from .homography_utils import warp_points, warp_points_torch
|
7 |
+
from kornia.geometry.epipolar import numeric
|
8 |
+
from kornia.geometry.conversions import convert_points_to_homogeneous
|
9 |
+
import pprint
|
10 |
+
|
11 |
+
|
12 |
+
# --- METRICS ---
|
13 |
+
|
14 |
+
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
|
15 |
+
# angle error between 2 vectors
|
16 |
+
t_gt = T_0to1[:3, 3]
|
17 |
+
n = np.linalg.norm(t) * np.linalg.norm(t_gt)
|
18 |
+
t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
|
19 |
+
t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
|
20 |
+
if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
|
21 |
+
t_err = 0
|
22 |
+
|
23 |
+
# angle error between 2 rotation matrices
|
24 |
+
R_gt = T_0to1[:3, :3]
|
25 |
+
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
|
26 |
+
cos = np.clip(cos, -1., 1.) # handle numercial errors
|
27 |
+
R_err = np.rad2deg(np.abs(np.arccos(cos)))
|
28 |
+
|
29 |
+
return t_err, R_err
|
30 |
+
|
31 |
+
def warp_pts_error(H_est, pts_coord, H_gt=None, pts_gt=None):
|
32 |
+
"""
|
33 |
+
corner_coord: 4*2
|
34 |
+
"""
|
35 |
+
if H_gt is not None:
|
36 |
+
est_warp = warp_points(pts_coord, H_est, False)
|
37 |
+
est_gt = warp_points(pts_coord, H_gt, False)
|
38 |
+
diff = est_warp - est_gt
|
39 |
+
elif pts_gt is not None:
|
40 |
+
est_warp = warp_points(pts_coord, H_est, False)
|
41 |
+
diff = est_warp - pts_gt
|
42 |
+
|
43 |
+
return np.mean(np.linalg.norm(diff, axis=1))
|
44 |
+
|
45 |
+
def homo_warp_match_distance(H_gt, kpts0, kpts1, hw):
|
46 |
+
"""
|
47 |
+
corner_coord: 4*2
|
48 |
+
"""
|
49 |
+
if isinstance(H_gt, np.ndarray):
|
50 |
+
kpts_warped = warp_points(kpts0, H_gt)
|
51 |
+
normalized_distance = np.linalg.norm((kpts_warped - kpts1) / hw[None, [1,0]], axis=1)
|
52 |
+
else:
|
53 |
+
kpts_warped = warp_points_torch(kpts0, H_gt)
|
54 |
+
normalized_distance = torch.linalg.norm((kpts_warped - kpts1) / hw[None, [1,0]], axis=1)
|
55 |
+
return normalized_distance
|
56 |
+
|
57 |
+
def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
|
58 |
+
"""Squared symmetric epipolar distance.
|
59 |
+
This can be seen as a biased estimation of the reprojection error.
|
60 |
+
Args:
|
61 |
+
pts0 (torch.Tensor): [N, 2]
|
62 |
+
E (torch.Tensor): [3, 3]
|
63 |
+
"""
|
64 |
+
pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
|
65 |
+
pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
|
66 |
+
pts0 = convert_points_to_homogeneous(pts0)
|
67 |
+
pts1 = convert_points_to_homogeneous(pts1)
|
68 |
+
|
69 |
+
Ep0 = pts0 @ E.T # [N, 3]
|
70 |
+
p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
|
71 |
+
Etp1 = pts1 @ E # [N, 3]
|
72 |
+
|
73 |
+
d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
|
74 |
+
return d
|
75 |
+
|
76 |
+
|
77 |
+
def compute_symmetrical_epipolar_errors(data, config):
|
78 |
+
"""
|
79 |
+
Update:
|
80 |
+
data (dict):{"epi_errs": [M]}
|
81 |
+
"""
|
82 |
+
Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
|
83 |
+
E_mat = Tx @ data['T_0to1'][:, :3, :3]
|
84 |
+
|
85 |
+
m_bids = data['m_bids']
|
86 |
+
pts0 = data['mkpts0_f']
|
87 |
+
pts1 = data['mkpts1_f'].clone().detach()
|
88 |
+
|
89 |
+
if config.LOFTR.FINE.MTD_SPVS:
|
90 |
+
m_bids = data['m_bids_f'] if 'm_bids_f' in data else data['m_bids']
|
91 |
+
epi_errs = []
|
92 |
+
for bs in range(Tx.size(0)):
|
93 |
+
mask = m_bids == bs
|
94 |
+
epi_errs.append(
|
95 |
+
symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
|
96 |
+
epi_errs = torch.cat(epi_errs, dim=0)
|
97 |
+
|
98 |
+
data.update({'epi_errs': epi_errs})
|
99 |
+
|
100 |
+
def compute_homo_match_warp_errors(data, config):
|
101 |
+
"""
|
102 |
+
Update:
|
103 |
+
data (dict):{"epi_errs": [M]}
|
104 |
+
"""
|
105 |
+
|
106 |
+
homography_gt = data['homography']
|
107 |
+
m_bids = data['m_bids']
|
108 |
+
pts0 = data['mkpts0_f']
|
109 |
+
pts1 = data['mkpts1_f']
|
110 |
+
origin_img0_size = data['origin_img_size0']
|
111 |
+
|
112 |
+
if config.LOFTR.FINE.MTD_SPVS:
|
113 |
+
m_bids = data['m_bids_f'] if 'm_bids_f' in data else data['m_bids']
|
114 |
+
epi_errs = []
|
115 |
+
for bs in range(homography_gt.shape[0]):
|
116 |
+
mask = m_bids == bs
|
117 |
+
epi_errs.append(
|
118 |
+
homo_warp_match_distance(homography_gt[bs], pts0[mask], pts1[mask], origin_img0_size[bs]))
|
119 |
+
epi_errs = torch.cat(epi_errs, dim=0)
|
120 |
+
|
121 |
+
data.update({'epi_errs': epi_errs})
|
122 |
+
|
123 |
+
|
124 |
+
def compute_symmetrical_epipolar_errors_gt(data, config):
|
125 |
+
"""
|
126 |
+
Update:
|
127 |
+
data (dict):{"epi_errs": [M]}
|
128 |
+
"""
|
129 |
+
Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
|
130 |
+
E_mat = Tx @ data['T_0to1'][:, :3, :3]
|
131 |
+
|
132 |
+
m_bids = data['m_bids']
|
133 |
+
pts0 = data['mkpts0_f_gt']
|
134 |
+
pts1 = data['mkpts1_f_gt']
|
135 |
+
|
136 |
+
epi_errs = []
|
137 |
+
for bs in range(Tx.size(0)):
|
138 |
+
# mask = m_bids == bs
|
139 |
+
assert bs == 0
|
140 |
+
mask = torch.tensor([True]*pts0.shape[0], device = pts0.device)
|
141 |
+
if config.LOFTR.FINE.MTD_SPVS:
|
142 |
+
epi_errs.append(
|
143 |
+
symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
|
144 |
+
else:
|
145 |
+
epi_errs.append(
|
146 |
+
symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
|
147 |
+
epi_errs = torch.cat(epi_errs, dim=0)
|
148 |
+
|
149 |
+
data.update({'epi_errs': epi_errs})
|
150 |
+
|
151 |
+
|
152 |
+
def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
|
153 |
+
if len(kpts0) < 5:
|
154 |
+
return None
|
155 |
+
# normalize keypoints
|
156 |
+
kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
|
157 |
+
kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
|
158 |
+
|
159 |
+
# normalize ransac threshold
|
160 |
+
ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
|
161 |
+
|
162 |
+
# compute pose with cv2
|
163 |
+
E, mask = cv2.findEssentialMat(
|
164 |
+
kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC)
|
165 |
+
if E is None:
|
166 |
+
print("\nE is None while trying to recover pose.\n")
|
167 |
+
return None
|
168 |
+
|
169 |
+
# recover pose from E
|
170 |
+
best_num_inliers = 0
|
171 |
+
ret = None
|
172 |
+
for _E in np.split(E, len(E) / 3):
|
173 |
+
n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
|
174 |
+
if n > best_num_inliers:
|
175 |
+
ret = (R, t[:, 0], mask.ravel() > 0)
|
176 |
+
best_num_inliers = n
|
177 |
+
|
178 |
+
return ret
|
179 |
+
|
180 |
+
def estimate_homo(kpts0, kpts1, thresh, conf=0.99999, mode='affine'):
|
181 |
+
if mode == 'affine':
|
182 |
+
H_est, inliers = cv2.estimateAffine2D(kpts0, kpts1, ransacReprojThreshold=thresh, confidence=conf, method=cv2.RANSAC)
|
183 |
+
if H_est is None:
|
184 |
+
return np.eye(3) * 0, np.empty((0))
|
185 |
+
H_est = np.concatenate([H_est, np.array([[0, 0, 1]])], axis=0) # 3 * 3
|
186 |
+
elif mode == 'homo':
|
187 |
+
H_est, inliers = cv2.findHomography(kpts0, kpts1, method=cv2.LMEDS, ransacReprojThreshold=thresh)
|
188 |
+
if H_est is None:
|
189 |
+
return np.eye(3) * 0, np.empty((0))
|
190 |
+
|
191 |
+
return H_est, inliers
|
192 |
+
|
193 |
+
def compute_homo_corner_warp_errors(data, config):
|
194 |
+
"""
|
195 |
+
Update:
|
196 |
+
data (dict):{
|
197 |
+
"R_errs" List[float]: [N] # Actually warp error
|
198 |
+
"t_errs" List[float]: [N] # Zero, place holder
|
199 |
+
"inliers" List[np.ndarray]: [N]
|
200 |
+
}
|
201 |
+
"""
|
202 |
+
pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
|
203 |
+
conf = config.TRAINER.RANSAC_CONF # 0.99999
|
204 |
+
data.update({'R_errs': [], 't_errs': [], 'inliers': []})
|
205 |
+
|
206 |
+
if config.LOFTR.FINE.MTD_SPVS:
|
207 |
+
m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy()
|
208 |
+
|
209 |
+
else:
|
210 |
+
m_bids = data['m_bids'].cpu().numpy()
|
211 |
+
pts0 = data['mkpts0_f'].cpu().numpy()
|
212 |
+
pts1 = data['mkpts1_f'].cpu().numpy()
|
213 |
+
homography_gt = data['homography'].cpu().numpy()
|
214 |
+
origin_size_0 = data['origin_img_size0'].cpu().numpy()
|
215 |
+
|
216 |
+
for bs in range(homography_gt.shape[0]):
|
217 |
+
mask = m_bids == bs
|
218 |
+
ret = estimate_homo(pts0[mask], pts1[mask], pixel_thr, conf=conf)
|
219 |
+
|
220 |
+
if ret is None:
|
221 |
+
data['R_errs'].append(np.inf)
|
222 |
+
data['t_errs'].append(np.inf)
|
223 |
+
data['inliers'].append(np.array([]).astype(bool))
|
224 |
+
else:
|
225 |
+
H_est, inliers = ret
|
226 |
+
corner_coord = np.array([[0, 0], [0, origin_size_0[bs][0]], [origin_size_0[bs][1], 0], [origin_size_0[bs][1], origin_size_0[bs][0]]])
|
227 |
+
corner_warp_distance = warp_pts_error(H_est, corner_coord, H_gt=homography_gt[bs])
|
228 |
+
data['R_errs'].append(corner_warp_distance)
|
229 |
+
data['t_errs'].append(0)
|
230 |
+
data['inliers'].append(inliers)
|
231 |
+
|
232 |
+
def compute_warp_control_pts_errors(data, config):
|
233 |
+
"""
|
234 |
+
Update:
|
235 |
+
data (dict):{
|
236 |
+
"R_errs" List[float]: [N] # Actually warp error
|
237 |
+
"t_errs" List[float]: [N] # Zero, place holder
|
238 |
+
"inliers" List[np.ndarray]: [N]
|
239 |
+
}
|
240 |
+
"""
|
241 |
+
pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
|
242 |
+
conf = config.TRAINER.RANSAC_CONF # 0.99999
|
243 |
+
data.update({'R_errs': [], 't_errs': [], 'inliers': []})
|
244 |
+
|
245 |
+
if config.LOFTR.FINE.MTD_SPVS:
|
246 |
+
m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy()
|
247 |
+
|
248 |
+
else:
|
249 |
+
m_bids = data['m_bids'].cpu().numpy()
|
250 |
+
pts0 = data['mkpts0_f'].cpu().numpy()
|
251 |
+
pts1 = data['mkpts1_f'].cpu().numpy()
|
252 |
+
gt_2D_matches = data["gt_2D_matches"].cpu().numpy()
|
253 |
+
|
254 |
+
data.update({'epi_errs': torch.zeros(m_bids.shape[0])})
|
255 |
+
for bs in range(gt_2D_matches.shape[0]):
|
256 |
+
mask = m_bids == bs
|
257 |
+
ret = estimate_homo(pts0[mask], pts1[mask], pixel_thr, conf=conf, mode=config.TRAINER.WARP_ESTIMATOR_MODEL)
|
258 |
+
|
259 |
+
if ret is None:
|
260 |
+
data['R_errs'].append(np.inf)
|
261 |
+
data['t_errs'].append(np.inf)
|
262 |
+
data['inliers'].append(np.array([]).astype(bool))
|
263 |
+
else:
|
264 |
+
H_est, inliers = ret
|
265 |
+
img0_pts, img1_pts = gt_2D_matches[bs][:, :2], gt_2D_matches[bs][:, 2:]
|
266 |
+
pts_warp_distance = warp_pts_error(H_est, img0_pts, pts_gt=img1_pts)
|
267 |
+
print(pts_warp_distance)
|
268 |
+
data['R_errs'].append(pts_warp_distance)
|
269 |
+
data['t_errs'].append(0)
|
270 |
+
data['inliers'].append(inliers)
|
271 |
+
|
272 |
+
def compute_pose_errors(data, config):
|
273 |
+
"""
|
274 |
+
Update:
|
275 |
+
data (dict):{
|
276 |
+
"R_errs" List[float]: [N]
|
277 |
+
"t_errs" List[float]: [N]
|
278 |
+
"inliers" List[np.ndarray]: [N]
|
279 |
+
}
|
280 |
+
"""
|
281 |
+
pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
|
282 |
+
conf = config.TRAINER.RANSAC_CONF # 0.99999
|
283 |
+
data.update({'R_errs': [], 't_errs': [], 'inliers': []})
|
284 |
+
|
285 |
+
if config.LOFTR.FINE.MTD_SPVS:
|
286 |
+
m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy()
|
287 |
+
|
288 |
+
else:
|
289 |
+
m_bids = data['m_bids'].cpu().numpy()
|
290 |
+
pts0 = data['mkpts0_f'].cpu().numpy()
|
291 |
+
pts1 = data['mkpts1_f'].cpu().numpy()
|
292 |
+
K0 = data['K0'].cpu().numpy()
|
293 |
+
K1 = data['K1'].cpu().numpy()
|
294 |
+
T_0to1 = data['T_0to1'].cpu().numpy()
|
295 |
+
|
296 |
+
for bs in range(K0.shape[0]):
|
297 |
+
mask = m_bids == bs
|
298 |
+
if config.LOFTR.EVAL_TIMES >= 1:
|
299 |
+
bpts0, bpts1 = pts0[mask], pts1[mask]
|
300 |
+
R_list, T_list, inliers_list = [], [], []
|
301 |
+
for _ in range(5):
|
302 |
+
shuffling = np.random.permutation(np.arange(len(bpts0)))
|
303 |
+
if _ >= config.LOFTR.EVAL_TIMES:
|
304 |
+
continue
|
305 |
+
bpts0 = bpts0[shuffling]
|
306 |
+
bpts1 = bpts1[shuffling]
|
307 |
+
|
308 |
+
ret = estimate_pose(bpts0, bpts1, K0[bs], K1[bs], pixel_thr, conf=conf)
|
309 |
+
if ret is None:
|
310 |
+
R_list.append(np.inf)
|
311 |
+
T_list.append(np.inf)
|
312 |
+
inliers_list.append(np.array([]).astype(bool))
|
313 |
+
print('Pose error: inf')
|
314 |
+
else:
|
315 |
+
R, t, inliers = ret
|
316 |
+
t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
|
317 |
+
R_list.append(R_err)
|
318 |
+
T_list.append(t_err)
|
319 |
+
inliers_list.append(inliers)
|
320 |
+
print(f'Pose error: {max(R_err, t_err)}')
|
321 |
+
R_err_mean = np.array(R_list).mean()
|
322 |
+
T_err_mean = np.array(T_list).mean()
|
323 |
+
# inliers_mean = np.array(inliers_list).mean()
|
324 |
+
|
325 |
+
data['R_errs'].append(R_list)
|
326 |
+
data['t_errs'].append(T_list)
|
327 |
+
data['inliers'].append(inliers_list[0])
|
328 |
+
|
329 |
+
else:
|
330 |
+
ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
|
331 |
+
|
332 |
+
if ret is None:
|
333 |
+
data['R_errs'].append(np.inf)
|
334 |
+
data['t_errs'].append(np.inf)
|
335 |
+
data['inliers'].append(np.array([]).astype(bool))
|
336 |
+
print('Pose error: inf')
|
337 |
+
else:
|
338 |
+
R, t, inliers = ret
|
339 |
+
t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
|
340 |
+
data['R_errs'].append(R_err)
|
341 |
+
data['t_errs'].append(t_err)
|
342 |
+
data['inliers'].append(inliers)
|
343 |
+
print(f'Pose error: {max(R_err, t_err)}')
|
344 |
+
|
345 |
+
|
346 |
+
# --- METRIC AGGREGATION ---
|
347 |
+
def error_rmse(error):
|
348 |
+
squard_errors = np.square(error) # N * 2
|
349 |
+
mse = np.mean(np.sum(squard_errors, axis=1))
|
350 |
+
rmse = np.sqrt(mse)
|
351 |
+
return rmse
|
352 |
+
|
353 |
+
def error_mae(error):
|
354 |
+
abs_diff = np.abs(error) # N * 2
|
355 |
+
absolute_errors = np.sum(abs_diff, axis=1)
|
356 |
+
|
357 |
+
# Return the maximum absolute error
|
358 |
+
mae = np.max(absolute_errors)
|
359 |
+
return mae
|
360 |
+
|
361 |
+
def error_auc(errors, thresholds, method='exact_auc'):
|
362 |
+
"""
|
363 |
+
Args:
|
364 |
+
errors (list): [N,]
|
365 |
+
thresholds (list)
|
366 |
+
"""
|
367 |
+
if method == 'exact_auc':
|
368 |
+
errors = [0] + sorted(list(errors))
|
369 |
+
recall = list(np.linspace(0, 1, len(errors)))
|
370 |
+
|
371 |
+
aucs = []
|
372 |
+
for thr in thresholds:
|
373 |
+
last_index = np.searchsorted(errors, thr)
|
374 |
+
y = recall[:last_index] + [recall[last_index-1]]
|
375 |
+
x = errors[:last_index] + [thr]
|
376 |
+
aucs.append(np.trapz(y, x) / thr)
|
377 |
+
return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
|
378 |
+
elif method == 'fire_paper':
|
379 |
+
aucs = []
|
380 |
+
for threshold in thresholds:
|
381 |
+
accum_error = 0
|
382 |
+
percent_error_below = np.zeros(threshold + 1)
|
383 |
+
for i in range(1, threshold + 1):
|
384 |
+
percent_error_below[i] = np.sum(errors < i) * 100 / len(errors)
|
385 |
+
accum_error += percent_error_below[i]
|
386 |
+
|
387 |
+
aucs.append(accum_error / (threshold * 100))
|
388 |
+
|
389 |
+
return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
|
390 |
+
elif method == 'success_rate':
|
391 |
+
aucs = []
|
392 |
+
for threshold in thresholds:
|
393 |
+
aucs.append((errors < threshold).astype(float).mean())
|
394 |
+
return {f'SR@{t}': auc for t, auc in zip(thresholds, aucs)}
|
395 |
+
else:
|
396 |
+
raise NotImplementedError
|
397 |
+
|
398 |
+
|
399 |
+
def epidist_prec(errors, thresholds, ret_dict=False):
|
400 |
+
precs = []
|
401 |
+
for thr in thresholds:
|
402 |
+
prec_ = []
|
403 |
+
for errs in errors:
|
404 |
+
correct_mask = errs < thr
|
405 |
+
prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
|
406 |
+
precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
|
407 |
+
if ret_dict:
|
408 |
+
return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
|
409 |
+
else:
|
410 |
+
return precs
|
411 |
+
|
412 |
+
|
413 |
+
def aggregate_metrics(metrics, epi_err_thr=5e-4, eval_n_time=1, threshold=[5, 10, 20], method='exact_auc'):
|
414 |
+
""" Aggregate metrics for the whole dataset:
|
415 |
+
(This method should be called once per dataset)
|
416 |
+
1. AUC of the pose error (angular) at the threshold [5, 10, 20]
|
417 |
+
2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
|
418 |
+
"""
|
419 |
+
# filter duplicates
|
420 |
+
unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
|
421 |
+
unq_ids = list(unq_ids.values())
|
422 |
+
logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
|
423 |
+
|
424 |
+
# pose auc
|
425 |
+
angular_thresholds = threshold
|
426 |
+
if eval_n_time >= 1:
|
427 |
+
pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0).reshape(-1, eval_n_time)[unq_ids].reshape(-1)
|
428 |
+
else:
|
429 |
+
pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
|
430 |
+
logger.info('num of pose_errors: {}'.format(pose_errors.shape))
|
431 |
+
aucs = error_auc(pose_errors, angular_thresholds, method=method) # (auc@5, auc@10, auc@20)
|
432 |
+
|
433 |
+
if eval_n_time >= 1:
|
434 |
+
for i in range(eval_n_time):
|
435 |
+
aucs_i = error_auc(pose_errors.reshape(-1, eval_n_time)[:,i], angular_thresholds, method=method)
|
436 |
+
logger.info('\n' + f'results of {i}-th RANSAC' + pprint.pformat(aucs_i))
|
437 |
+
# matching precision
|
438 |
+
dist_thresholds = [epi_err_thr]
|
439 |
+
precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr)
|
440 |
+
|
441 |
+
u_num_mathces = np.array(metrics['num_matches'], dtype=object)[unq_ids]
|
442 |
+
u_percent_inliers = np.array(metrics['percent_inliers'], dtype=object)[unq_ids]
|
443 |
+
num_matches = {f'num_matches': u_num_mathces.mean() }
|
444 |
+
percent_inliers = {f'percent_inliers': u_percent_inliers.mean()}
|
445 |
+
return {**aucs, **precs, **num_matches, **percent_inliers}
|
imcui/third_party/MatchAnything/src/utils/misc.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import contextlib
|
3 |
+
import joblib
|
4 |
+
from typing import Union
|
5 |
+
from loguru import _Logger, logger
|
6 |
+
from itertools import chain
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from yacs.config import CfgNode as CN
|
10 |
+
from pytorch_lightning.utilities import rank_zero_only
|
11 |
+
|
12 |
+
|
13 |
+
def lower_config(yacs_cfg):
|
14 |
+
if not isinstance(yacs_cfg, CN):
|
15 |
+
return yacs_cfg
|
16 |
+
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
|
17 |
+
|
18 |
+
|
19 |
+
def upper_config(dict_cfg):
|
20 |
+
if not isinstance(dict_cfg, dict):
|
21 |
+
return dict_cfg
|
22 |
+
return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
|
23 |
+
|
24 |
+
|
25 |
+
def log_on(condition, message, level):
|
26 |
+
if condition:
|
27 |
+
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
|
28 |
+
logger.log(level, message)
|
29 |
+
|
30 |
+
|
31 |
+
def get_rank_zero_only_logger(logger: _Logger):
|
32 |
+
if rank_zero_only.rank == 0:
|
33 |
+
return logger
|
34 |
+
else:
|
35 |
+
for _level in logger._core.levels.keys():
|
36 |
+
level = _level.lower()
|
37 |
+
setattr(logger, level,
|
38 |
+
lambda x: None)
|
39 |
+
logger._log = lambda x: None
|
40 |
+
return logger
|
41 |
+
|
42 |
+
|
43 |
+
def setup_gpus(gpus: Union[str, int]) -> int:
|
44 |
+
""" A temporary fix for pytorch-lighting 1.3.x """
|
45 |
+
gpus = str(gpus)
|
46 |
+
gpu_ids = []
|
47 |
+
|
48 |
+
if ',' not in gpus:
|
49 |
+
n_gpus = int(gpus)
|
50 |
+
return n_gpus if n_gpus != -1 else torch.cuda.device_count()
|
51 |
+
else:
|
52 |
+
gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
|
53 |
+
|
54 |
+
# setup environment variables
|
55 |
+
visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
|
56 |
+
if visible_devices is None:
|
57 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
58 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
|
59 |
+
visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
|
60 |
+
logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
|
61 |
+
else:
|
62 |
+
logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
|
63 |
+
return len(gpu_ids)
|
64 |
+
|
65 |
+
|
66 |
+
def flattenList(x):
|
67 |
+
return list(chain(*x))
|
68 |
+
|
69 |
+
|
70 |
+
@contextlib.contextmanager
|
71 |
+
def tqdm_joblib(tqdm_object):
|
72 |
+
"""Context manager to patch joblib to report into tqdm progress bar given as argument
|
73 |
+
|
74 |
+
Usage:
|
75 |
+
with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
|
76 |
+
Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
|
77 |
+
|
78 |
+
When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
|
79 |
+
ret_vals = Parallel(n_jobs=args.world_size)(
|
80 |
+
delayed(lambda x: _compute_cov_score(pid, *x))(param)
|
81 |
+
for param in tqdm(combinations(image_ids, 2),
|
82 |
+
desc=f'Computing cov_score of [{pid}]',
|
83 |
+
total=len(image_ids)*(len(image_ids)-1)/2))
|
84 |
+
Src: https://stackoverflow.com/a/58936697
|
85 |
+
"""
|
86 |
+
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
|
87 |
+
def __init__(self, *args, **kwargs):
|
88 |
+
super().__init__(*args, **kwargs)
|
89 |
+
|
90 |
+
def __call__(self, *args, **kwargs):
|
91 |
+
tqdm_object.update(n=self.batch_size)
|
92 |
+
return super().__call__(*args, **kwargs)
|
93 |
+
|
94 |
+
old_batch_callback = joblib.parallel.BatchCompletionCallBack
|
95 |
+
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
|
96 |
+
try:
|
97 |
+
yield tqdm_object
|
98 |
+
finally:
|
99 |
+
joblib.parallel.BatchCompletionCallBack = old_batch_callback
|
100 |
+
tqdm_object.close()
|
101 |
+
|