Theo Viel
commited on
Commit
·
98a67a0
1
Parent(s):
694c514
add weights and code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- checkpoints/charset.txt +1 -0
- checkpoints/detector.pth +3 -0
- checkpoints/recognizer.pth +3 -0
- checkpoints/relational.pth +3 -0
- example.py +43 -0
- nemo-retriever-ocr/cpp/.gitattributes +1 -0
- nemo-retriever-ocr/cpp/.gitignore +6 -0
- nemo-retriever-ocr/cpp/.gitmodules +3 -0
- nemo-retriever-ocr/cpp/README.md +15 -0
- nemo-retriever-ocr/cpp/beam_decode/beam_decode.cpp +460 -0
- nemo-retriever-ocr/cpp/beam_decode/beam_decode.h +18 -0
- nemo-retriever-ocr/cpp/beam_decode/kn_lm.cpp +86 -0
- nemo-retriever-ocr/cpp/beam_decode/kn_lm.h +27 -0
- nemo-retriever-ocr/cpp/beam_decode/language_model.cpp +147 -0
- nemo-retriever-ocr/cpp/beam_decode/language_model.h +66 -0
- nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.cpp +7 -0
- nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.h +54 -0
- nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.cpp +330 -0
- nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.h +80 -0
- nemo-retriever-ocr/cpp/beam_decode/prefix.cpp +23 -0
- nemo-retriever-ocr/cpp/beam_decode/prefix.h +158 -0
- nemo-retriever-ocr/cpp/beam_decode/sbo_lm.cpp +47 -0
- nemo-retriever-ocr/cpp/beam_decode/sbo_lm.h +21 -0
- nemo-retriever-ocr/cpp/better_grid_sample/cpu_indirect_grid_sample.cpp +94 -0
- nemo-retriever-ocr/cpp/better_grid_sample/gpu_grid_sample_utils.cuh +42 -0
- nemo-retriever-ocr/cpp/better_grid_sample/gpu_indirect_grid_sample.cu +328 -0
- nemo-retriever-ocr/cpp/better_grid_sample/grid_sample.h +67 -0
- nemo-retriever-ocr/cpp/common.cpp +13 -0
- nemo-retriever-ocr/cpp/common.h +58 -0
- nemo-retriever-ocr/cpp/cuda_intellisense.cuh +51 -0
- nemo-retriever-ocr/cpp/geometry.h +1101 -0
- nemo-retriever-ocr/cpp/geometry_api/calc_poly_min_rrect.cpp +165 -0
- nemo-retriever-ocr/cpp/geometry_api/geometry_api.cpp +101 -0
- nemo-retriever-ocr/cpp/geometry_api/geometry_api.h +16 -0
- nemo-retriever-ocr/cpp/geometry_api/geometry_api_common.h +121 -0
- nemo-retriever-ocr/cpp/geometry_api/geometry_api_gpu.cu +142 -0
- nemo-retriever-ocr/cpp/geometry_api/get_rel_continuation_cos.cpp +60 -0
- nemo-retriever-ocr/cpp/geometry_api/matrix2x2.h +93 -0
- nemo-retriever-ocr/cpp/geometry_api/poly_bounds_quad.cpp +61 -0
- nemo-retriever-ocr/cpp/graph_detection/encode_util.cpp +272 -0
- nemo-retriever-ocr/cpp/graph_detection/encode_util.h +184 -0
- nemo-retriever-ocr/cpp/half_ops.cu +5 -0
- nemo-retriever-ocr/cpp/half_ops.cuh +149 -0
- nemo-retriever-ocr/cpp/local_ips/local_ips.h +11 -0
- nemo-retriever-ocr/cpp/local_ips/quad_all_2_all_dist_v2.cu +162 -0
- nemo-retriever-ocr/cpp/module.cpp +125 -0
- nemo-retriever-ocr/cpp/non_maximal_suppression/cpu_non_maximal_suppression.cpp +209 -0
- nemo-retriever-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu +1720 -0
- nemo-retriever-ocr/cpp/non_maximal_suppression/nms_common.h +227 -0
- nemo-retriever-ocr/cpp/non_maximal_suppression/nms_kd_tree.h +449 -0
checkpoints/charset.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[" ", "!", "\"", "#", "$", "%", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ":", ";", "<", "=", ">", "?", "@", "A", "B", "C", "D", "E", "F", "FI", "G", "H", "I", "I\u0307", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "SS", "T", "U", "V", "W", "X", "Y", "Z", "[", "\\", "]", "^", "_", "`", "a", "b", "c", "d", "e", "f", "fi", "g", "h", "i", "i\u0307", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "ss", "t", "u", "v", "w", "x", "y", "z", "{", "|", "}", "~", "\u00b2", "\u00b3", "\u00b5", "\u00b9", "\u00ba", "\u00c0", "\u00c1", "\u00c2", "\u00c3", "\u00c4", "\u00c5", "\u00c6", "\u00c7", "\u00c8", "\u00c9", "\u00ca", "\u00cb", "\u00cc", "\u00cd", "\u00ce", "\u00cf", "\u00d0", "\u00d1", "\u00d2", "\u00d3", "\u00d4", "\u00d5", "\u00d6", "\u00d8", "\u00d9", "\u00da", "\u00db", "\u00dc", "\u00dd", "\u00de", "\u00df", "\u00e0", "\u00e1", "\u00e2", "\u00e3", "\u00e4", "\u00e5", "\u00e6", "\u00e7", "\u00e8", "\u00e9", "\u00ea", "\u00eb", "\u00ec", "\u00ed", "\u00ee", "\u00ef", "\u00f0", "\u00f1", "\u00f2", "\u00f3", "\u00f4", "\u00f5", "\u00f6", "\u00f8", "\u00f9", "\u00fa", "\u00fb", "\u00fc", "\u00fd", "\u00fe", "\u00ff", "\u0100", "\u0101", "\u0102", "\u0103", "\u0104", "\u0105", "\u0106", "\u0107", "\u010c", "\u010d", "\u010e", "\u010f", "\u0110", "\u0111", "\u0112", "\u0113", "\u0116", "\u0117", "\u0118", "\u0119", "\u011a", "\u011b", "\u011e", "\u011f", "\u0120", "\u0121", "\u0126", "\u0127", "\u0128", "\u0129", "\u012a", "\u012b", "\u0130", "\u0131", "\u0136", "\u0137", "\u013d", "\u013e", "\u0141", "\u0142", "\u0143", "\u0144", "\u0145", "\u0146", "\u0147", "\u0148", "\u014a", "\u014b", "\u014c", "\u014d", "\u014e", "\u014f", "\u0150", "\u0151", "\u0152", "\u0153", "\u0158", "\u0159", "\u015a", "\u015b", "\u015e", "\u015f", "\u0160", "\u0161", "\u0162", "\u0163", "\u0164", "\u0165", "\u0168", "\u0169", "\u016a", "\u016b", "\u016c", "\u016d", "\u016e", "\u016f", "\u0172", "\u0173", "\u0174", "\u0175", "\u0176", "\u0177", "\u0178", "\u0179", "\u017a", "\u017b", "\u017c", "\u017d", "\u017e", "\u0181", "\u0186", "\u0189", "\u018a", "\u018f", "\u0190", "\u0191", "\u0192", "\u0194", "\u0197", "\u019c", "\u019d", "\u019f", "\u01a0", "\u01a1", "\u01a6", "\u01a9", "\u01ae", "\u01af", "\u01b0", "\u01b1", "\u01b2", "\u01b7", "\u01c2", "\u01cd", "\u01ce", "\u01cf", "\u01d0", "\u01d1", "\u01d2", "\u01d3", "\u01d4", "\u01ea", "\u01eb", "\u0218", "\u0219", "\u021a", "\u021b", "\u0245", "\u0250", "\u0251", "\u0252", "\u0253", "\u0254", "\u0255", "\u0256", "\u0257", "\u0259", "\u025b", "\u025f", "\u0261", "\u0262", "\u0263", "\u0266", "\u0267", "\u0268", "\u026a", "\u026c", "\u026f", "\u0272", "\u0274", "\u0275", "\u0278", "\u027b", "\u027e", "\u0280", "\u0281", "\u0282", "\u0283", "\u0287", "\u0288", "\u028a", "\u028b", "\u028c", "\u028d", "\u028e", "\u0292", "\u0294", "\u0295", "\u0298", "\u029d", "\u029f", "\u02b0", "\u02b2", "\u02b7", "\u02bb", "\u02bc", "\u02be", "\u02bf", "\u02c0", "\u02c1", "\u02c8", "\u02cc", "\u02d0", "\u02e0", "\u02e4", "\u0386", "\u0388", "\u038a", "\u038c", "\u038e", "\u038f", "\u0391", "\u0391\u0342", "\u0392", "\u0393", "\u0394", "\u0395", "\u0396", "\u0397", "\u0397\u0342", "\u0398", "\u0399", "\u0399\u0342", "\u039a", "\u039b", "\u039c", "\u039d", "\u039e", "\u039f", "\u03a0", "\u03a1", "\u03a3", "\u03a4", "\u03a5", "\u03a5\u0313", "\u03a5\u0342", "\u03a6", "\u03a7", "\u03a8", "\u03a9", "\u03a9\u0342", "\u03a9\u0342\u0399", "\u03ac", "\u03ad", "\u03af", "\u03b1", "\u03b1\u0342", "\u03b2", "\u03b3", "\u03b4", "\u03b5", "\u03b6", "\u03b7", "\u03b7\u0342", "\u03b8", "\u03b9", "\u03b9\u0342", "\u03ba", "\u03bb", "\u03bc", "\u03bd", "\u03be", "\u03bf", "\u03c0", "\u03c1", "\u03c2", "\u03c3", "\u03c4", "\u03c5", "\u03c5\u0313", "\u03c5\u0342", "\u03c6", "\u03c7", "\u03c8", "\u03c9", "\u03c9\u0342", "\u03c9\u0342\u03b9", "\u03cc", "\u03cd", "\u03ce", "\u03d5", "\u0401", "\u0406", "\u0408", "\u0410", "\u0411", "\u0412", "\u0413", "\u0414", "\u0415", "\u0416", "\u0417", "\u0418", "\u0419", "\u041a", "\u041b", "\u041c", "\u041d", "\u041e", "\u041f", "\u0420", "\u0421", "\u0422", "\u0423", "\u0425", "\u0426", "\u0427", "\u0428", "\u042a", "\u042b", "\u042c", "\u042d", "\u042e", "\u042f", "\u0430", "\u0431", "\u0432", "\u0433", "\u0434", "\u0435", "\u0436", "\u0437", "\u0438", "\u0439", "\u043a", "\u043b", "\u043c", "\u043d", "\u043e", "\u043f", "\u0440", "\u0441", "\u0442", "\u0443", "\u0445", "\u0446", "\u0447", "\u0448", "\u044a", "\u044b", "\u044c", "\u044d", "\u044e", "\u044f", "\u0451", "\u0456", "\u0458", "\u05b5", "\u05b6", "\u05bc", "\u05d0", "\u05d1", "\u05d2", "\u05d3", "\u05d5", "\u05d7", "\u05d9", "\u05dc", "\u05dd", "\u05de", "\u05e0", "\u05e1", "\u05e2", "\u05e6", "\u05e8", "\u05e9", "\u05ea", "\u0621", "\u0623", "\u0625", "\u0627", "\u0628", "\u0629", "\u062a", "\u062c", "\u062d", "\u062e", "\u062f", "\u0631", "\u0632", "\u0633", "\u0634", "\u0635", "\u0637", "\u0639", "\u063a", "\u0641", "\u0642", "\u0643", "\u0644", "\u0645", "\u0646", "\u0647", "\u0648", "\u064a", "\u06cc", "\u0902", "\u0905", "\u0906", "\u0909", "\u0915", "\u0917", "\u091f", "\u0921", "\u0924", "\u0926", "\u0928", "\u092a", "\u092c", "\u092d", "\u092e", "\u092f", "\u0930", "\u0932", "\u0936", "\u0937", "\u0938", "\u0939", "\u093e", "\u093f", "\u0940", "\u0947", "\u094b", "\u0995", "\u09a4", "\u09b2", "\u09be", "\u09bf", "\u0b95", "\u0ba9", "\u0bb3", "\u0e02", "\u0e07", "\u0e08", "\u0e0a", "\u0e10", "\u0e15", "\u0e17", "\u0e19", "\u0e1b", "\u0e1e", "\u0e23", "\u0e27", "\u0e30", "\u0e31", "\u0e32", "\u0e40", "\u0e41", "\u16c3", "\u16cb", "\u16df", "\u1e0c", "\u1e0d", "\u1e24", "\u1e25", "\u1e36", "\u1e37", "\u1e3a", "\u1e3b", "\u1e42", "\u1e43", "\u1e44", "\u1e45", "\u1e46", "\u1e47", "\u1e48", "\u1e49", "\u1e5a", "\u1e5b", "\u1e5e", "\u1e5f", "\u1e62", "\u1e63", "\u1e6c", "\u1e6d", "\u1e6e", "\u1e6f", "\u1ea0", "\u1ea1", "\u1ea2", "\u1ea3", "\u1ea4", "\u1ea5", "\u1ea6", "\u1ea7", "\u1ea8", "\u1ea9", "\u1eaa", "\u1eab", "\u1eac", "\u1ead", "\u1eae", "\u1eaf", "\u1eb4", "\u1eb5", "\u1eb6", "\u1eb7", "\u1eb8", "\u1eb9", "\u1ebe", "\u1ebf", "\u1ec2", "\u1ec3", "\u1ec4", "\u1ec5", "\u1ec6", "\u1ec7", "\u1eca", "\u1ecb", "\u1ecc", "\u1ecd", "\u1ece", "\u1ecf", "\u1ed0", "\u1ed1", "\u1ed2", "\u1ed3", "\u1ed4", "\u1ed5", "\u1ed6", "\u1ed7", "\u1ed8", "\u1ed9", "\u1eda", "\u1edb", "\u1edc", "\u1edd", "\u1ede", "\u1edf", "\u1ee2", "\u1ee3", "\u1ee4", "\u1ee5", "\u1ee6", "\u1ee7", "\u1ee8", "\u1ee9", "\u1eea", "\u1eeb", "\u1eec", "\u1eed", "\u1eee", "\u1eef", "\u1ef0", "\u1ef1", "\u1ef2", "\u1ef3", "\u1ef4", "\u1ef5", "\u1ef8", "\u1ef9", "\u1f00", "\u1f04", "\u1f08", "\u1f0c", "\u1f10", "\u1f15", "\u1f18", "\u1f1d", "\u1f20", "\u1f21", "\u1f28", "\u1f29", "\u1f30", "\u1f31", "\u1f38", "\u1f39", "\u1f41", "\u1f44", "\u1f49", "\u1f4c", "\u1f50", "\u1f51", "\u1f59", "\u1f61", "\u1f69", "\u1f70", "\u1f72", "\u1f74", "\u1f76", "\u1f78", "\u1f7a", "\u1f7c", "\u1fb6", "\u1fba", "\u1fc6", "\u1fc8", "\u1fca", "\u1fd6", "\u1fda", "\u1fe6", "\u1fea", "\u1ff6", "\u1ff7", "\u1ff8", "\u1ffa", "\u2081", "\u2082", "\u2083", "\u2113", "\u2460", "\u2461", "\u2463", "\u2c6d", "\u2c6f", "\u2c70", "\u3044", "\u3045", "\u3046", "\u304a", "\u304b", "\u304d", "\u304f", "\u3050", "\u3053", "\u3057", "\u3059", "\u305b", "\u305f", "\u3064", "\u3069", "\u306e", "\u3070", "\u307d", "\u3088", "\u3089", "\u3093", "\u30a1", "\u30a2", "\u30a3", "\u30a4", "\u30a6", "\u30a7", "\u30a8", "\u30a9", "\u30aa", "\u30ab", "\u30ac", "\u30af", "\u30b0", "\u30b3", "\u30b4", "\u30b5", "\u30b6", "\u30b7", "\u30b8", "\u30b9", "\u30ba", "\u30bb", "\u30bc", "\u30bd", "\u30bf", "\u30c1", "\u30c3", "\u30c4", "\u30c6", "\u30c7", "\u30c8", "\u30c9", "\u30ca", "\u30cb", "\u30ce", "\u30cf", "\u30d0", "\u30d1", "\u30d2", "\u30d3", "\u30d5", "\u30d6", "\u30d7", "\u30d9", "\u30da", "\u30dc", "\u30de", "\u30df", "\u30e1", "\u30e3", "\u30e4", "\u30e5", "\u30e6", "\u30e9", "\u30ea", "\u30eb", "\u30ec", "\u30ed", "\u30ef", "\u30f3", "\u30f4", "\u30fc", "\ua7aa", "\ua7ac", "\ua7ad", "\ua7ae", "\ua7b1", "\ua7b2", "\ua7c5", "\uac70", "\ub9c8", "\ub9c9", "\ub9d0", "\uc0ac", "\uc778", "\uc804", "\uc9c0", "\uc9d3", "\ud22c", "\ufb01"]
|
checkpoints/detector.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b7d50c74b2dba9acb8dd76d2fbcf75e6eeae0cb3e9688edf42c91aa5550ade1
|
| 3 |
+
size 181677320
|
checkpoints/recognizer.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db307d9b0dcb6cd15ab6c71e302fd62ca90ce077c3013c9f63a4ba0dbfdf3f50
|
| 3 |
+
size 19823477
|
checkpoints/relational.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b1db5a62853269aabd8a040eeb05038a871032e8275def77653631657cb8ca4a
|
| 3 |
+
size 9048309
|
example.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
from nemo_retriever_ocr.inference.pipeline import NemoRetrieverOCR
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main(image_path, merge_level, no_visualize, model_dir):
|
| 12 |
+
ocr_pipeline = NemoRetrieverOCR()
|
| 13 |
+
|
| 14 |
+
predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
|
| 15 |
+
|
| 16 |
+
print(f"Found {len(predictions)} text regions.")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
parser = argparse.ArgumentParser(description="Run OCR inference and annotate image.")
|
| 21 |
+
parser.add_argument("image_path", type=str, help="Path to the input image.")
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--merge-level",
|
| 24 |
+
type=str,
|
| 25 |
+
choices=["word", "sentence", "paragraph"],
|
| 26 |
+
default="paragraph",
|
| 27 |
+
help="Merge level for OCR output (word, sentence, paragraph).",
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument("--no-visualize", action="store_true", help="Do not save the annotated image.")
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--model-dir",
|
| 32 |
+
type=str,
|
| 33 |
+
help="Path to the model checkpoints.",
|
| 34 |
+
default="./checkpoints",
|
| 35 |
+
)
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
main(
|
| 39 |
+
args.image_path,
|
| 40 |
+
merge_level=args.merge_level,
|
| 41 |
+
no_visualize=args.no_visualize,
|
| 42 |
+
model_dir=args.model_dir,
|
| 43 |
+
)
|
nemo-retriever-ocr/cpp/.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
load_png/wuffs-v0.3.c filter=lfs diff=lfs merge=lfs -text
|
nemo-retriever-ocr/cpp/.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
.vscode
|
| 3 |
+
build
|
| 4 |
+
*.egg-info
|
| 5 |
+
dist
|
| 6 |
+
.vs
|
nemo-retriever-ocr/cpp/.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "trove"]
|
| 2 |
+
path = trove
|
| 3 |
+
url = https://github.com/bryancatanzaro/trove.git
|
nemo-retriever-ocr/cpp/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Optimized Image Operations for PyTorch
|
| 2 |
+
|
| 3 |
+
## Installation
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
python setup.py install
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
## Usage
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
# It's important that you do this first
|
| 13 |
+
import torch
|
| 14 |
+
from pytorch_image_ops import color_transform, spatial_transform
|
| 15 |
+
```
|
nemo-retriever-ocr/cpp/beam_decode/beam_decode.cpp
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "beam_decode.h"
|
| 6 |
+
|
| 7 |
+
#include <vector>
|
| 8 |
+
#include <deque>
|
| 9 |
+
#include <limits>
|
| 10 |
+
#include <memory>
|
| 11 |
+
#include <unordered_set>
|
| 12 |
+
#include <set>
|
| 13 |
+
#include <algorithm>
|
| 14 |
+
#include <chrono>
|
| 15 |
+
|
| 16 |
+
#include "../common.h"
|
| 17 |
+
#include "prefix.h"
|
| 18 |
+
#include "log_sum_exp.h"
|
| 19 |
+
#include "sbo_lm.h"
|
| 20 |
+
|
| 21 |
+
using namespace std;
|
| 22 |
+
|
| 23 |
+
template<typename scalar_t>
|
| 24 |
+
using pred_seq_t = torch::TensorAccessor<scalar_t, 2>;
|
| 25 |
+
|
| 26 |
+
struct PrefixScore
|
| 27 |
+
{
|
| 28 |
+
float_t lProbBlank;
|
| 29 |
+
float_t lProbChar;
|
| 30 |
+
// float_t raw_lProbBlank;
|
| 31 |
+
// float_t raw_lProbChar;
|
| 32 |
+
mutable float_t _lProb;
|
| 33 |
+
|
| 34 |
+
PrefixScore(float_t lProbBlank = NEG_INF /* log P(0) */, float_t lProbChar = NEG_INF /* log P(0) */)
|
| 35 |
+
: lProbBlank(lProbBlank), lProbChar(lProbChar), _lProb(NEG_INF)
|
| 36 |
+
// , raw_lProbBlank(lProbBlank), raw_lProbChar(lProbChar)
|
| 37 |
+
{}
|
| 38 |
+
|
| 39 |
+
float_t get_lScore() const {
|
| 40 |
+
if (_lProb == NEG_INF) {
|
| 41 |
+
_lProb = log_sum_exp(lProbBlank, lProbChar);
|
| 42 |
+
}
|
| 43 |
+
return _lProb;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// float_t get_raw_lScore() const {
|
| 47 |
+
// return log_sum_exp(raw_lProbBlank, raw_lProbChar);
|
| 48 |
+
// }
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
typedef std::unordered_map<Prefix*, PrefixScore> PrefixMap;
|
| 52 |
+
typedef std::pair<Prefix*, PrefixScore> BeamItem;
|
| 53 |
+
typedef std::vector<BeamItem> Beam;
|
| 54 |
+
|
| 55 |
+
/*
|
| 56 |
+
Allows us to get an estimate of the vision model confidence, irrespective of how the language
|
| 57 |
+
model guided the decoding. NOTE: This scoring could follow an entirely different path than
|
| 58 |
+
the returned decoded sequence.
|
| 59 |
+
*/
|
| 60 |
+
template<typename scalar_t>
|
| 61 |
+
scalar_t get_vision_confidence(const pred_seq_t<scalar_t> &logProbs, scalar_t minProb)
|
| 62 |
+
{
|
| 63 |
+
const int64_t T = logProbs.size(0);
|
| 64 |
+
const int64_t S = logProbs.size(1);
|
| 65 |
+
|
| 66 |
+
scalar_t ret = 0; // log(1)
|
| 67 |
+
|
| 68 |
+
for (size_t t = 0; t < T; ++t) {
|
| 69 |
+
float_t maxP = logProbs[t][0];
|
| 70 |
+
int64_t maxC = 0;
|
| 71 |
+
for (int64_t c = 1; c < S; ++c) {
|
| 72 |
+
float_t p = logProbs[t][c];
|
| 73 |
+
if (p > maxP) {
|
| 74 |
+
maxP = p;
|
| 75 |
+
maxC = c;
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
ret += maxP;
|
| 79 |
+
// Ignore everything past the sequence terminator
|
| 80 |
+
if (maxC == 1) {
|
| 81 |
+
break;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
if (ret < minProb) {
|
| 85 |
+
break;
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
return ret;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
template<typename scalar_t>
|
| 94 |
+
pair<vector<token_t>, float_t>
|
| 95 |
+
ctc_beam_decode_impl(const pred_seq_t<scalar_t> &probs, const int64_t beamSize,
|
| 96 |
+
const int64_t blank, scalar_t minProb,
|
| 97 |
+
const LanguageModel &langModel, scalar_t lmWeight)
|
| 98 |
+
{
|
| 99 |
+
if (blank != 0) {
|
| 100 |
+
throw runtime_error("Currently, only ordinal 0 supported for the blank prediction");
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
const int64_t T = probs.size(0);
|
| 104 |
+
const int64_t S = probs.size(1);
|
| 105 |
+
|
| 106 |
+
// NOTE: In log space, the following is true:
|
| 107 |
+
// 1. Adding two probabilities: log_sum_exp(l_p_a, l_p_b)
|
| 108 |
+
// 2. Multiplying two probabilities: l_p_a + l_p_b
|
| 109 |
+
// 3. log P(0) = -inf
|
| 110 |
+
// 4. log P(1) = 0
|
| 111 |
+
|
| 112 |
+
// Convert to log-space
|
| 113 |
+
if (minProb > 0) {
|
| 114 |
+
minProb = log(minProb);
|
| 115 |
+
} else {
|
| 116 |
+
minProb = NEG_INF;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
auto retScore = get_vision_confidence(probs, minProb);
|
| 120 |
+
|
| 121 |
+
if (retScore < minProb) {
|
| 122 |
+
return { {}, NEG_INF };
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
PrefixAllocator prefixAlloc;
|
| 126 |
+
|
| 127 |
+
Beam beam;
|
| 128 |
+
beam.emplace_back(prefixAlloc.GetPrefix(), PrefixScore{0, NEG_INF}); // Add a dummy first node
|
| 129 |
+
|
| 130 |
+
Beam terminated;
|
| 131 |
+
|
| 132 |
+
typedef tuple<Prefix*, token_t> lm_cache_key_t;
|
| 133 |
+
unordered_map<lm_cache_key_t, float_t> lmScoreCache;
|
| 134 |
+
|
| 135 |
+
for (int64_t t = 0; t < T; ++t) {
|
| 136 |
+
PrefixMap nextBeam;
|
| 137 |
+
|
| 138 |
+
// Add all of the completed paths to the next beam.
|
| 139 |
+
// This allows us to accumulate new paths into these,
|
| 140 |
+
// but otherwise not process them
|
| 141 |
+
for (const BeamItem &prevNode : beam) {
|
| 142 |
+
if (prevNode.first->Token == 1) {
|
| 143 |
+
nextBeam.insert(prevNode);
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// Loop over vocab
|
| 148 |
+
for (int64_t s = 0; s < S; ++s) {
|
| 149 |
+
float_t lpEmit = probs[t][s];
|
| 150 |
+
|
| 151 |
+
if (lpEmit < minProb) {
|
| 152 |
+
continue;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
for (const BeamItem &prevNode : beam) {
|
| 156 |
+
Prefix *prevPrefix = prevNode.first;
|
| 157 |
+
const PrefixScore &prevScore = prevNode.second;
|
| 158 |
+
|
| 159 |
+
// Ignore already completed paths
|
| 160 |
+
if (prevPrefix->Token == 1) {
|
| 161 |
+
continue;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
// Ignore impossible paths
|
| 165 |
+
if (prevScore.lProbBlank == NEG_INF && prevScore.lProbChar == NEG_INF) {
|
| 166 |
+
continue;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// If we propose a blank the prefix doesn't change.
|
| 170 |
+
// Only the probability of ending in blank gets updated.
|
| 171 |
+
if (s == blank) {
|
| 172 |
+
PrefixScore &score = nextBeam[prevPrefix];
|
| 173 |
+
score.lProbBlank = log_sum_exp(score.lProbBlank , prevScore.lProbBlank + lpEmit, prevScore.lProbChar + lpEmit);
|
| 174 |
+
// score.raw_lProbBlank = log_sum_exp(score.raw_lProbBlank, prevScore.raw_lProbBlank + lpEmit, prevScore.raw_lProbChar + lpEmit);
|
| 175 |
+
continue;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
// Extend the prefix by the new character s and add it to the beam.
|
| 179 |
+
// Only the probability of not ending in blank gets updated.
|
| 180 |
+
token_t prevToken = prevPrefix->Token;
|
| 181 |
+
|
| 182 |
+
// NOTE: We always create a new prefix regardless of duplication because the PrefixScore
|
| 183 |
+
// is simultaneously tracking prefixes that do and don't end in a blank. And it's those
|
| 184 |
+
// that end in a blank that would cause the prefix to be extended.
|
| 185 |
+
auto extendPrefix = prefixAlloc.GetPrefix(s, prevPrefix);
|
| 186 |
+
|
| 187 |
+
// Evaluate the language model, but use the cache if we've already considered this string before
|
| 188 |
+
auto lmCacheItem = make_tuple(prevPrefix, s);
|
| 189 |
+
auto lmCacheIter = lmScoreCache.find(lmCacheItem);
|
| 190 |
+
float_t lpLang = 0;
|
| 191 |
+
if (lmCacheIter == lmScoreCache.end()) {
|
| 192 |
+
lpLang = langModel.ScoreTransition(prevPrefix, s);
|
| 193 |
+
lpLang *= lmWeight;
|
| 194 |
+
lmCacheIter = lmScoreCache.emplace(lmCacheItem, lpLang).first;
|
| 195 |
+
}
|
| 196 |
+
lpLang = lmCacheIter->second;
|
| 197 |
+
|
| 198 |
+
PrefixScore &extendScore = nextBeam[extendPrefix];
|
| 199 |
+
// Remember, adding two log probabilities is equivalent to multiplying two probabilities
|
| 200 |
+
if (s != prevToken) {
|
| 201 |
+
extendScore.lProbChar = log_sum_exp(extendScore.lProbChar, prevScore.lProbBlank + lpEmit + lpLang, prevScore.lProbChar + lpEmit + lpLang);
|
| 202 |
+
// extendScore.raw_lProbChar = log_sum_exp(extendScore.raw_lProbChar, prevScore.raw_lProbBlank + lpEmit , prevScore.raw_lProbChar + lpEmit );
|
| 203 |
+
} else {
|
| 204 |
+
// We don't include the previous probability of not ending in blank if s is repeated at the end. The CTC
|
| 205 |
+
// algorithm merges characters not separated by a blank.
|
| 206 |
+
extendScore.lProbChar = log_sum_exp(extendScore.lProbChar , prevScore.lProbBlank + lpEmit + lpLang);
|
| 207 |
+
// extendScore.raw_lProbChar = log_sum_exp(extendScore.raw_lProbChar, prevScore.raw_lProbBlank + lpEmit );
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
// If the token is repeated, we also have to deal with the unchanged prefix since repeated characters are collapsed
|
| 211 |
+
if (s == prevToken) {
|
| 212 |
+
PrefixScore &collapseScore = nextBeam[prevPrefix];
|
| 213 |
+
collapseScore.lProbChar = log_sum_exp(collapseScore.lProbChar , prevScore.lProbChar + lpEmit);
|
| 214 |
+
// collapseScore.raw_lProbChar = log_sum_exp(collapseScore.raw_lProbChar, prevScore.raw_lProbChar + lpEmit);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
Beam vecNextBeam(begin(nextBeam), end(nextBeam));
|
| 221 |
+
|
| 222 |
+
if (vecNextBeam.size() > beamSize) {
|
| 223 |
+
partial_sort(begin(vecNextBeam), begin(vecNextBeam) + beamSize, end(vecNextBeam),
|
| 224 |
+
[] (const BeamItem &a, const BeamItem &b) {
|
| 225 |
+
return a.second.get_lScore() > b.second.get_lScore();
|
| 226 |
+
}
|
| 227 |
+
);
|
| 228 |
+
vecNextBeam.resize(beamSize);
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
beam = move(vecNextBeam);
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// Find the best raw score
|
| 235 |
+
const BeamItem *bestItem = nullptr;
|
| 236 |
+
// for (const BeamItem &b : beam) {
|
| 237 |
+
// if (bestItem == nullptr or b.second.get_raw_lScore() > bestItem->second.get_raw_lScore()) {
|
| 238 |
+
// bestItem = &b;
|
| 239 |
+
// }
|
| 240 |
+
// }
|
| 241 |
+
if (! beam.empty()) {
|
| 242 |
+
bestItem = &beam[0];
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
if (bestItem != nullptr) {
|
| 246 |
+
auto retList = bestItem->first->ToList();
|
| 247 |
+
|
| 248 |
+
return { move(retList), retScore };
|
| 249 |
+
} else {
|
| 250 |
+
return { {}, NEG_INF };
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
typedef std::pair<Prefix*, float_t> RegBeamItem;
|
| 255 |
+
|
| 256 |
+
bool operator<(const RegBeamItem &a, const RegBeamItem &b) {
|
| 257 |
+
return a.second > b.second;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
template<typename scalar_t>
|
| 261 |
+
pair<vector<token_t>, float_t>
|
| 262 |
+
reg_beam_decode_impl(const pred_seq_t<scalar_t> &logProbs, const int64_t beamSize,
|
| 263 |
+
scalar_t minProb,
|
| 264 |
+
const LanguageModel &langModel, scalar_t lmWeight)
|
| 265 |
+
{
|
| 266 |
+
const int64_t T = logProbs.size(0);
|
| 267 |
+
const int64_t S = logProbs.size(1);
|
| 268 |
+
|
| 269 |
+
// NOTE: In log space, the following is true:
|
| 270 |
+
// 1. Adding two probabilities: log_sum_exp(l_p_a, l_p_b)
|
| 271 |
+
// 2. Multiplying two probabilities: l_p_a + l_p_b
|
| 272 |
+
// 3. log P(0) = -inf
|
| 273 |
+
// 4. log P(1) = 0
|
| 274 |
+
|
| 275 |
+
// Convert to log-space
|
| 276 |
+
if (minProb > 0) {
|
| 277 |
+
minProb = log(minProb);
|
| 278 |
+
} else {
|
| 279 |
+
minProb = NEG_INF;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
auto retScore = get_vision_confidence(logProbs, minProb);
|
| 283 |
+
|
| 284 |
+
if (retScore < minProb) {
|
| 285 |
+
return { {}, NEG_INF };
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
PrefixAllocator prefixAlloc;
|
| 289 |
+
|
| 290 |
+
vector<RegBeamItem> beam, nextBeam;
|
| 291 |
+
beam.emplace_back(prefixAlloc.GetPrefix(), 0); // log(1) = 0
|
| 292 |
+
|
| 293 |
+
for (int64_t t = 0; t < T && !beam.empty(); ++t) {
|
| 294 |
+
nextBeam.clear();
|
| 295 |
+
|
| 296 |
+
auto addToBeam = [&nextBeam, beamSize] (const RegBeamItem &rbi) {
|
| 297 |
+
nextBeam.push_back(rbi);
|
| 298 |
+
};
|
| 299 |
+
|
| 300 |
+
// Expand each path in the beam
|
| 301 |
+
for (const RegBeamItem &prevNode : beam) {
|
| 302 |
+
if (prevNode.first->Token == 1) {
|
| 303 |
+
// Move completed paths along without processing further
|
| 304 |
+
addToBeam(prevNode);
|
| 305 |
+
continue;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
Prefix *prevPrefix = prevNode.first;
|
| 309 |
+
float_t prevScore = prevNode.second;
|
| 310 |
+
|
| 311 |
+
// Loop over vocab
|
| 312 |
+
for (int64_t s = 0; s < S; ++s) {
|
| 313 |
+
float_t lpEmit = logProbs[t][s];
|
| 314 |
+
|
| 315 |
+
if (lpEmit < minProb) {
|
| 316 |
+
// The probability dropped below threshold, so stop processing this path
|
| 317 |
+
continue;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
auto extendPrefix = prefixAlloc.GetPrefix(s, prevPrefix);
|
| 321 |
+
|
| 322 |
+
float_t lpLang = langModel.ScoreTransition(prevPrefix, s);
|
| 323 |
+
|
| 324 |
+
float_t lpNext = prevScore + lpLang + lpEmit;
|
| 325 |
+
|
| 326 |
+
addToBeam({extendPrefix, lpNext});
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
if (nextBeam.size() > beamSize) {
|
| 331 |
+
// Find the top-k items, and then truncate the rest
|
| 332 |
+
partial_sort(begin(nextBeam), begin(nextBeam) + beamSize, end(nextBeam));
|
| 333 |
+
nextBeam.resize(beamSize);
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
std::swap(beam, nextBeam);
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
if (!beam.empty()) {
|
| 340 |
+
// The highest probability element will always be in the back
|
| 341 |
+
RegBeamItem rbi{ nullptr, NEG_INF };
|
| 342 |
+
for (auto &rb : beam) {
|
| 343 |
+
if (rbi.first == nullptr || rb.second > rbi.second) {
|
| 344 |
+
rbi = rb;
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
auto retList = rbi.first->ToList();
|
| 349 |
+
|
| 350 |
+
return { move(retList), retScore };
|
| 351 |
+
} else {
|
| 352 |
+
return { {}, NEG_INF };
|
| 353 |
+
}
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
template<typename scalar_t>
|
| 359 |
+
void dp_beam_decode_impl(const torch::TensorAccessor<scalar_t, 3> &probsAccess,
|
| 360 |
+
torch::TensorAccessor<int64_t, 2> retAccess,
|
| 361 |
+
torch::TensorAccessor<scalar_t, 1> confAccess,
|
| 362 |
+
int64_t beamSize, int64_t blank,
|
| 363 |
+
scalar_t minProb,
|
| 364 |
+
const LanguageModel *langModel,
|
| 365 |
+
scalar_t lmWeight,
|
| 366 |
+
bool combineDuplicates)
|
| 367 |
+
{
|
| 368 |
+
const int64_t N = probsAccess.size(0);
|
| 369 |
+
|
| 370 |
+
#pragma omp parallel for num_threads(8)
|
| 371 |
+
for (int64_t i = 0; i < N; ++i) {
|
| 372 |
+
vector<token_t> seq;
|
| 373 |
+
float_t lConf;
|
| 374 |
+
if (combineDuplicates) {
|
| 375 |
+
tie(seq, lConf) = ctc_beam_decode_impl(probsAccess[i], beamSize, blank,
|
| 376 |
+
minProb,
|
| 377 |
+
*langModel, lmWeight);
|
| 378 |
+
} else {
|
| 379 |
+
tie(seq, lConf) = reg_beam_decode_impl(probsAccess[i], beamSize,
|
| 380 |
+
minProb,
|
| 381 |
+
*langModel, lmWeight);
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
int64_t sz = min<int64_t>(seq.size(), retAccess.size(1));
|
| 385 |
+
|
| 386 |
+
for (int64_t k = 0; k < sz; ++k) {
|
| 387 |
+
retAccess[i][k] = seq[k];
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
confAccess[i] = exp(lConf);
|
| 391 |
+
}
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 395 |
+
beam_decode(torch::Tensor probs, int64_t beamSize, int64_t blank,
|
| 396 |
+
float minProb,
|
| 397 |
+
const LanguageModel *langModel,
|
| 398 |
+
float lmWeight,
|
| 399 |
+
bool combineDuplicates)
|
| 400 |
+
{
|
| 401 |
+
if (langModel == nullptr) {
|
| 402 |
+
langModel = &NullLanguageModel;
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
auto tStart = chrono::high_resolution_clock::now();
|
| 406 |
+
|
| 407 |
+
probs = probs.contiguous();
|
| 408 |
+
|
| 409 |
+
bool collapse = false;
|
| 410 |
+
if (probs.dim() == 2) {
|
| 411 |
+
// N,T,C
|
| 412 |
+
probs = probs.unsqueeze(0);
|
| 413 |
+
collapse = true;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
probs = probs.log();
|
| 417 |
+
|
| 418 |
+
torch::Tensor ret = torch::ones({ probs.size(0), probs.size(1) }, torch::kInt64);
|
| 419 |
+
torch::Tensor conf = torch::zeros({ probs.size(0) }, probs.options());
|
| 420 |
+
|
| 421 |
+
auto retAccess = ret.accessor<int64_t, 2>();
|
| 422 |
+
|
| 423 |
+
AT_DISPATCH_FLOATING_TYPES(
|
| 424 |
+
probs.scalar_type(),
|
| 425 |
+
"cpu_beam_decode",
|
| 426 |
+
([&] {
|
| 427 |
+
dp_beam_decode_impl(
|
| 428 |
+
probs.accessor<scalar_t, 3>(),
|
| 429 |
+
retAccess,
|
| 430 |
+
conf.accessor<scalar_t, 1>(),
|
| 431 |
+
beamSize, blank,
|
| 432 |
+
static_cast<scalar_t>(minProb),
|
| 433 |
+
langModel,
|
| 434 |
+
static_cast<scalar_t>(lmWeight),
|
| 435 |
+
combineDuplicates
|
| 436 |
+
);
|
| 437 |
+
})
|
| 438 |
+
);
|
| 439 |
+
|
| 440 |
+
if (collapse) {
|
| 441 |
+
ret = ret.squeeze(0);
|
| 442 |
+
conf = conf[0];
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
auto tEnd = chrono::high_resolution_clock::now();
|
| 446 |
+
|
| 447 |
+
typedef chrono::duration<double, std::milli> tp_t;
|
| 448 |
+
tp_t totalElapsed = tEnd - tStart;
|
| 449 |
+
|
| 450 |
+
cout << "Beam Decode " << probs.size(0) << " - "
|
| 451 |
+
<< "Total: " << totalElapsed.count() << "ms"
|
| 452 |
+
<< endl;
|
| 453 |
+
|
| 454 |
+
return { ret, conf };
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
std::unique_ptr<LanguageModel> create_sbo_lm(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoffWeight)
|
| 458 |
+
{
|
| 459 |
+
return make_unique<SBO_LanguageModel>(dataFilePath, move(tokenMapping), backoffWeight);
|
| 460 |
+
}
|
nemo-retriever-ocr/cpp/beam_decode/beam_decode.h
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <torch/torch.h>
|
| 8 |
+
|
| 9 |
+
#include "language_model.h"
|
| 10 |
+
|
| 11 |
+
std::tuple<torch::Tensor, torch::Tensor>
|
| 12 |
+
beam_decode(torch::Tensor probs, int64_t beamSize, int64_t blank,
|
| 13 |
+
float minProb,
|
| 14 |
+
const LanguageModel *langModel,
|
| 15 |
+
float lmWeight,
|
| 16 |
+
bool combineDuplicates);
|
| 17 |
+
|
| 18 |
+
std::unique_ptr<LanguageModel> create_sbo_lm(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoffWeight);
|
nemo-retriever-ocr/cpp/beam_decode/kn_lm.cpp
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "kn_lm.h"
|
| 6 |
+
|
| 7 |
+
using namespace std;
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
KN_LanguageModel::KN_LanguageModel(const string &dataFilePath, token_mapping_t tokenMapping, float_t knDelta)
|
| 11 |
+
: NGramLMBase(dataFilePath, move(tokenMapping)), m_knDelta(knDelta)
|
| 12 |
+
{
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
float KN_LanguageModel::ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const
|
| 16 |
+
{
|
| 17 |
+
if (prefix.empty()) {
|
| 18 |
+
return ScoreUnigram(suffix);
|
| 19 |
+
} else {
|
| 20 |
+
return ScoreTransition(prefix, suffix);
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
float_t KN_LanguageModel::ScoreUnigram(const std::wstring &uni) const
|
| 25 |
+
{
|
| 26 |
+
auto lIter = m_lookup[1].find(L""s);
|
| 27 |
+
if (lIter == m_lookup[1].end()) {
|
| 28 |
+
throw std::runtime_error("Unigrams not supported by this model!");
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
auto uniIter = lIter->second.find(uni);
|
| 32 |
+
float_t ctUni = 1e-8;
|
| 33 |
+
if (uniIter != lIter->second.end()) {
|
| 34 |
+
ctUni = uniIter->second;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
float_t ctSuffixes = GetPrefixSum(L""s);
|
| 38 |
+
|
| 39 |
+
return ctUni / ctSuffixes;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
float_t KN_LanguageModel::ScoreTransition(const std::wstring &prefix, const std::wstring &suffix) const
|
| 43 |
+
{
|
| 44 |
+
if (prefix.empty()) {
|
| 45 |
+
// The number of distinct bigrams that end with this token
|
| 46 |
+
auto rlIter = m_reverseLookup.find(suffix);
|
| 47 |
+
|
| 48 |
+
float_t ctEndingBigrams = 0;
|
| 49 |
+
if (rlIter != m_reverseLookup.end()) {
|
| 50 |
+
ctEndingBigrams = rlIter->second[2].size();
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
float_t ctAllBigrams = m_lookup[2].size();
|
| 54 |
+
|
| 55 |
+
return ctEndingBigrams / ctAllBigrams;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
auto lIter = m_lookup[prefix.size() + 1].find(prefix);
|
| 59 |
+
float_t ctUqSuffixes = 0;
|
| 60 |
+
float_t ctSuffixes = 0;
|
| 61 |
+
float_t ctSuffix = 0;
|
| 62 |
+
if (lIter != m_lookup[prefix.size() + 1].end()) {
|
| 63 |
+
ctUqSuffixes = lIter->second.size();
|
| 64 |
+
|
| 65 |
+
ctSuffixes = GetPrefixSum(prefix);
|
| 66 |
+
|
| 67 |
+
auto sIter = lIter->second.find(suffix);
|
| 68 |
+
if (sIter != lIter->second.end()) {
|
| 69 |
+
ctSuffix = sIter->second;
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
float_t factor = 0;
|
| 74 |
+
float_t main = 0;
|
| 75 |
+
if (ctSuffixes != 0) {
|
| 76 |
+
factor = m_knDelta * ctUqSuffixes / ctSuffixes;
|
| 77 |
+
// TODO: Figure out how to make this call without copying the string!
|
| 78 |
+
factor *= ScoreTransition({begin(prefix) + 1, end(prefix)}, suffix);
|
| 79 |
+
|
| 80 |
+
main = max<float_t>(ctSuffix - m_knDelta, 0) / ctSuffixes;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
float_t total = main + factor;
|
| 84 |
+
|
| 85 |
+
return total;
|
| 86 |
+
}
|
nemo-retriever-ocr/cpp/beam_decode/kn_lm.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <unordered_map>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#include "ngram_lm_base.h"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class KN_LanguageModel
|
| 14 |
+
: public NGramLMBase
|
| 15 |
+
{
|
| 16 |
+
public:
|
| 17 |
+
KN_LanguageModel(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t knDelta);
|
| 18 |
+
|
| 19 |
+
protected:
|
| 20 |
+
virtual float_t ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const override;
|
| 21 |
+
|
| 22 |
+
private:
|
| 23 |
+
float_t ScoreUnigram(const std::wstring &uni) const;
|
| 24 |
+
float_t ScoreTransition(const std::wstring &prefix, const std::wstring &suffix) const;
|
| 25 |
+
|
| 26 |
+
float_t m_knDelta;
|
| 27 |
+
};
|
nemo-retriever-ocr/cpp/beam_decode/language_model.cpp
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "language_model.h"
|
| 6 |
+
|
| 7 |
+
#include <locale>
|
| 8 |
+
#include <codecvt>
|
| 9 |
+
|
| 10 |
+
using namespace std;
|
| 11 |
+
|
| 12 |
+
const NullLanguageModel_t NullLanguageModel;
|
| 13 |
+
|
| 14 |
+
NullLanguageModel_t::NullLanguageModel_t()
|
| 15 |
+
: LanguageModel({})
|
| 16 |
+
{
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
TokenMappingWrapper::TokenMappingWrapper(token_mapping_t mapping)
|
| 20 |
+
: token_mapping(move(mapping))
|
| 21 |
+
{
|
| 22 |
+
for (const auto &mp : token_mapping) {
|
| 23 |
+
if (mp.second.size() == 1) {
|
| 24 |
+
wchar_t c = mp.second.front();
|
| 25 |
+
reverse_token_mapping.emplace(c, mp.first);
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
TokenMappingWrapper::Ptr create_token_mapping(token_mapping_t tokenMapping)
|
| 31 |
+
{
|
| 32 |
+
return make_shared<TokenMappingWrapper>(move(tokenMapping));
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
template<typename token_t>
|
| 37 |
+
vector<tuple<wstring, float>>
|
| 38 |
+
decode_sequences_impl(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping,
|
| 39 |
+
c10::optional<torch::Tensor> probs)
|
| 40 |
+
{
|
| 41 |
+
const token_mapping_t &mapping = tokenMapping->token_mapping;
|
| 42 |
+
|
| 43 |
+
auto tokensAccess = tokens.accessor<token_t, 2>();
|
| 44 |
+
|
| 45 |
+
torch::Tensor pTens = probs.value_or(torch::ones({ tokens.size(0) }, torch::kFloat32));
|
| 46 |
+
if (pTens.dim() == 1) {
|
| 47 |
+
pTens = pTens.unsqueeze(1);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
auto probsAccess = pTens.accessor<float, 2>();
|
| 51 |
+
|
| 52 |
+
const int64_t B = tokens.size(0);
|
| 53 |
+
const int64_t T = tokens.size(1);
|
| 54 |
+
|
| 55 |
+
vector<tuple<wstring, float>> ret;
|
| 56 |
+
|
| 57 |
+
for (int64_t b = 0; b < B; ++b) {
|
| 58 |
+
wstring buff;
|
| 59 |
+
|
| 60 |
+
float logProb = 0.0f; // log 1
|
| 61 |
+
bool done = false;
|
| 62 |
+
for (int64_t t = 0; t < T && ! done; ++t) {
|
| 63 |
+
typename token_mapping_t::key_type tokIdx = tokensAccess[b][t];
|
| 64 |
+
|
| 65 |
+
if (t < probsAccess.size(1)) {
|
| 66 |
+
logProb += log(probsAccess[b][t]);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
switch (tokIdx) {
|
| 70 |
+
case 0:
|
| 71 |
+
// Blank char
|
| 72 |
+
continue;
|
| 73 |
+
case 1:
|
| 74 |
+
// End of sequence char
|
| 75 |
+
done = true;
|
| 76 |
+
break;
|
| 77 |
+
case 2:
|
| 78 |
+
buff.push_back('^');
|
| 79 |
+
break;
|
| 80 |
+
default:
|
| 81 |
+
auto iter = mapping.find(tokIdx);
|
| 82 |
+
if (iter == mapping.end()) {
|
| 83 |
+
throw std::runtime_error("The token mapping doesn't contain an entry for index " + to_string(tokIdx));
|
| 84 |
+
}
|
| 85 |
+
buff += iter->second;
|
| 86 |
+
break;
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
ret.emplace_back(move(buff), exp(logProb));
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return ret;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
vector<tuple<wstring, float>>
|
| 97 |
+
decode_sequences(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping,
|
| 98 |
+
c10::optional<torch::Tensor> probs)
|
| 99 |
+
{
|
| 100 |
+
if (tokens.dim() != 2) {
|
| 101 |
+
throw std::runtime_error("`tokens` must be 2-dimensions of type B,T!");
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
if (tokenMapping == nullptr) {
|
| 105 |
+
throw std::runtime_error("Cannot supply a null token mapping!");
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
const token_mapping_t &mapping = tokenMapping->token_mapping;
|
| 109 |
+
|
| 110 |
+
if (mapping.empty()) {
|
| 111 |
+
throw std::runtime_error("The token mapping hasn't been initialized!");
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
if (probs.has_value()) {
|
| 115 |
+
if (probs.value().scalar_type() != torch::kFloat32) {
|
| 116 |
+
throw std::runtime_error("If the probability distribution is specified, then it must be of type `torch.float32`");
|
| 117 |
+
}
|
| 118 |
+
if (probs.value().size(0) != tokens.size(0)) {
|
| 119 |
+
throw std::runtime_error("The probability distribution batch size doesn't match the tokens batch size!");
|
| 120 |
+
}
|
| 121 |
+
if (probs.value().dim() == 2 && probs.value().size(1) != tokens.size(1)) {
|
| 122 |
+
throw std::runtime_error("Invalid probability distribution shape!");
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
vector<tuple<wstring, float>> ret;
|
| 127 |
+
|
| 128 |
+
AT_DISPATCH_INTEGRAL_TYPES(
|
| 129 |
+
tokens.scalar_type(),
|
| 130 |
+
"decode_sequences_impl",
|
| 131 |
+
([&] {
|
| 132 |
+
ret = decode_sequences_impl<scalar_t>(tokens, tokenMapping, probs);
|
| 133 |
+
})
|
| 134 |
+
);
|
| 135 |
+
|
| 136 |
+
return ret;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
std::string ws2s(const std::wstring& wstr)
|
| 141 |
+
{
|
| 142 |
+
using convert_typeX = std::codecvt_utf8<wchar_t>;
|
| 143 |
+
std::wstring_convert<convert_typeX, wchar_t> converterX;
|
| 144 |
+
|
| 145 |
+
return converterX.to_bytes(wstr);
|
| 146 |
+
}
|
| 147 |
+
|
nemo-retriever-ocr/cpp/beam_decode/language_model.h
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <memory>
|
| 8 |
+
|
| 9 |
+
#include <torch/torch.h>
|
| 10 |
+
|
| 11 |
+
#include "prefix.h"
|
| 12 |
+
#include "log_sum_exp.h"
|
| 13 |
+
|
| 14 |
+
typedef std::unordered_map<int64_t, std::wstring> token_mapping_t;
|
| 15 |
+
typedef std::unordered_map<wchar_t, int64_t> reverse_token_mapping_t;
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LanguageModel
|
| 19 |
+
{
|
| 20 |
+
public:
|
| 21 |
+
virtual ~LanguageModel() {}
|
| 22 |
+
|
| 23 |
+
virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const = 0;
|
| 24 |
+
|
| 25 |
+
const token_mapping_t &TokenMapping() const { return m_tokenMapping; }
|
| 26 |
+
|
| 27 |
+
protected:
|
| 28 |
+
LanguageModel(token_mapping_t tokenMapping)
|
| 29 |
+
: m_tokenMapping(std::move(tokenMapping))
|
| 30 |
+
{}
|
| 31 |
+
|
| 32 |
+
token_mapping_t m_tokenMapping;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class NullLanguageModel_t
|
| 37 |
+
: public LanguageModel
|
| 38 |
+
{
|
| 39 |
+
public:
|
| 40 |
+
NullLanguageModel_t();
|
| 41 |
+
|
| 42 |
+
virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const override
|
| 43 |
+
{
|
| 44 |
+
// log P(1)
|
| 45 |
+
// Which means the probability is unchanged
|
| 46 |
+
return 0;
|
| 47 |
+
}
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
extern const NullLanguageModel_t NullLanguageModel;
|
| 51 |
+
|
| 52 |
+
struct TokenMappingWrapper
|
| 53 |
+
{
|
| 54 |
+
typedef std::shared_ptr<TokenMappingWrapper> Ptr;
|
| 55 |
+
|
| 56 |
+
TokenMappingWrapper(token_mapping_t mapping);
|
| 57 |
+
|
| 58 |
+
token_mapping_t token_mapping;
|
| 59 |
+
reverse_token_mapping_t reverse_token_mapping;
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
TokenMappingWrapper::Ptr create_token_mapping(token_mapping_t tokenMapping);
|
| 63 |
+
|
| 64 |
+
std::vector<std::tuple<std::wstring, float>>
|
| 65 |
+
decode_sequences(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping,
|
| 66 |
+
c10::optional<torch::Tensor> probs = torch::nullopt);
|
nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.cpp
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "log_sum_exp.h"
|
| 6 |
+
|
| 7 |
+
const float_t NEG_INF = -std::numeric_limits<float_t>::infinity();
|
nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <cmath>
|
| 8 |
+
#include <limits>
|
| 9 |
+
#include <algorithm>
|
| 10 |
+
|
| 11 |
+
typedef float float_t;
|
| 12 |
+
extern const float_t NEG_INF;
|
| 13 |
+
|
| 14 |
+
template<typename T>
|
| 15 |
+
inline T max_val(T v)
|
| 16 |
+
{
|
| 17 |
+
return v;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
template<typename T, typename ...Args>
|
| 21 |
+
inline T max_val(T v, Args... rest)
|
| 22 |
+
{
|
| 23 |
+
auto restMax = max_val(rest...);
|
| 24 |
+
|
| 25 |
+
return std::max(v, restMax);
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
template<typename T>
|
| 29 |
+
inline T sum_exp(T maxVal, T v)
|
| 30 |
+
{
|
| 31 |
+
return std::exp(v - maxVal);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
template<typename T, typename ...Args>
|
| 35 |
+
inline T sum_exp(T maxVal, T v, Args... rest)
|
| 36 |
+
{
|
| 37 |
+
auto restSum = sum_exp(maxVal, rest...);
|
| 38 |
+
|
| 39 |
+
return sum_exp(maxVal, v) + restSum;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
template<typename T, typename ...Args>
|
| 43 |
+
inline T log_sum_exp(T v, Args ...args)
|
| 44 |
+
{
|
| 45 |
+
auto maxVal = max_val(v, args...);
|
| 46 |
+
|
| 47 |
+
if (maxVal == -std::numeric_limits<T>::infinity()) {
|
| 48 |
+
return -std::numeric_limits<T>::infinity();
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
auto sumExp = sum_exp(maxVal, v, args...);
|
| 52 |
+
|
| 53 |
+
return maxVal + std::log(sumExp);
|
| 54 |
+
}
|
nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.cpp
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "ngram_lm_base.h"
|
| 6 |
+
|
| 7 |
+
#include <iostream>
|
| 8 |
+
#include <fstream>
|
| 9 |
+
|
| 10 |
+
#if defined( USE_BOOST )
|
| 11 |
+
|
| 12 |
+
#include <boost/archive/binary_oarchive.hpp>
|
| 13 |
+
#include <boost/archive/binary_iarchive.hpp>
|
| 14 |
+
#include <boost/serialization/vector.hpp>
|
| 15 |
+
#include <boost/serialization/string.hpp>
|
| 16 |
+
#include <boost/serialization/unordered_map.hpp>
|
| 17 |
+
|
| 18 |
+
#endif // USE_BOOST
|
| 19 |
+
|
| 20 |
+
using namespace std;
|
| 21 |
+
|
| 22 |
+
const std::wstring WORD_END(1, 2);
|
| 23 |
+
const std::wstring NUMERIC(1, 3);
|
| 24 |
+
const std::wstring UNMODELED(1, 4);
|
| 25 |
+
|
| 26 |
+
struct LMStorage
|
| 27 |
+
{
|
| 28 |
+
lookup_t Lookup;
|
| 29 |
+
reverse_lookup_t ReverseLookup;
|
| 30 |
+
|
| 31 |
+
template<class Archive>
|
| 32 |
+
void serialize(Archive &ar, const unsigned int version) {
|
| 33 |
+
ar & Lookup;
|
| 34 |
+
ar & ReverseLookup;
|
| 35 |
+
}
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
void save_suffix_map(std::fstream& fs, const suffix_map_t& suffix_map)
|
| 39 |
+
{
|
| 40 |
+
// write out number of elements for Lookup
|
| 41 |
+
std::size_t suffix_map_count = suffix_map.size();
|
| 42 |
+
fs.write((char*)(&suffix_map_count), sizeof(suffix_map_count));
|
| 43 |
+
for (suffix_map_t::const_iterator reverse_lookup_it = suffix_map.begin(); reverse_lookup_it != suffix_map.end(); ++reverse_lookup_it)
|
| 44 |
+
{
|
| 45 |
+
// write out the key
|
| 46 |
+
size_t key_len = reverse_lookup_it->first.length();
|
| 47 |
+
fs.write((char*)(&key_len), sizeof(key_len));
|
| 48 |
+
fs.write((char*)(reverse_lookup_it->first.data()), key_len * sizeof(wchar_t));
|
| 49 |
+
|
| 50 |
+
// write out value
|
| 51 |
+
fs.write((char*)(&reverse_lookup_it->second), sizeof(reverse_lookup_it->second));
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void save_lookup(std::fstream& fs, const lookup_t& lookup)
|
| 56 |
+
{
|
| 57 |
+
// write out number of elements for Lookup
|
| 58 |
+
std::size_t lookup_count = lookup.size();
|
| 59 |
+
fs.write((char*)(&lookup_count), sizeof(lookup_count));
|
| 60 |
+
for (lookup_t::const_iterator lookup_it = lookup.begin(); lookup_it != lookup.end(); ++lookup_it)
|
| 61 |
+
{
|
| 62 |
+
// write out element map size
|
| 63 |
+
std::size_t map_elem_count = lookup_it->size();
|
| 64 |
+
fs.write((char*)(&map_elem_count), sizeof(map_elem_count));
|
| 65 |
+
|
| 66 |
+
for (string_suffix_map_t::const_iterator str_sfx_it = lookup_it->begin(); str_sfx_it != lookup_it->end(); ++str_sfx_it)
|
| 67 |
+
{
|
| 68 |
+
// write out key
|
| 69 |
+
size_t key_len = str_sfx_it->first.length();
|
| 70 |
+
fs.write((char*)(&key_len), sizeof(key_len));
|
| 71 |
+
fs.write((char*)(str_sfx_it->first.data()), key_len * sizeof(wchar_t));
|
| 72 |
+
save_suffix_map(fs, str_sfx_it->second);
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
void save_reverse_lookup(std::fstream& fs, const reverse_lookup_t& reverse_lookup)
|
| 78 |
+
{
|
| 79 |
+
// write out number of elements for Lookup
|
| 80 |
+
std::size_t reverse_lookup_count = reverse_lookup.size();
|
| 81 |
+
fs.write((char*)(&reverse_lookup_count), sizeof(reverse_lookup_count));
|
| 82 |
+
for (reverse_lookup_t::const_iterator reverse_lookup_it = reverse_lookup.begin(); reverse_lookup_it != reverse_lookup.end(); ++reverse_lookup_it)
|
| 83 |
+
{
|
| 84 |
+
// write out the key
|
| 85 |
+
size_t key_len = reverse_lookup_it->first.length();
|
| 86 |
+
fs.write((char*)(&key_len), sizeof(key_len));
|
| 87 |
+
fs.write((char*)(reverse_lookup_it->first.data()), key_len * sizeof(wchar_t));
|
| 88 |
+
|
| 89 |
+
// write out value vector length
|
| 90 |
+
size_t val_vec_len = reverse_lookup_it->second.size();
|
| 91 |
+
fs.write((char*)(&val_vec_len), sizeof(val_vec_len));
|
| 92 |
+
|
| 93 |
+
for (suffix_map_vec_t::const_iterator val_vec_it = reverse_lookup_it->second.begin();
|
| 94 |
+
val_vec_it != reverse_lookup_it->second.end();
|
| 95 |
+
++val_vec_it)
|
| 96 |
+
{
|
| 97 |
+
save_suffix_map(fs, *val_vec_it);
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
void load_suffix_map(std::fstream& fs, suffix_map_t& suffix_map)
|
| 103 |
+
{
|
| 104 |
+
// read in number of elements
|
| 105 |
+
std::size_t suffix_map_count = 0;
|
| 106 |
+
fs.read((char*)(&suffix_map_count), sizeof(suffix_map_count));
|
| 107 |
+
for (size_t suffix_map_index = 0; suffix_map_index < suffix_map_count; ++suffix_map_index )
|
| 108 |
+
{
|
| 109 |
+
// read in key
|
| 110 |
+
std::size_t key_len = 0;
|
| 111 |
+
fs.read((char*)(&key_len), sizeof(key_len));
|
| 112 |
+
|
| 113 |
+
std::wstring wkey(key_len, 0);
|
| 114 |
+
fs.read((char*)(wkey.data()), key_len * sizeof(wchar_t));
|
| 115 |
+
uint32_t value = 0;
|
| 116 |
+
fs.read((char*)(&value), sizeof(value));
|
| 117 |
+
|
| 118 |
+
suffix_map.insert(std::make_pair(wkey, value));
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
void load_lookup(std::fstream& fs, lookup_t& lookup)
|
| 123 |
+
{
|
| 124 |
+
// read in number of elements
|
| 125 |
+
std::size_t lookup_count = 0;
|
| 126 |
+
fs.read((char*)(&lookup_count), sizeof(lookup_count));
|
| 127 |
+
for (size_t lookup_index = 0; lookup_index < lookup_count; ++lookup_index)
|
| 128 |
+
{
|
| 129 |
+
std::size_t map_elem_count = 0;
|
| 130 |
+
fs.read((char*)(&map_elem_count), sizeof(map_elem_count));
|
| 131 |
+
|
| 132 |
+
lookup.push_back(string_suffix_map_t());
|
| 133 |
+
string_suffix_map_t& str_sfx_map = lookup.back();
|
| 134 |
+
|
| 135 |
+
for (size_t str_sfx_map_index = 0; str_sfx_map_index < map_elem_count; ++str_sfx_map_index)
|
| 136 |
+
{
|
| 137 |
+
std::size_t key_len = 0;
|
| 138 |
+
fs.read((char*)(&key_len), sizeof(key_len));
|
| 139 |
+
|
| 140 |
+
std::wstring wkey(key_len, 0);
|
| 141 |
+
fs.read((char*)(wkey.data()), key_len * sizeof(wchar_t));
|
| 142 |
+
str_sfx_map.insert(std::make_pair<wstring, suffix_map_t>(std::wstring(wkey), suffix_map_t()));
|
| 143 |
+
suffix_map_t& suffix_map = str_sfx_map[wkey];
|
| 144 |
+
|
| 145 |
+
load_suffix_map(fs, suffix_map);
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
void load_reverse_lookup(std::fstream& fs, reverse_lookup_t& reverse_lookup)
|
| 151 |
+
{
|
| 152 |
+
// read in number of elements
|
| 153 |
+
std::size_t reverse_lookup_count = 0;
|
| 154 |
+
fs.read((char*)(&reverse_lookup_count), sizeof(reverse_lookup_count));
|
| 155 |
+
for (size_t rev_lookup_index = 0; rev_lookup_index < reverse_lookup_count; ++rev_lookup_index )
|
| 156 |
+
{
|
| 157 |
+
// read in the key
|
| 158 |
+
std::size_t key_len = 0;
|
| 159 |
+
fs.read((char*)(&key_len), sizeof(key_len));
|
| 160 |
+
|
| 161 |
+
std::wstring wkey(key_len, 0);
|
| 162 |
+
fs.read((char*)(wkey.data()), key_len * sizeof(wchar_t));
|
| 163 |
+
reverse_lookup.insert(std::make_pair(wkey, suffix_map_vec_t()));
|
| 164 |
+
suffix_map_vec_t& val_vec = reverse_lookup[wkey];
|
| 165 |
+
|
| 166 |
+
std::size_t val_vec_len = 0;
|
| 167 |
+
fs.read((char*)(&val_vec_len), sizeof(val_vec_len));
|
| 168 |
+
|
| 169 |
+
for (size_t val_vec_index = 0; val_vec_index < val_vec_len; ++val_vec_index)
|
| 170 |
+
{
|
| 171 |
+
val_vec.push_back(suffix_map_t());
|
| 172 |
+
suffix_map_t& suffix_map = val_vec.back();
|
| 173 |
+
load_suffix_map(fs, suffix_map);
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
#if ! defined( USE_BOOST )
|
| 179 |
+
|
| 180 |
+
NGramLMBase::NGramLMBase(const string &dataFilePath, token_mapping_t tokenMapping)
|
| 181 |
+
: LanguageModel(move(tokenMapping))
|
| 182 |
+
{
|
| 183 |
+
std::fstream in(dataFilePath, std::ios::in | std::ios::binary);
|
| 184 |
+
load_lookup(in, m_lookup);
|
| 185 |
+
load_reverse_lookup(in, m_reverseLookup);
|
| 186 |
+
|
| 187 |
+
if (m_lookup.size() >= 10) {
|
| 188 |
+
throw runtime_error("Only N-Grams of 9 or less are supported!");
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
for (auto &ngLevel : m_lookup) {
|
| 192 |
+
for (auto &kvPrefixLevel : ngLevel) {
|
| 193 |
+
uint32_t ct = 0;
|
| 194 |
+
for (auto &kvSfx : kvPrefixLevel.second) {
|
| 195 |
+
ct += kvSfx.second;
|
| 196 |
+
}
|
| 197 |
+
m_prefixSumLookup.emplace(kvPrefixLevel.first, ct);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
void save_ngram_data_file(const lookup_t& lookup, const reverse_lookup_t& reverseLookup, const std::string &outputPath)
|
| 203 |
+
{
|
| 204 |
+
std::fstream out(outputPath, std::ios::out | std::ios::binary);
|
| 205 |
+
|
| 206 |
+
save_lookup(out, lookup);
|
| 207 |
+
save_reverse_lookup(out, reverseLookup);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
#else // USE_BOOST
|
| 211 |
+
|
| 212 |
+
NGramLMBase::NGramLMBase(const string &dataFilePath, token_mapping_t tokenMapping)
|
| 213 |
+
: LanguageModel(move(tokenMapping))
|
| 214 |
+
{
|
| 215 |
+
{
|
| 216 |
+
ifstream dfStr(dataFilePath, ios_base::in | ios_base::binary);
|
| 217 |
+
boost::archive::binary_iarchive ia(dfStr);
|
| 218 |
+
|
| 219 |
+
LMStorage s;
|
| 220 |
+
ia >> s;
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
m_lookup = move(s.Lookup);
|
| 224 |
+
|
| 225 |
+
m_reverseLookup = move(s.ReverseLookup);
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
if (m_lookup.size() >= 10) {
|
| 229 |
+
throw runtime_error("Only N-Grams of 9 or less are supported!");
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
for (auto &ngLevel : m_lookup) {
|
| 233 |
+
for (auto &kvPrefixLevel : ngLevel) {
|
| 234 |
+
uint32_t ct = 0;
|
| 235 |
+
for (auto &kvSfx : kvPrefixLevel.second) {
|
| 236 |
+
ct += kvSfx.second;
|
| 237 |
+
}
|
| 238 |
+
m_prefixSumLookup.emplace(kvPrefixLevel.first, ct);
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
void save_ngram_data_file(lookup_t lookup, reverse_lookup_t reverseLookup, const std::string &outputPath)
|
| 244 |
+
{
|
| 245 |
+
ofstream ofs(outputPath, ios_base::out | ios_base::binary);
|
| 246 |
+
|
| 247 |
+
LMStorage s;
|
| 248 |
+
s.Lookup = move(lookup);
|
| 249 |
+
s.ReverseLookup = move(reverseLookup);
|
| 250 |
+
|
| 251 |
+
boost::archive::binary_oarchive oa(ofs);
|
| 252 |
+
oa << s;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
#endif // USE_BOOST
|
| 256 |
+
|
| 257 |
+
float_t NGramLMBase::ScoreTransition(const Prefix *p, token_t nextToken) const
|
| 258 |
+
{
|
| 259 |
+
std::wstring prefix;
|
| 260 |
+
if (! ConvertToString(p, prefix)) {
|
| 261 |
+
return NEG_INF;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
const std::wstring *pSuffix = nullptr;
|
| 265 |
+
|
| 266 |
+
if (nextToken != 1) {
|
| 267 |
+
auto iter = m_tokenMapping.find(nextToken);
|
| 268 |
+
if (iter == m_tokenMapping.end()) {
|
| 269 |
+
pSuffix = &UNMODELED;
|
| 270 |
+
} else {
|
| 271 |
+
pSuffix = &iter->second;
|
| 272 |
+
|
| 273 |
+
if (iswdigit(pSuffix->at(0))) {
|
| 274 |
+
pSuffix = &NUMERIC;
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
} else {
|
| 279 |
+
pSuffix = &WORD_END;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
float_t ret = ScoreTransitionImpl(prefix, *pSuffix);
|
| 283 |
+
|
| 284 |
+
if (ret > 0) {
|
| 285 |
+
return log(ret);
|
| 286 |
+
} else {
|
| 287 |
+
return NEG_INF;
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
bool NGramLMBase::ConvertToString(const Prefix *p, std::wstring &prefix) const
|
| 292 |
+
{
|
| 293 |
+
const Prefix *stk[10];
|
| 294 |
+
int32_t sz = -1;
|
| 295 |
+
const Prefix *curr = p;
|
| 296 |
+
decltype(sz) mlSz{(int)m_lookup.size() - 2};
|
| 297 |
+
while (curr && sz < mlSz) {
|
| 298 |
+
stk[++sz] = curr;
|
| 299 |
+
curr = curr->Parent;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
// Either blank or empty prefix
|
| 303 |
+
if (sz < 1) { return true; }
|
| 304 |
+
|
| 305 |
+
--sz;
|
| 306 |
+
for (; sz >= 0; --sz) {
|
| 307 |
+
token_t tok = stk[sz]->Token;
|
| 308 |
+
// End of word token, which maps to the null character
|
| 309 |
+
if (tok == 1) {
|
| 310 |
+
prefix.push_back(WORD_END[0]);
|
| 311 |
+
} else if (tok == 0) {
|
| 312 |
+
// Do nothing
|
| 313 |
+
} else {
|
| 314 |
+
auto iter = m_tokenMapping.find(tok);
|
| 315 |
+
if (iter == m_tokenMapping.end()) {
|
| 316 |
+
prefix += UNMODELED;
|
| 317 |
+
} else {
|
| 318 |
+
const std::wstring &wChar = iter->second;
|
| 319 |
+
|
| 320 |
+
if (iswdigit(wChar[0])) {
|
| 321 |
+
prefix += NUMERIC;
|
| 322 |
+
} else {
|
| 323 |
+
prefix += wChar;
|
| 324 |
+
}
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
return true;
|
| 330 |
+
}
|
nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.h
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <unordered_map>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#include "language_model.h"
|
| 11 |
+
|
| 12 |
+
// #define USE_BOOST 1
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
typedef std::unordered_map<std::wstring, uint32_t> suffix_map_t;
|
| 16 |
+
|
| 17 |
+
/* Tells us the number of suffixes for a given ngram of order K
|
| 18 |
+
Keys:
|
| 19 |
+
1. NGram Order
|
| 20 |
+
2. Prefix
|
| 21 |
+
3. Suffix
|
| 22 |
+
Value:
|
| 23 |
+
Count
|
| 24 |
+
*/
|
| 25 |
+
typedef std::unordered_map<std::wstring, suffix_map_t> string_suffix_map_t;
|
| 26 |
+
typedef std::vector<string_suffix_map_t> lookup_t;
|
| 27 |
+
/* Tells us the number of K-gram prefixes found for a given suffix
|
| 28 |
+
Keys:
|
| 29 |
+
1. Suffix
|
| 30 |
+
2. NGram Order
|
| 31 |
+
3. Prefix
|
| 32 |
+
Values:
|
| 33 |
+
Count
|
| 34 |
+
*/
|
| 35 |
+
typedef std::vector<suffix_map_t> suffix_map_vec_t;
|
| 36 |
+
typedef std::unordered_map<std::wstring, suffix_map_vec_t> reverse_lookup_t;
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
extern const std::wstring WORD_END;
|
| 41 |
+
extern const std::wstring NUMERIC;
|
| 42 |
+
extern const std::wstring UNMODELED;
|
| 43 |
+
|
| 44 |
+
class NGramLMBase
|
| 45 |
+
: public LanguageModel
|
| 46 |
+
{
|
| 47 |
+
public:
|
| 48 |
+
virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const override;
|
| 49 |
+
|
| 50 |
+
protected:
|
| 51 |
+
NGramLMBase(const std::string &dataFilePath, token_mapping_t tokenMapping);
|
| 52 |
+
|
| 53 |
+
virtual float_t ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const = 0;
|
| 54 |
+
|
| 55 |
+
bool ConvertToString(const Prefix *p, std::wstring &prefix) const;
|
| 56 |
+
|
| 57 |
+
float_t GetPrefixSum(const std::wstring &prefix) const;
|
| 58 |
+
|
| 59 |
+
lookup_t m_lookup;
|
| 60 |
+
reverse_lookup_t m_reverseLookup;
|
| 61 |
+
|
| 62 |
+
std::unordered_map<std::wstring, uint32_t> m_prefixSumLookup;
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
#if ! defined( USE_BOOST )
|
| 66 |
+
void save_ngram_data_file(const lookup_t& lookup, const reverse_lookup_t& reverseLookup, const std::string &output_path);
|
| 67 |
+
#else // USE_BOOST
|
| 68 |
+
void save_ngram_data_file(lookup_t lookup, reverse_lookup_t reverseLookup, const std::string &output_path);
|
| 69 |
+
#endif // USE_BOOST
|
| 70 |
+
|
| 71 |
+
inline float_t NGramLMBase::GetPrefixSum(const std::wstring &prefix) const
|
| 72 |
+
{
|
| 73 |
+
auto iter = m_prefixSumLookup.find(prefix);
|
| 74 |
+
|
| 75 |
+
if (iter == m_prefixSumLookup.end()) {
|
| 76 |
+
return 0;
|
| 77 |
+
} else {
|
| 78 |
+
return iter->second;
|
| 79 |
+
}
|
| 80 |
+
}
|
nemo-retriever-ocr/cpp/beam_decode/prefix.cpp
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "prefix.h"
|
| 6 |
+
|
| 7 |
+
using namespace std;
|
| 8 |
+
|
| 9 |
+
vector<token_t> Prefix::ToList() const
|
| 10 |
+
{
|
| 11 |
+
vector<token_t> ret;
|
| 12 |
+
|
| 13 |
+
auto curr = this;
|
| 14 |
+
|
| 15 |
+
while (curr) {
|
| 16 |
+
if (curr->Token != 0) {
|
| 17 |
+
ret.push_back(curr->Token);
|
| 18 |
+
}
|
| 19 |
+
curr = curr->Parent;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
return { rbegin(ret), rend(ret) };
|
| 23 |
+
}
|
nemo-retriever-ocr/cpp/beam_decode/prefix.h
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <cstdlib>
|
| 8 |
+
#include <memory>
|
| 9 |
+
#include <vector>
|
| 10 |
+
#include <unordered_map>
|
| 11 |
+
#include <list>
|
| 12 |
+
|
| 13 |
+
typedef int32_t token_t;
|
| 14 |
+
|
| 15 |
+
class Prefix;
|
| 16 |
+
|
| 17 |
+
// typedef std::shared_ptr<Prefix> PrefixPtr;
|
| 18 |
+
|
| 19 |
+
class Prefix
|
| 20 |
+
{
|
| 21 |
+
public:
|
| 22 |
+
token_t Token;
|
| 23 |
+
Prefix *Parent;
|
| 24 |
+
|
| 25 |
+
Prefix(token_t token = 0 /* blank */, Prefix *parent = nullptr)
|
| 26 |
+
: Token(token), Parent(parent)
|
| 27 |
+
{}
|
| 28 |
+
|
| 29 |
+
std::vector<token_t> ToList() const;
|
| 30 |
+
|
| 31 |
+
size_t size() const;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
///// Borrowed from Boost libraries
|
| 36 |
+
template<typename T>
|
| 37 |
+
void hash_combine(size_t & seed, T const& v)
|
| 38 |
+
{
|
| 39 |
+
seed ^= std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
| 40 |
+
}
|
| 41 |
+
/////
|
| 42 |
+
|
| 43 |
+
namespace std {
|
| 44 |
+
template<>
|
| 45 |
+
struct hash<Prefix*>
|
| 46 |
+
{
|
| 47 |
+
size_t operator()(const Prefix *p) const noexcept
|
| 48 |
+
{
|
| 49 |
+
size_t seed = 0;
|
| 50 |
+
|
| 51 |
+
while (p) {
|
| 52 |
+
if (p->Token != 0) {
|
| 53 |
+
hash_combine(seed, p->Token);
|
| 54 |
+
}
|
| 55 |
+
p = p->Parent;
|
| 56 |
+
}
|
| 57 |
+
return seed;
|
| 58 |
+
}
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
template<>
|
| 62 |
+
struct hash<tuple<Prefix*, token_t>>
|
| 63 |
+
{
|
| 64 |
+
size_t operator()(const tuple<Prefix*, token_t> &t) const noexcept
|
| 65 |
+
{
|
| 66 |
+
size_t seed = 0;
|
| 67 |
+
hash_combine(seed, get<0>(t));
|
| 68 |
+
hash_combine(seed, get<1>(t));
|
| 69 |
+
return seed;
|
| 70 |
+
}
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template<>
|
| 74 |
+
struct equal_to<Prefix*>
|
| 75 |
+
{
|
| 76 |
+
bool operator()(const Prefix *a, const Prefix *b) const noexcept
|
| 77 |
+
{
|
| 78 |
+
while (a != nullptr && b != nullptr) {
|
| 79 |
+
if (a->Token != b->Token) {
|
| 80 |
+
return false;
|
| 81 |
+
}
|
| 82 |
+
a = a->Parent;
|
| 83 |
+
b = b->Parent;
|
| 84 |
+
}
|
| 85 |
+
// If one chain is shorter than the other
|
| 86 |
+
return a == b;
|
| 87 |
+
}
|
| 88 |
+
};
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
inline size_t Prefix::size() const
|
| 92 |
+
{
|
| 93 |
+
size_t ret = 0;
|
| 94 |
+
auto p = this;
|
| 95 |
+
while (p != nullptr) {
|
| 96 |
+
ret += 1;
|
| 97 |
+
p = p->Parent;
|
| 98 |
+
}
|
| 99 |
+
return ret;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class PrefixAllocator
|
| 104 |
+
{
|
| 105 |
+
public:
|
| 106 |
+
PrefixAllocator() = default;
|
| 107 |
+
~PrefixAllocator();
|
| 108 |
+
|
| 109 |
+
template<typename ...Args>
|
| 110 |
+
Prefix *GetPrefix(Args&& ...ctorArgs);
|
| 111 |
+
|
| 112 |
+
private:
|
| 113 |
+
void AllocateNextBuffer();
|
| 114 |
+
|
| 115 |
+
std::list<Prefix*> m_buffers;
|
| 116 |
+
size_t m_allocSize = 0;
|
| 117 |
+
size_t m_currOff = 0;
|
| 118 |
+
};
|
| 119 |
+
|
| 120 |
+
inline PrefixAllocator::~PrefixAllocator()
|
| 121 |
+
{
|
| 122 |
+
for (auto p : m_buffers) {
|
| 123 |
+
// Prefix is a POD, and are allocated without initializing
|
| 124 |
+
// to prevent redundant work upfront
|
| 125 |
+
// delete[] p;
|
| 126 |
+
free(p);
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
inline void PrefixAllocator::AllocateNextBuffer()
|
| 131 |
+
{
|
| 132 |
+
size_t nextSize = m_allocSize == 0 ? 1000 : 2 * m_allocSize;
|
| 133 |
+
|
| 134 |
+
// Using malloc here to prevent the ctor of Prefix being called for each item.
|
| 135 |
+
// Instead, the ctor will be called upon first access using GetPrefix
|
| 136 |
+
auto pBuff = reinterpret_cast<Prefix*>(malloc(sizeof(Prefix) * nextSize));
|
| 137 |
+
|
| 138 |
+
m_buffers.push_back(pBuff);
|
| 139 |
+
|
| 140 |
+
m_allocSize = nextSize;
|
| 141 |
+
m_currOff = 0;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
template<typename ...Args>
|
| 145 |
+
Prefix *PrefixAllocator::GetPrefix(Args&& ...ctorArgs)
|
| 146 |
+
{
|
| 147 |
+
if (m_currOff == m_allocSize) {
|
| 148 |
+
AllocateNextBuffer();
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
auto buff = m_buffers.back() + m_currOff;
|
| 152 |
+
|
| 153 |
+
auto ret = new (buff) Prefix(std::forward<Args>(ctorArgs)...);
|
| 154 |
+
|
| 155 |
+
++m_currOff;
|
| 156 |
+
|
| 157 |
+
return ret;
|
| 158 |
+
}
|
nemo-retriever-ocr/cpp/beam_decode/sbo_lm.cpp
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "sbo_lm.h"
|
| 6 |
+
|
| 7 |
+
#include <assert.h>
|
| 8 |
+
|
| 9 |
+
// Reference paper: https://www.aclweb.org/anthology/D07-1090.pdf
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
SBO_LanguageModel::SBO_LanguageModel(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoff)
|
| 13 |
+
: NGramLMBase(dataFilePath, move(tokenMapping)), m_backoff(backoff)
|
| 14 |
+
{
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
float SBO_LanguageModel::ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const
|
| 18 |
+
{
|
| 19 |
+
auto lIter = m_lookup[prefix.size() + 1].find(prefix);
|
| 20 |
+
|
| 21 |
+
// This prefix doesn't exist. Shrink it!
|
| 22 |
+
if (lIter == m_lookup[prefix.size() + 1].end()) {
|
| 23 |
+
return m_backoff * ScoreTransitionImpl({ begin(prefix) + 1, end(prefix) }, suffix);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
const suffix_map_t &suffixMap = lIter->second;
|
| 27 |
+
|
| 28 |
+
auto sfIter = suffixMap.find(suffix);
|
| 29 |
+
|
| 30 |
+
if (sfIter == suffixMap.end()) {
|
| 31 |
+
// This is a novel character entirely!
|
| 32 |
+
if (prefix.empty()) {
|
| 33 |
+
return 1e-8;
|
| 34 |
+
} else {
|
| 35 |
+
return m_backoff * ScoreTransitionImpl({ begin(prefix) + 1, end(prefix) }, suffix);
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
float_t ctSuffix = sfIter->second;
|
| 40 |
+
float_t ctNgram = GetPrefixSum(prefix);
|
| 41 |
+
|
| 42 |
+
float_t score = ctSuffix / ctNgram;
|
| 43 |
+
|
| 44 |
+
assert(score >= 0 && score <= 1);
|
| 45 |
+
|
| 46 |
+
return score;
|
| 47 |
+
}
|
nemo-retriever-ocr/cpp/beam_decode/sbo_lm.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include "kn_lm.h"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SBO_LanguageModel
|
| 11 |
+
: public NGramLMBase
|
| 12 |
+
{
|
| 13 |
+
public:
|
| 14 |
+
SBO_LanguageModel(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoff);
|
| 15 |
+
|
| 16 |
+
protected:
|
| 17 |
+
virtual float_t ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const override;
|
| 18 |
+
|
| 19 |
+
private:
|
| 20 |
+
float_t m_backoff;
|
| 21 |
+
};
|
nemo-retriever-ocr/cpp/better_grid_sample/cpu_indirect_grid_sample.cpp
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "grid_sample.h"
|
| 6 |
+
#include "gpu_grid_sample_utils.cuh"
|
| 7 |
+
|
| 8 |
+
template<typename T>
|
| 9 |
+
void indirect_grid_sample_forward_bilinear(torch::TensorAccessor<T, 4> input,
|
| 10 |
+
torch::TensorAccessor<T, 4> grid,
|
| 11 |
+
torch::TensorAccessor<int64_t, 1> inputIndices,
|
| 12 |
+
torch::TensorAccessor<T, 4> output)
|
| 13 |
+
{
|
| 14 |
+
const int64_t N = inputIndices.size(0);
|
| 15 |
+
const int64_t C = output.size(1);
|
| 16 |
+
|
| 17 |
+
T fInputHeight = input.size(2);
|
| 18 |
+
T fInputWidth = input.size(3);
|
| 19 |
+
int64_t outputHeight = output.size(2);
|
| 20 |
+
int64_t outputWidth = output.size(3);
|
| 21 |
+
|
| 22 |
+
#pragma omp parallel for num_threads(8)
|
| 23 |
+
for (int64_t i = 0; i < N; ++i) {
|
| 24 |
+
int64_t inputIdx = inputIndices[i];
|
| 25 |
+
|
| 26 |
+
for (int64_t c = 0; c < C; ++c) {
|
| 27 |
+
for (int64_t outY = 0; outY < outputHeight; ++outY) {
|
| 28 |
+
for (int64_t outX = 0; outX < outputWidth; ++outX) {
|
| 29 |
+
T u = grid[i][outY][outX][0];
|
| 30 |
+
T v = grid[i][outY][outX][1];
|
| 31 |
+
|
| 32 |
+
if (u < -1 || u > 1 || v < -1 || v > 1) {
|
| 33 |
+
output[i][c][outY][outX] = 0;
|
| 34 |
+
continue;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Denormalize the coordinates
|
| 38 |
+
u = (u + 1) * ((fInputWidth - 1) / 2);
|
| 39 |
+
v = (v + 1) * ((fInputHeight - 1) / 2);
|
| 40 |
+
|
| 41 |
+
// Calculate coordinates
|
| 42 |
+
const T inX = u;
|
| 43 |
+
const T inXint = std::floor(inX);
|
| 44 |
+
const T inXfrac = inX - inXint;
|
| 45 |
+
|
| 46 |
+
const T inY = v;
|
| 47 |
+
const T inYint = std::floor(inY);
|
| 48 |
+
const T inYfrac = inY - inYint;
|
| 49 |
+
|
| 50 |
+
T ps[] = { 1 - inXfrac, inXfrac };
|
| 51 |
+
T rs[] = { 1 - inYfrac, inYfrac };
|
| 52 |
+
T opVal = 0;
|
| 53 |
+
|
| 54 |
+
#pragma unroll
|
| 55 |
+
for (int64_t row = 0; row < 2; ++row) {
|
| 56 |
+
#pragma unroll
|
| 57 |
+
for (int64_t col = 0; col < 2; ++col) {
|
| 58 |
+
T Tpx = utils::get_pixel_clamped(input, inputIdx, c, inXint + col, inYint + row);
|
| 59 |
+
opVal += rs[row] * ps[col] * Tpx;
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
output[i][c][outY][outX] = opVal;
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
torch::Tensor cpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid,
|
| 71 |
+
torch::Tensor inputIndices, const std::string &method)
|
| 72 |
+
{
|
| 73 |
+
auto output = input.new_empty({ inputIndices.size(0), input.size(1), grid.size(1), grid.size(2) });
|
| 74 |
+
|
| 75 |
+
AT_DISPATCH_FLOATING_TYPES(
|
| 76 |
+
input.scalar_type(),
|
| 77 |
+
"cpu_indirect_grid_sample_forward_impl",
|
| 78 |
+
([&] {
|
| 79 |
+
typedef scalar_t T;
|
| 80 |
+
if (method == "bilinear") {
|
| 81 |
+
indirect_grid_sample_forward_bilinear(
|
| 82 |
+
input.accessor<T, 4>(),
|
| 83 |
+
grid.accessor<T, 4>(),
|
| 84 |
+
inputIndices.accessor<int64_t, 1>(),
|
| 85 |
+
output.accessor<T, 4>()
|
| 86 |
+
);
|
| 87 |
+
} else {
|
| 88 |
+
throw std::runtime_error("Unsupported resample method: " + method);
|
| 89 |
+
}
|
| 90 |
+
})
|
| 91 |
+
);
|
| 92 |
+
|
| 93 |
+
return output;
|
| 94 |
+
}
|
nemo-retriever-ocr/cpp/better_grid_sample/gpu_grid_sample_utils.cuh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <torch/torch.h>
|
| 8 |
+
|
| 9 |
+
#include "../cuda_intellisense.cuh"
|
| 10 |
+
|
| 11 |
+
#ifndef __NVCC__
|
| 12 |
+
#include <algorithm>
|
| 13 |
+
#define __device__
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
+
namespace utils {
|
| 17 |
+
|
| 18 |
+
#ifdef __NVCC__
|
| 19 |
+
|
| 20 |
+
template<typename T>
|
| 21 |
+
__device__ __lib_inline__
|
| 22 |
+
T clamp(T val, T minVal, T maxVal)
|
| 23 |
+
{
|
| 24 |
+
return max(minVal, min(val, maxVal));
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
#else
|
| 28 |
+
using std::clamp;
|
| 29 |
+
#endif
|
| 30 |
+
|
| 31 |
+
template<typename accessor_t>
|
| 32 |
+
__device__ __lib_inline__
|
| 33 |
+
auto &get_pixel_clamped(accessor_t &inputs,
|
| 34 |
+
int64_t n, int64_t c, int64_t x, int64_t y)
|
| 35 |
+
{
|
| 36 |
+
x = clamp<decltype(x)>(x, 0, inputs.size(3) - 1);
|
| 37 |
+
y = clamp<decltype(y)>(y, 0, inputs.size(2) - 1);
|
| 38 |
+
|
| 39 |
+
return inputs[n][c][y][x];
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
}
|
nemo-retriever-ocr/cpp/better_grid_sample/gpu_indirect_grid_sample.cu
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "grid_sample.h"
|
| 6 |
+
|
| 7 |
+
#include "../cuda_intellisense.cuh"
|
| 8 |
+
#include "../half_ops.cuh"
|
| 9 |
+
#include "gpu_grid_sample_utils.cuh"
|
| 10 |
+
|
| 11 |
+
using namespace std;
|
| 12 |
+
|
| 13 |
+
template<typename accessor_t, typename index_t>
|
| 14 |
+
__device__ __lib_inline__
|
| 15 |
+
auto &my_get_pixel_clamped(accessor_t &inputs, index_t x, index_t y)
|
| 16 |
+
{
|
| 17 |
+
x = utils::clamp(x, 0, inputs.size(1) - 1);
|
| 18 |
+
y = utils::clamp(y, 0, inputs.size(0) - 1);
|
| 19 |
+
|
| 20 |
+
return inputs[y][x];
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
__global__
|
| 24 |
+
void single_ex_grid_sample_bilinear_kernel(const float *pInputImage,
|
| 25 |
+
uint32_t imgHeight, uint32_t imgWidth, uint32_t numChannels,
|
| 26 |
+
const float2 *pGrid,
|
| 27 |
+
uint32_t numGridCells,
|
| 28 |
+
float *pOutputImage)
|
| 29 |
+
{
|
| 30 |
+
const uint32_t z = blockDim.x * blockIdx.x + threadIdx.x;
|
| 31 |
+
const uint32_t c = blockDim.y * blockIdx.y + threadIdx.y;
|
| 32 |
+
|
| 33 |
+
if (c >= numChannels || z >= numGridCells) {
|
| 34 |
+
return;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
const uint32_t g = blockIdx.z;
|
| 38 |
+
|
| 39 |
+
const float2 uv = pGrid[g * numGridCells + z];
|
| 40 |
+
|
| 41 |
+
float &outPx = pOutputImage[(g * numChannels + c) * numGridCells + z];
|
| 42 |
+
if (abs(uv.x) > 1.0f || abs(uv.y) > 1.0f) {
|
| 43 |
+
outPx = 0.0f;
|
| 44 |
+
} else {
|
| 45 |
+
const uint32_t maxX = imgWidth - 1;
|
| 46 |
+
const uint32_t maxY = imgHeight - 1;
|
| 47 |
+
|
| 48 |
+
const float u = (uv.x + 1.0f) * maxX * 0.5f;
|
| 49 |
+
const float v = (uv.y + 1.0f) * maxY * 0.5f;
|
| 50 |
+
|
| 51 |
+
// calculate coordinates
|
| 52 |
+
const float inX = u;
|
| 53 |
+
const uint32_t inXint = inX;
|
| 54 |
+
const float inXfrac = inX - inXint;
|
| 55 |
+
|
| 56 |
+
const float inY = v;
|
| 57 |
+
const uint32_t inYint = inY;
|
| 58 |
+
const float inYfrac = inY - inYint;
|
| 59 |
+
|
| 60 |
+
const float *pChanImage = pInputImage + c * imgHeight * imgWidth;
|
| 61 |
+
|
| 62 |
+
// By being in this conditional block, we know that u and v are >= 0, which means
|
| 63 |
+
// that their truncated value is also >= 0. Instead of clamping the value to within the buffer,
|
| 64 |
+
// we set the multiplication factor to be 0 if the interpolated value is outside the buffer
|
| 65 |
+
const float ps[] = { 1.0f - inXfrac, inXfrac * (inXint < maxX) };
|
| 66 |
+
const float rs[] = { 1.0f - inYfrac, inYfrac * (inYint < maxY) };
|
| 67 |
+
float opVal = 0.0f;
|
| 68 |
+
#pragma unroll
|
| 69 |
+
for (uint32_t row = 0; row < 2; ++row) {
|
| 70 |
+
const float *pRowImage = pChanImage + (inYint + row) * imgWidth;
|
| 71 |
+
|
| 72 |
+
#pragma unroll
|
| 73 |
+
for (uint32_t col = 0; col < 2; ++col) {
|
| 74 |
+
const float px = pRowImage[inXint + col];
|
| 75 |
+
opVal += rs[row] * ps[col] * px;
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
outPx = opVal;
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template<typename T>
|
| 84 |
+
__global__
|
| 85 |
+
void indirect_grid_sample_forward_bilinear_kernel(torch::PackedTensorAccessor32<T, 4> inputs,
|
| 86 |
+
torch::PackedTensorAccessor32<T, 4> grid,
|
| 87 |
+
torch::PackedTensorAccessor32<int64_t, 1> inputIndices,
|
| 88 |
+
torch::PackedTensorAccessor32<T, 4> outputs)
|
| 89 |
+
{
|
| 90 |
+
static_assert(std::is_same<T, float>::value, "Currently only float32 is supported!");
|
| 91 |
+
//typedef typename fp_promote<T>::type accum_t;
|
| 92 |
+
typedef float accum_t;
|
| 93 |
+
constexpr T NEG_ONE = -1;
|
| 94 |
+
constexpr T ONE = 1;
|
| 95 |
+
constexpr T ZERO = 0;
|
| 96 |
+
constexpr T TWO = 2;
|
| 97 |
+
constexpr T ZERO_PT_5 = 0.5;
|
| 98 |
+
typedef decltype(inputs.stride(0)) index_t;
|
| 99 |
+
|
| 100 |
+
const index_t n = blockDim.z * blockIdx.z + threadIdx.z;
|
| 101 |
+
|
| 102 |
+
if (n >= inputIndices.size(0)) return;
|
| 103 |
+
|
| 104 |
+
const index_t c = blockDim.y * blockIdx.y + threadIdx.y;
|
| 105 |
+
|
| 106 |
+
const index_t z = blockDim.x * blockIdx.x + threadIdx.x;
|
| 107 |
+
|
| 108 |
+
const accum_t inputHeight = inputs.size(2);
|
| 109 |
+
const accum_t inputWidth = inputs.size(3);
|
| 110 |
+
const index_t outputHeight = outputs.size(2);
|
| 111 |
+
const index_t outputWidth = outputs.size(3);
|
| 112 |
+
|
| 113 |
+
const index_t outY = z / outputWidth;
|
| 114 |
+
//const index_t outX = z % outputWidth;
|
| 115 |
+
const index_t outX = z - (outY * outputWidth);
|
| 116 |
+
|
| 117 |
+
if (outY >= outputHeight) return;
|
| 118 |
+
|
| 119 |
+
index_t inputIdx = inputIndices[n];
|
| 120 |
+
const float2 f2uv = *reinterpret_cast<const float2*>(grid[n][outY][outX].data());
|
| 121 |
+
float u = f2uv.x;
|
| 122 |
+
float v = f2uv.y;
|
| 123 |
+
|
| 124 |
+
if (u < NEG_ONE || u > ONE || v < NEG_ONE || v > ONE) {
|
| 125 |
+
outputs[n][c][outY][outX] = ZERO;
|
| 126 |
+
return;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Denormalize the coordinates
|
| 130 |
+
u = (u + ONE) * ((inputWidth - ONE) * ZERO_PT_5);
|
| 131 |
+
v = (v + ONE) * ((inputHeight - ONE) * ZERO_PT_5);
|
| 132 |
+
|
| 133 |
+
// calculate coordinates
|
| 134 |
+
const accum_t inX = u;
|
| 135 |
+
const index_t inXint = inX;
|
| 136 |
+
const accum_t inXfrac = inX - inXint;
|
| 137 |
+
|
| 138 |
+
const accum_t inY = v;
|
| 139 |
+
const index_t inYint = inY;
|
| 140 |
+
const accum_t inYfrac = inY - inYint;
|
| 141 |
+
|
| 142 |
+
accum_t ps[] = { ONE - inXfrac, inXfrac };
|
| 143 |
+
accum_t rs[] = { ONE - inYfrac, inYfrac };
|
| 144 |
+
accum_t opVal = ZERO;
|
| 145 |
+
|
| 146 |
+
auto localInputs = inputs[inputIdx][c];
|
| 147 |
+
|
| 148 |
+
#pragma unroll
|
| 149 |
+
for (index_t row = 0; row < 2; ++row) {
|
| 150 |
+
#pragma unroll
|
| 151 |
+
for (index_t col = 0; col < 2; ++col) {
|
| 152 |
+
T Tpx = my_get_pixel_clamped(localInputs, inXint + col, inYint + row);
|
| 153 |
+
opVal += rs[row] * ps[col] * Convert<T, accum_t>::LeftToRight(Tpx);
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
outputs[n][c][outY][outX] = Convert<T, accum_t>::RightToLeft(opVal);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
template<typename T>
|
| 161 |
+
__global__
|
| 162 |
+
void indirect_grid_sample_backward_bilinear_kernel(torch::PackedTensorAccessor64<T, 4> inputs,
|
| 163 |
+
torch::PackedTensorAccessor64<T, 4> grid,
|
| 164 |
+
torch::PackedTensorAccessor64<int64_t, 1> inputIndices,
|
| 165 |
+
torch::PackedTensorAccessor64<T, 4> gradOutput,
|
| 166 |
+
torch::PackedTensorAccessor64<T, 4> gradInput,
|
| 167 |
+
torch::PackedTensorAccessor64<T, 4> gradGrid)
|
| 168 |
+
{
|
| 169 |
+
typedef typename fp_promote<T>::type accum_t;
|
| 170 |
+
constexpr T NEG_ONE = -1;
|
| 171 |
+
constexpr T ONE = 1;
|
| 172 |
+
|
| 173 |
+
const int64_t n = blockDim.z * blockIdx.z + threadIdx.z;
|
| 174 |
+
|
| 175 |
+
if (n >= inputIndices.size(0)) return;
|
| 176 |
+
|
| 177 |
+
const int64_t c = blockDim.y * blockIdx.y + threadIdx.y;
|
| 178 |
+
|
| 179 |
+
const int64_t z = blockDim.x * blockIdx.x + threadIdx.x;
|
| 180 |
+
|
| 181 |
+
const accum_t inputHeight = inputs.size(2);
|
| 182 |
+
const accum_t inputWidth = inputs.size(3);
|
| 183 |
+
const int64_t outputHeight = gradOutput.size(2);
|
| 184 |
+
const int64_t outputWidth = gradOutput.size(3);
|
| 185 |
+
|
| 186 |
+
const int64_t outY = z / outputWidth;
|
| 187 |
+
const int64_t outX = z % outputWidth;
|
| 188 |
+
|
| 189 |
+
if (outY >= outputHeight) return;
|
| 190 |
+
|
| 191 |
+
int64_t inputIdx = inputIndices[n];
|
| 192 |
+
const float2 f2uv = *reinterpret_cast<const float2*>(grid[n][outY][outX].data());
|
| 193 |
+
float u = f2uv.x;
|
| 194 |
+
float v = f2uv.y;
|
| 195 |
+
|
| 196 |
+
// No output gradient contribution from this position
|
| 197 |
+
if (u < NEG_ONE || u > ONE || v < NEG_ONE || v > ONE) {
|
| 198 |
+
return;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
// Denormalize the coordinates
|
| 202 |
+
u = (u + 1) * ((inputWidth - 1) / 2);
|
| 203 |
+
v = (v + 1) * ((inputHeight - 1) / 2);
|
| 204 |
+
|
| 205 |
+
// calculate coordinates
|
| 206 |
+
const accum_t inX = u;
|
| 207 |
+
const accum_t inXint = floor(inX);
|
| 208 |
+
const accum_t inXfrac = inX - inXint;
|
| 209 |
+
|
| 210 |
+
const accum_t inY = v;
|
| 211 |
+
const accum_t inYint = floor(inY);
|
| 212 |
+
const accum_t inYfrac = inY - inYint;
|
| 213 |
+
|
| 214 |
+
accum_t ps[] = { 1 - inXfrac, inXfrac };
|
| 215 |
+
accum_t rs[] = { 1 - inYfrac, inYfrac };
|
| 216 |
+
|
| 217 |
+
const accum_t gOut = Convert<T, accum_t>::LeftToRight(gradOutput[n][c][outY][outX]);
|
| 218 |
+
|
| 219 |
+
#pragma unroll
|
| 220 |
+
for (size_t row = 0; row < 2; ++row) {
|
| 221 |
+
#pragma unroll
|
| 222 |
+
for (size_t col = 0; col < 2; ++col) {
|
| 223 |
+
T &gIn = utils::get_pixel_clamped(gradInput, inputIdx, c, inXint + col, inYint + row);
|
| 224 |
+
|
| 225 |
+
T gContrib = Convert<T, accum_t>::RightToLeft(rs[row] * ps[col] * gOut);
|
| 226 |
+
|
| 227 |
+
atomicAdd(&gIn, gContrib);
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
torch::Tensor gpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method)
|
| 233 |
+
{
|
| 234 |
+
auto output = input.new_empty({ inputIndices.size(0), input.size(1), grid.size(1), grid.size(2) });
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if (method != "bilinear"s) {
|
| 238 |
+
throw runtime_error("Only 'bilinear' sampling is currently supported!");
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
if (input.size(0) == 1 && input.is_contiguous() && grid.is_contiguous()) {
|
| 242 |
+
uint32_t gridNumCells = grid.size(1) * grid.size(2);
|
| 243 |
+
dim3 blockDim(32, 3, 1);
|
| 244 |
+
dim3 gridDim(div_up(gridNumCells, blockDim.x),
|
| 245 |
+
div_up(input.size(1), blockDim.y),
|
| 246 |
+
div_up(grid.size(0), blockDim.z));
|
| 247 |
+
single_ex_grid_sample_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) (
|
| 248 |
+
input.data_ptr<float>(),
|
| 249 |
+
input.size(2), input.size(3), input.size(1),
|
| 250 |
+
reinterpret_cast<const float2*>(grid.data_ptr()),
|
| 251 |
+
gridNumCells,
|
| 252 |
+
output.data_ptr<float>()
|
| 253 |
+
);
|
| 254 |
+
|
| 255 |
+
} else {
|
| 256 |
+
// z is batch idx
|
| 257 |
+
// y is channel
|
| 258 |
+
// x is w*h
|
| 259 |
+
dim3 blockDim(32, 1, 3);
|
| 260 |
+
dim3 gridDim(div_up(grid.size(1) * grid.size(2), blockDim.x),
|
| 261 |
+
div_up(input.size(1), blockDim.y),
|
| 262 |
+
div_up(inputIndices.size(0), blockDim.z));
|
| 263 |
+
indirect_grid_sample_forward_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) (
|
| 264 |
+
input.packed_accessor32<float, 4>(),
|
| 265 |
+
grid.packed_accessor32<float, 4>(),
|
| 266 |
+
inputIndices.packed_accessor32<int64_t, 1>(),
|
| 267 |
+
output.packed_accessor32<float, 4>()
|
| 268 |
+
);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
//AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 272 |
+
// input.scalar_type(),
|
| 273 |
+
// "gpu_indirect_grid_sample_forward",
|
| 274 |
+
// ([&] {
|
| 275 |
+
// typedef typename remap_half<scalar_t>::type T;
|
| 276 |
+
// // typedef scalar_t T;
|
| 277 |
+
// if (method == "bilinear") {
|
| 278 |
+
// indirect_grid_sample_forward_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) (
|
| 279 |
+
// input.packed_accessor64<T, 4>(),
|
| 280 |
+
// grid.packed_accessor64<T, 4>(),
|
| 281 |
+
// inputIndices.packed_accessor64<int64_t, 1>(),
|
| 282 |
+
// output.packed_accessor64<T, 4>()
|
| 283 |
+
// );
|
| 284 |
+
// } else {
|
| 285 |
+
// throw runtime_error("Unsupported resample method: " + method);
|
| 286 |
+
// }
|
| 287 |
+
// })
|
| 288 |
+
//);
|
| 289 |
+
|
| 290 |
+
return output;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
std::vector<torch::Tensor> gpu_indirect_grad_sample_backward(torch::Tensor gradOutput, torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method)
|
| 294 |
+
{
|
| 295 |
+
auto gradInput = torch::zeros_like(input);
|
| 296 |
+
auto gradGrid = torch::zeros_like(grid);
|
| 297 |
+
|
| 298 |
+
// z is batch idx
|
| 299 |
+
// y is channel
|
| 300 |
+
// x is w*h
|
| 301 |
+
dim3 blockDim(32, 1, 1);
|
| 302 |
+
dim3 gridDim(div_up(grid.size(1) * grid.size(2), blockDim.x),
|
| 303 |
+
div_up(input.size(1), blockDim.y),
|
| 304 |
+
div_up(inputIndices.size(0), blockDim.z));
|
| 305 |
+
|
| 306 |
+
AT_DISPATCH_FLOATING_TYPES(
|
| 307 |
+
input.scalar_type(),
|
| 308 |
+
"gpu_indirect_grid_sample_backward",
|
| 309 |
+
([&] {
|
| 310 |
+
typedef typename remap_half<scalar_t>::type T;
|
| 311 |
+
// typedef scalar_t T;
|
| 312 |
+
if (method == "bilinear") {
|
| 313 |
+
indirect_grid_sample_backward_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) (
|
| 314 |
+
input.packed_accessor64<T, 4>(),
|
| 315 |
+
grid.packed_accessor64<T, 4>(),
|
| 316 |
+
inputIndices.packed_accessor64<int64_t, 1>(),
|
| 317 |
+
gradOutput.packed_accessor64<T, 4>(),
|
| 318 |
+
gradInput.packed_accessor64<T, 4>(),
|
| 319 |
+
gradGrid.packed_accessor64<T, 4>()
|
| 320 |
+
);
|
| 321 |
+
} else {
|
| 322 |
+
throw runtime_error("Unsupported resample method: " + method);
|
| 323 |
+
}
|
| 324 |
+
})
|
| 325 |
+
);
|
| 326 |
+
|
| 327 |
+
return { gradInput, gradGrid };
|
| 328 |
+
}
|
nemo-retriever-ocr/cpp/better_grid_sample/grid_sample.h
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <torch/torch.h>
|
| 8 |
+
|
| 9 |
+
inline
|
| 10 |
+
torch::Tensor region_counts_to_indices(torch::Tensor regionCounts, int64_t numOutputs)
|
| 11 |
+
{
|
| 12 |
+
// If there's only one example, we can trivially return idx 0 for all
|
| 13 |
+
if (regionCounts.size(0) == 1) {
|
| 14 |
+
return torch::zeros({ numOutputs }, regionCounts.options().dtype(torch::kInt64));
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
// regionCounts will be some tensor like [ 5, 1, 10, 2 ] which means that the first 5 outputs
|
| 18 |
+
// correspond to the first input, the next output to the second input, 10 to the third, and so on.
|
| 19 |
+
|
| 20 |
+
// We want to convert this to instead have an entry for each output which specifies the index of the corresponding input.
|
| 21 |
+
// To do this, we can count the number of times the output index exceeds the cumulative input counts.
|
| 22 |
+
// e.g. the cumulative region count for the above tensor is [ 5, 6, 16, 18 ].
|
| 23 |
+
// The output indices 0-4 are not greater than or equal to any cumulative count, so they get the input index of 0.
|
| 24 |
+
// The output index 5 is equal to a single count, therefore index 1.
|
| 25 |
+
// The outputs 6-15 are all greater than or equal to two cumulative counts, therefore index 2.
|
| 26 |
+
// And so on.
|
| 27 |
+
|
| 28 |
+
auto indices = torch::arange(regionCounts.size(0), regionCounts.options().dtype(torch::kInt64));
|
| 29 |
+
|
| 30 |
+
auto outputIndices = torch::repeat_interleave(indices, regionCounts, /*dim=*/ 0, /*output_size=*/ numOutputs);
|
| 31 |
+
|
| 32 |
+
return outputIndices;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
torch::Tensor gpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method);
|
| 36 |
+
torch::Tensor cpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method);
|
| 37 |
+
std::vector<torch::Tensor> gpu_indirect_grad_sample_backward(torch::Tensor gradOutput, torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method);
|
| 38 |
+
|
| 39 |
+
inline
|
| 40 |
+
torch::Tensor indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method)
|
| 41 |
+
{
|
| 42 |
+
if (input.is_cuda() != grid.is_cuda() || input.is_cuda() != inputIndices.is_cuda()) {
|
| 43 |
+
throw std::runtime_error("Input tensors must all be on the same device!");
|
| 44 |
+
}
|
| 45 |
+
if (inputIndices.size(0) != grid.size(0)) {
|
| 46 |
+
throw std::runtime_error("The batch dimensions must match!");
|
| 47 |
+
}
|
| 48 |
+
if (grid.size(-1) != 2) {
|
| 49 |
+
throw std::runtime_error("The final grid dimension must be 2.");
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
if (input.is_cuda()) {
|
| 53 |
+
return gpu_indirect_grid_sample_forward(std::move(input), std::move(grid), std::move(inputIndices), method);
|
| 54 |
+
} else {
|
| 55 |
+
return cpu_indirect_grid_sample_forward(std::move(input), std::move(grid), std::move(inputIndices), method);
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
inline
|
| 60 |
+
std::vector<torch::Tensor> indirect_grad_sample_backward(torch::Tensor gradOutput, torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method)
|
| 61 |
+
{
|
| 62 |
+
if (gradOutput.is_cuda()) {
|
| 63 |
+
return gpu_indirect_grad_sample_backward(std::move(gradOutput), std::move(input), std::move(grid), std::move(inputIndices), method);
|
| 64 |
+
} else {
|
| 65 |
+
throw std::runtime_error("Not implemented!");
|
| 66 |
+
}
|
| 67 |
+
}
|
nemo-retriever-ocr/cpp/common.cpp
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "common.h"
|
| 6 |
+
|
| 7 |
+
#include <sstream>
|
| 8 |
+
|
| 9 |
+
using namespace std;
|
| 10 |
+
|
| 11 |
+
void print_tensor(const torch::Tensor &t) {
|
| 12 |
+
cout << t << endl;
|
| 13 |
+
}
|
nemo-retriever-ocr/cpp/common.h
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <ostream>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#include <torch/torch.h>
|
| 11 |
+
|
| 12 |
+
template<typename T>
|
| 13 |
+
inline
|
| 14 |
+
std::ostream &operator<<(std::ostream &os, const std::vector<T> &v) {
|
| 15 |
+
os << "[";
|
| 16 |
+
if (! v.empty()) {
|
| 17 |
+
os << v[0];
|
| 18 |
+
for (size_t i = 1; i < v.size(); ++i) {
|
| 19 |
+
os << ", " << v[i];
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
os << "]";
|
| 23 |
+
return os;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
template<int Counter, typename ...Args>
|
| 27 |
+
struct _inner_tuple_print
|
| 28 |
+
{
|
| 29 |
+
inline
|
| 30 |
+
static std::ostream &print(std::ostream &os, const std::tuple<Args...> &t) {
|
| 31 |
+
_inner_tuple_print<Counter - 1, Args...>::print(os, t);
|
| 32 |
+
|
| 33 |
+
os << ", " << std::get<Counter>(t);
|
| 34 |
+
return os;
|
| 35 |
+
}
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
template<typename ...Args>
|
| 39 |
+
struct _inner_tuple_print<0, Args...>
|
| 40 |
+
{
|
| 41 |
+
inline
|
| 42 |
+
static std::ostream &print(std::ostream &os, const std::tuple<Args...> &t) {
|
| 43 |
+
os << std::get<0>(t);
|
| 44 |
+
return os;
|
| 45 |
+
}
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
template<typename... Args>
|
| 50 |
+
inline
|
| 51 |
+
std::ostream &operator<<(std::ostream &os, const std::tuple<Args...> &t) {
|
| 52 |
+
os << "(";
|
| 53 |
+
_inner_tuple_print<sizeof...(Args) - 1, Args...>::print(os, t);
|
| 54 |
+
os << ")";
|
| 55 |
+
return os;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
void print_tensor(const torch::Tensor &t);
|
nemo-retriever-ocr/cpp/cuda_intellisense.cuh
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#if defined(__INTELLISENSE__) || !defined(__NVCC__)
|
| 8 |
+
#ifndef KERNEL_ARG2
|
| 9 |
+
#define KERNEL_ARG2(grid, block)
|
| 10 |
+
#define KERNEL_ARG3(grid, block, sh_mem)
|
| 11 |
+
#define KERNEL_ARG4(grid, block, sh_mem, stream)
|
| 12 |
+
#define __global__
|
| 13 |
+
#define __device__
|
| 14 |
+
#define __host__
|
| 15 |
+
#endif
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#ifdef __INTELLISENSE__
|
| 19 |
+
#define __CUDACC__
|
| 20 |
+
#include <cuda_runtime.h>
|
| 21 |
+
|
| 22 |
+
void __syncthreads(); // workaround __syncthreads warning
|
| 23 |
+
|
| 24 |
+
dim3 threadIdx;
|
| 25 |
+
dim3 blockIdx;
|
| 26 |
+
dim3 blockDim;
|
| 27 |
+
dim3 gridDim;
|
| 28 |
+
|
| 29 |
+
#else
|
| 30 |
+
#ifndef KERNEL_ARG2
|
| 31 |
+
#define KERNEL_ARG2(grid, block) <<< grid, block >>>
|
| 32 |
+
#define KERNEL_ARG3(grid, block, sh_mem) <<< grid, block, sh_mem >>>
|
| 33 |
+
#define KERNEL_ARG4(grid, block, sh_mem, stream) <<< grid, block, sh_mem, stream >>>
|
| 34 |
+
#endif
|
| 35 |
+
#endif
|
| 36 |
+
|
| 37 |
+
#define __any_device__ __host__ __device__
|
| 38 |
+
|
| 39 |
+
#ifdef __NVCC__
|
| 40 |
+
#define __lib_inline__ __forceinline__
|
| 41 |
+
|
| 42 |
+
#else
|
| 43 |
+
#define __lib_inline__ inline
|
| 44 |
+
#endif
|
| 45 |
+
|
| 46 |
+
template<typename T1, typename T2>
|
| 47 |
+
__any_device__
|
| 48 |
+
inline auto div_up(T1 n, T2 d)
|
| 49 |
+
{
|
| 50 |
+
return (n + d - 1) / d;
|
| 51 |
+
}
|
nemo-retriever-ocr/cpp/geometry.h
ADDED
|
@@ -0,0 +1,1101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <cmath>
|
| 9 |
+
#include <iostream>
|
| 10 |
+
#include <type_traits>
|
| 11 |
+
|
| 12 |
+
#ifndef _GEOMETRY_NO_TORCH
|
| 13 |
+
#include <torch/torch.h>
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
+
#include "cuda_intellisense.cuh"
|
| 17 |
+
|
| 18 |
+
#ifndef __NVCC__
|
| 19 |
+
#define SORT_ALGO std::sort
|
| 20 |
+
#define SWAP std::swap
|
| 21 |
+
|
| 22 |
+
template<typename ...Args>
|
| 23 |
+
using tuple_t = std::tuple<Args...>;
|
| 24 |
+
|
| 25 |
+
#else
|
| 26 |
+
|
| 27 |
+
#include <thrust/sort.h>
|
| 28 |
+
#include <thrust/tuple.h>
|
| 29 |
+
|
| 30 |
+
#define SORT_ALGO thrust::sort
|
| 31 |
+
#define SWAP thrust::swap
|
| 32 |
+
|
| 33 |
+
template<typename ...Args>
|
| 34 |
+
using tuple_t = thrust::tuple<Args...>;
|
| 35 |
+
#endif
|
| 36 |
+
|
| 37 |
+
template<typename T>
|
| 38 |
+
struct Point_ {
|
| 39 |
+
typedef T inner_type;
|
| 40 |
+
|
| 41 |
+
T X, Y;
|
| 42 |
+
|
| 43 |
+
Point_() = default;
|
| 44 |
+
|
| 45 |
+
__any_device__
|
| 46 |
+
Point_(T x, T y) : X(x), Y(y) {}
|
| 47 |
+
|
| 48 |
+
__any_device__
|
| 49 |
+
Point_(T *ptr) : X(ptr[0]), Y(ptr[1]) {}
|
| 50 |
+
|
| 51 |
+
#ifndef _GEOMETRY_NO_TORCH
|
| 52 |
+
template<typename T2>
|
| 53 |
+
__any_device__
|
| 54 |
+
Point_(const torch::TensorAccessor<T2, 1> &accessor) : X(accessor[0]), Y(accessor[1]) {}
|
| 55 |
+
|
| 56 |
+
template<typename T2>
|
| 57 |
+
__any_device__
|
| 58 |
+
Point_(const torch::PackedTensorAccessor64<T2, 1> &accessor) : X(accessor[0]), Y(accessor[1]) {}
|
| 59 |
+
#endif
|
| 60 |
+
|
| 61 |
+
__any_device__
|
| 62 |
+
Point_ &operator+=(const Point_ &other);
|
| 63 |
+
|
| 64 |
+
__any_device__
|
| 65 |
+
Point_ &operator-=(const Point_ &other);
|
| 66 |
+
|
| 67 |
+
__any_device__
|
| 68 |
+
Point_ &operator*=(const Point_ &other);
|
| 69 |
+
|
| 70 |
+
__any_device__
|
| 71 |
+
Point_ &operator/=(const Point_ &other);
|
| 72 |
+
|
| 73 |
+
template<typename W>
|
| 74 |
+
__any_device__
|
| 75 |
+
Point_ &operator/=(W w);
|
| 76 |
+
|
| 77 |
+
template<typename W>
|
| 78 |
+
__any_device__
|
| 79 |
+
Point_ &operator*=(W w);
|
| 80 |
+
|
| 81 |
+
__any_device__
|
| 82 |
+
Point_ operator-() {
|
| 83 |
+
return { -X, -Y };
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
__any_device__
|
| 87 |
+
T Sum() const { return X + Y; }
|
| 88 |
+
|
| 89 |
+
__any_device__
|
| 90 |
+
T Angle() const;
|
| 91 |
+
|
| 92 |
+
__any_device__
|
| 93 |
+
void swap(Point_ &other) noexcept {
|
| 94 |
+
SWAP(X, other.X);
|
| 95 |
+
SWAP(Y, other.Y);
|
| 96 |
+
}
|
| 97 |
+
};
|
| 98 |
+
|
| 99 |
+
template<typename T>
|
| 100 |
+
__lib_inline__ __any_device__
|
| 101 |
+
void swap(Point_<T> &a, Point_<T> &b) {
|
| 102 |
+
a.swap(b);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
template<typename T>
|
| 107 |
+
__any_device__
|
| 108 |
+
__lib_inline__ T Point_<T>::Angle() const {
|
| 109 |
+
#ifndef __NVCC__
|
| 110 |
+
using std::atan2;
|
| 111 |
+
#endif
|
| 112 |
+
return atan2(Y, X);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
template<typename T>
|
| 116 |
+
__any_device__
|
| 117 |
+
__lib_inline__ Point_<T> min(const Point_<T> &a, const Point_<T> &b) {
|
| 118 |
+
#ifndef __NVCC__
|
| 119 |
+
using std::min;
|
| 120 |
+
#endif
|
| 121 |
+
return {
|
| 122 |
+
min(a.X, b.X),
|
| 123 |
+
min(a.Y, b.Y)
|
| 124 |
+
};
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
template<typename T>
|
| 128 |
+
__any_device__
|
| 129 |
+
__lib_inline__ Point_<T> max(const Point_<T> &a, const Point_<T> &b) {
|
| 130 |
+
#ifndef __NVCC__
|
| 131 |
+
using std::max;
|
| 132 |
+
#endif
|
| 133 |
+
return {
|
| 134 |
+
max(a.X, b.X),
|
| 135 |
+
max(a.Y, b.Y)
|
| 136 |
+
};
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
template<typename T>
|
| 140 |
+
struct AABB_ {
|
| 141 |
+
typedef T inner_type;
|
| 142 |
+
|
| 143 |
+
T X;
|
| 144 |
+
T Y;
|
| 145 |
+
T MaxX;
|
| 146 |
+
T MaxY;
|
| 147 |
+
|
| 148 |
+
AABB_() = default;
|
| 149 |
+
__any_device__
|
| 150 |
+
AABB_(T x, T y, T maxX, T maxY)
|
| 151 |
+
: X(x), Y(y), MaxX(maxX), MaxY(maxY) {}
|
| 152 |
+
|
| 153 |
+
__any_device__
|
| 154 |
+
bool Contains(const Point_<T> &p) const {
|
| 155 |
+
return p.X >= X && p.X < MaxX &&
|
| 156 |
+
p.Y >= Y && p.Y < MaxY;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
__any_device__ __lib_inline__
|
| 160 |
+
AABB_ Union(const AABB_ &other) const {
|
| 161 |
+
#ifndef __NVCC__
|
| 162 |
+
using std::min;
|
| 163 |
+
using std::max;
|
| 164 |
+
#endif
|
| 165 |
+
T minX = min(X, other.X);
|
| 166 |
+
T maxX = max(MaxX, other.MaxX);
|
| 167 |
+
T minY = min(Y, other.Y);
|
| 168 |
+
T maxY = max(MaxY, other.MaxY);
|
| 169 |
+
|
| 170 |
+
return { minX, minY, maxX, maxY };
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
__any_device__
|
| 174 |
+
AABB_ &operator-=(const Point_<T> &offset) {
|
| 175 |
+
X -= offset.X;
|
| 176 |
+
MaxX -= offset.X;
|
| 177 |
+
Y -= offset.Y;
|
| 178 |
+
MaxY -= offset.Y;
|
| 179 |
+
return *this;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
__any_device__
|
| 183 |
+
__lib_inline__ T Width() const { return MaxX - X; }
|
| 184 |
+
__any_device__
|
| 185 |
+
__lib_inline__ T Height() const { return MaxY - Y; }
|
| 186 |
+
__any_device__
|
| 187 |
+
__lib_inline__ T Area() const { return Width() * Height(); }
|
| 188 |
+
|
| 189 |
+
__lib_inline__ T &operator[] (int64_t idx)
|
| 190 |
+
{
|
| 191 |
+
static_assert(std::is_standard_layout<AABB_<T>>::value, "This function is only valid for standard layout");
|
| 192 |
+
return (&X)[idx];
|
| 193 |
+
}
|
| 194 |
+
__lib_inline__ T operator[] (int64_t idx) const
|
| 195 |
+
{
|
| 196 |
+
static_assert(std::is_standard_layout<AABB_<T>>::value, "This function is only valid for standard layout");
|
| 197 |
+
return (&X)[idx];
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
__any_device__ __lib_inline__
|
| 201 |
+
AABB_ Intersection(const AABB_ &other) const {
|
| 202 |
+
#ifndef __NVCC__
|
| 203 |
+
using std::min;
|
| 204 |
+
using std::max;
|
| 205 |
+
#endif
|
| 206 |
+
T minX = max(X, other.X);
|
| 207 |
+
T minY = max(Y, other.Y);
|
| 208 |
+
T maxX = min(MaxX, other.MaxX);
|
| 209 |
+
T maxY = min(MaxY, other.MaxY);
|
| 210 |
+
// Prevent negative area
|
| 211 |
+
minX = min(minX, maxX);
|
| 212 |
+
minY = min(minY, maxY);
|
| 213 |
+
return { minX, minY, maxX, maxY };
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
__any_device__ __lib_inline__
|
| 217 |
+
T IntersectionArea(const AABB_ &other) const { return Intersection(other).Area(); }
|
| 218 |
+
};
|
| 219 |
+
|
| 220 |
+
template<typename T, typename Derived>
|
| 221 |
+
struct QuadBase_ {
|
| 222 |
+
typedef T inner_type;
|
| 223 |
+
|
| 224 |
+
__any_device__
|
| 225 |
+
AABB_<T> Bounds() const;
|
| 226 |
+
|
| 227 |
+
__any_device__
|
| 228 |
+
bool Contains(const Point_<T> &p) const;
|
| 229 |
+
|
| 230 |
+
__any_device__
|
| 231 |
+
T Area() const;
|
| 232 |
+
|
| 233 |
+
__any_device__
|
| 234 |
+
T Height() const;
|
| 235 |
+
|
| 236 |
+
__any_device__
|
| 237 |
+
T Width() const;
|
| 238 |
+
|
| 239 |
+
template<typename Derived2>
|
| 240 |
+
__any_device__
|
| 241 |
+
T IntersectionArea(const QuadBase_<T, Derived2> &other) const;
|
| 242 |
+
|
| 243 |
+
template<typename Derived2>
|
| 244 |
+
__any_device__
|
| 245 |
+
T IOU(const QuadBase_<T, Derived2> &other) const;
|
| 246 |
+
|
| 247 |
+
template<typename Derived2>
|
| 248 |
+
__any_device__
|
| 249 |
+
T IOU_UpperBound(const QuadBase_<T, Derived2> &other) const;
|
| 250 |
+
|
| 251 |
+
__any_device__
|
| 252 |
+
Point_<T> Center() const;
|
| 253 |
+
|
| 254 |
+
template<typename Derived2>
|
| 255 |
+
__any_device__
|
| 256 |
+
/*
|
| 257 |
+
Returns 3 geometric associations between the two quads:
|
| 258 |
+
0: The percent shared area between this and other relative to this (e.g. if other contains this, then it returns 1)
|
| 259 |
+
1: The percent shared area between other and this relative to other (e.g. if this contains other, then it return 1)
|
| 260 |
+
2: The IOU of the two quads
|
| 261 |
+
*/
|
| 262 |
+
tuple_t<T, T, T> RegionSizes(const QuadBase_<T, Derived2> &other) const;
|
| 263 |
+
|
| 264 |
+
template<typename Derived2>
|
| 265 |
+
__any_device__
|
| 266 |
+
tuple_t<T, T, T> RegionSizes_UpperBound(const QuadBase_<T, Derived2> &other) const;
|
| 267 |
+
|
| 268 |
+
__any_device__
|
| 269 |
+
Derived &operator/=(T val) {
|
| 270 |
+
auto rcp = 1 / val;
|
| 271 |
+
return *this *= rcp;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
__any_device__
|
| 275 |
+
Derived &operator*=(T val) {
|
| 276 |
+
auto dThis = static_cast<Derived*>(this);
|
| 277 |
+
#pragma unroll
|
| 278 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 279 |
+
dThis->Vertices[i] *= val;
|
| 280 |
+
}
|
| 281 |
+
return *dThis;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
friend auto begin(const QuadBase_ &q) { return static_cast<const Derived&>(q).Vertices; }
|
| 285 |
+
friend auto begin(QuadBase_& q) { return static_cast<const Derived&>(q).Vertices; }
|
| 286 |
+
friend auto end(const QuadBase_ &q) { return static_cast<const Derived&>(q).Vertices + 4; }
|
| 287 |
+
friend auto end(QuadBase_ &q) { return static_cast<const Derived&>(q).Vertices + 4; }
|
| 288 |
+
};
|
| 289 |
+
|
| 290 |
+
template<typename T>
|
| 291 |
+
struct Quad_ : QuadBase_<T, Quad_<T>> {
|
| 292 |
+
Point_<T> *Vertices = nullptr;
|
| 293 |
+
|
| 294 |
+
Quad_() = default;
|
| 295 |
+
__any_device__
|
| 296 |
+
Quad_(T *dataPtr)
|
| 297 |
+
: Vertices(reinterpret_cast<Point_<T>*>(dataPtr)) {}
|
| 298 |
+
__any_device__
|
| 299 |
+
Quad_(Point_<T> *dataPtr)
|
| 300 |
+
: Vertices(dataPtr) {}
|
| 301 |
+
|
| 302 |
+
template<typename index_t>
|
| 303 |
+
__any_device__ __lib_inline__
|
| 304 |
+
const Point_<T> &operator[](index_t offset) const { return Vertices[offset]; }
|
| 305 |
+
template<typename index_t>
|
| 306 |
+
__any_device__ __lib_inline__
|
| 307 |
+
Point_<T> &operator[](index_t offset) { return Vertices[offset]; }
|
| 308 |
+
};
|
| 309 |
+
|
| 310 |
+
template<typename T>
|
| 311 |
+
struct InPlaceQuad_ : public QuadBase_<T, InPlaceQuad_<T>> {
|
| 312 |
+
Point_<T> Vertices[4];
|
| 313 |
+
|
| 314 |
+
InPlaceQuad_() = default;
|
| 315 |
+
__any_device__
|
| 316 |
+
InPlaceQuad_(const T *dataPtr)
|
| 317 |
+
{
|
| 318 |
+
#if defined(__NVCC__)
|
| 319 |
+
T *pVals = reinterpret_cast<T*>(Vertices);
|
| 320 |
+
#pragma unroll
|
| 321 |
+
for (uint32_t i = 0; i < 8; ++i) {
|
| 322 |
+
pVals[i] = dataPtr[i];
|
| 323 |
+
}
|
| 324 |
+
#else
|
| 325 |
+
using std::copy;
|
| 326 |
+
copy(dataPtr, dataPtr + 8, reinterpret_cast<T*>(Vertices));
|
| 327 |
+
#endif
|
| 328 |
+
}
|
| 329 |
+
__any_device__
|
| 330 |
+
InPlaceQuad_(const Point_<T> *dataPtr)
|
| 331 |
+
{
|
| 332 |
+
#if defined(__NVCC__)
|
| 333 |
+
#pragma unroll
|
| 334 |
+
for (uint32_t i = 0; i < 4; ++i) {
|
| 335 |
+
Vertices[i] = dataPtr[i];
|
| 336 |
+
}
|
| 337 |
+
#else
|
| 338 |
+
using std::copy;
|
| 339 |
+
copy(dataPtr, dataPtr + 4, Vertices);
|
| 340 |
+
#endif
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
template<typename index_t>
|
| 344 |
+
__any_device__ __lib_inline__
|
| 345 |
+
const Point_<T> &operator[](index_t v) const { return Vertices[v]; }
|
| 346 |
+
|
| 347 |
+
template<typename index_t>
|
| 348 |
+
__any_device__ __lib_inline__
|
| 349 |
+
Point_<T> &operator[](index_t v) { return Vertices[v]; }
|
| 350 |
+
};
|
| 351 |
+
|
| 352 |
+
template<typename T, typename Derived>
|
| 353 |
+
struct PolygonBase_ {
|
| 354 |
+
typedef T inner_type;
|
| 355 |
+
|
| 356 |
+
__any_device__
|
| 357 |
+
AABB_<T> Bounds() const;
|
| 358 |
+
|
| 359 |
+
__any_device__
|
| 360 |
+
bool Contains(const Point_<T> &p) const;
|
| 361 |
+
|
| 362 |
+
__any_device__
|
| 363 |
+
T EdgeLength() const;
|
| 364 |
+
|
| 365 |
+
__any_device__
|
| 366 |
+
Point_<T> Center() const;
|
| 367 |
+
|
| 368 |
+
__any_device__
|
| 369 |
+
T Area() const;
|
| 370 |
+
};
|
| 371 |
+
|
| 372 |
+
template<typename T>
|
| 373 |
+
struct Polygon_ : PolygonBase_<T, Polygon_<T>> {
|
| 374 |
+
Point_<T> *Vertices = nullptr;
|
| 375 |
+
size_t Count = 0;
|
| 376 |
+
|
| 377 |
+
Polygon_() = default;
|
| 378 |
+
__any_device__
|
| 379 |
+
Polygon_(T *dataPtr, size_t vertexCount)
|
| 380 |
+
: Vertices(reinterpret_cast<Point_<T>*>(dataPtr)), Count(vertexCount) {}
|
| 381 |
+
__any_device__
|
| 382 |
+
Polygon_(Point_<T> *dataPtr, size_t vertexCount)
|
| 383 |
+
: Vertices(dataPtr), Count(vertexCount) {}
|
| 384 |
+
|
| 385 |
+
__any_device__
|
| 386 |
+
const Point_<T> &operator[](size_t offset) const { return Vertices[offset]; }
|
| 387 |
+
__any_device__
|
| 388 |
+
Point_<T> &operator[](size_t offset) { return Vertices[offset]; }
|
| 389 |
+
};
|
| 390 |
+
|
| 391 |
+
template<typename T>
|
| 392 |
+
struct Segment_ {
|
| 393 |
+
Point_<T> A, B;
|
| 394 |
+
|
| 395 |
+
Segment_() = default;
|
| 396 |
+
__any_device__
|
| 397 |
+
Segment_(const Point_<T> &a, const Point_<T> &b) : A(a), B(b) {}
|
| 398 |
+
|
| 399 |
+
__any_device__
|
| 400 |
+
T Length() const;
|
| 401 |
+
__any_device__
|
| 402 |
+
T LengthSq() const;
|
| 403 |
+
__any_device__
|
| 404 |
+
bool Intersection(const Segment_<T> &other, Point_<T> &out_ptAlong) const;
|
| 405 |
+
};
|
| 406 |
+
|
| 407 |
+
template<typename T>
|
| 408 |
+
__any_device__
|
| 409 |
+
__lib_inline__ Point_<T> operator+(const Point_<T> &a, const Point_<T> &b) {
|
| 410 |
+
return { a.X + b.X, a.Y + b.Y };
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
template<typename T>
|
| 414 |
+
__any_device__
|
| 415 |
+
__lib_inline__ Point_<T> operator-(const Point_<T> &a, const Point_<T> &b) {
|
| 416 |
+
return { a.X - b.X, a.Y - b.Y };
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
template<typename T, typename W>
|
| 420 |
+
__any_device__
|
| 421 |
+
__lib_inline__ Point_<T> operator*(W scale, const Point_<T> &p) {
|
| 422 |
+
return { scale * p.X, scale * p.Y };
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
template<typename T, typename W>
|
| 426 |
+
__any_device__
|
| 427 |
+
__lib_inline__ Point_<T> operator*(const Point_<T> &p, W scale) {
|
| 428 |
+
return { scale * p.X, scale * p.Y };
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
template<typename T, typename W>
|
| 432 |
+
__any_device__
|
| 433 |
+
__lib_inline__ Point_<T> operator/(const Point_<T> &p, W divisor) {
|
| 434 |
+
return { p.X / divisor, p.Y / divisor };
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
template<typename T>
|
| 438 |
+
__any_device__
|
| 439 |
+
__lib_inline__ Point_<T> operator*(const Point_<T> &a, const Point_<T> &b) {
|
| 440 |
+
return { a.X * b.X, a.Y * b.Y };
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
template<typename T, typename W>
|
| 444 |
+
__any_device__
|
| 445 |
+
__lib_inline__ Point_<T> operator-(const Point_<T> &p, W v) {
|
| 446 |
+
return { p.X - v, p.Y - v };
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
template<typename T>
|
| 450 |
+
__any_device__
|
| 451 |
+
__lib_inline__ Point_<T> &Point_<T>::operator+=(const Point_<T> &p) {
|
| 452 |
+
X = X + p.X;
|
| 453 |
+
Y = Y + p.Y;
|
| 454 |
+
return *this;
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
template<typename T>
|
| 458 |
+
__any_device__
|
| 459 |
+
__lib_inline__ Point_<T> &Point_<T>::operator-=(const Point_<T> &p) {
|
| 460 |
+
X = X - p.X;
|
| 461 |
+
Y = Y - p.Y;
|
| 462 |
+
return *this;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
template<typename T>
|
| 466 |
+
__any_device__
|
| 467 |
+
__lib_inline__ Point_<T> &Point_<T>::operator*=(const Point_<T> &p) {
|
| 468 |
+
X = X * p.X;
|
| 469 |
+
Y = Y * p.Y;
|
| 470 |
+
return *this;
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
template<typename T>
|
| 474 |
+
__any_device__
|
| 475 |
+
__lib_inline__ Point_<T> &Point_<T>::operator/=(const Point_<T> &p) {
|
| 476 |
+
X = X / p.X;
|
| 477 |
+
Y = Y / p.Y;
|
| 478 |
+
return *this;
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
template<typename T>
|
| 482 |
+
template<typename W>
|
| 483 |
+
__any_device__
|
| 484 |
+
__lib_inline__ Point_<T> &Point_<T>::operator/=(W val) {
|
| 485 |
+
// TODO: This can be more efficient for float types by computing the reciprocal
|
| 486 |
+
X /= val;
|
| 487 |
+
Y /= val;
|
| 488 |
+
return *this;
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
template<typename T>
|
| 492 |
+
template<typename W>
|
| 493 |
+
__any_device__
|
| 494 |
+
__lib_inline__ Point_<T> &Point_<T>::operator*=(W val) {
|
| 495 |
+
X *= val;
|
| 496 |
+
Y *= val;
|
| 497 |
+
return *this;
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
template<typename T>
|
| 501 |
+
__any_device__
|
| 502 |
+
__lib_inline__ T dot(const Point_<T> &a, const Point_<T> &b) {
|
| 503 |
+
return a.X * b.X + a.Y * b.Y;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
template<typename T>
|
| 507 |
+
__any_device__
|
| 508 |
+
__lib_inline__ T dot(const Point_<T> &p) {
|
| 509 |
+
return dot(p, p);
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
template<typename T>
|
| 513 |
+
__any_device__
|
| 514 |
+
__lib_inline__ T length(const Point_<T> &p) {
|
| 515 |
+
#ifndef __NVCC__
|
| 516 |
+
using std::sqrt;
|
| 517 |
+
#endif
|
| 518 |
+
return sqrt(dot(p));
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
template<typename T>
|
| 522 |
+
__any_device__
|
| 523 |
+
__lib_inline__ Point_<T> normalize(const Point_<T> &p) {
|
| 524 |
+
static constexpr T epsilon = std::numeric_limits<T>::epsilon();
|
| 525 |
+
auto len = length(p) + epsilon;
|
| 526 |
+
return { p.X / len, p.Y / len };
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
template<typename T>
|
| 530 |
+
__any_device__
|
| 531 |
+
__lib_inline__ Point_<T> ortho_2d(const Point_<T> &p) {
|
| 532 |
+
return { -p.Y, p.X };
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
template<typename T>
|
| 536 |
+
__host__
|
| 537 |
+
__lib_inline__ std::ostream &operator<<(std::ostream &os, const Point_<T> &p) {
|
| 538 |
+
return os << "(" << p.X << ", " << p.Y << ")";
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
template<typename T>
|
| 542 |
+
__host__
|
| 543 |
+
__lib_inline__ std::ostream &operator<<(std::ostream &os, const AABB_<T> &b) {
|
| 544 |
+
return os << "[(" << b.X << ", " << b.Y << "), (" << b.MaxX << ", " << b.MaxY << ")]";
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
template<typename T>
|
| 548 |
+
__host__
|
| 549 |
+
__lib_inline__ std::ostream &operator<<(std::ostream &os, const Segment_<T> &s) {
|
| 550 |
+
return os << "[(" << s.A.X << ", " << s.A.Y << "), (" << s.B.X << ", " << s.B.Y << ")]";
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
template<typename T>
|
| 554 |
+
__host__
|
| 555 |
+
__lib_inline__ std::ostream &operator<<(std::ostream &os, const Quad_<T> &q) {
|
| 556 |
+
os << "[" << q.Vertices[0];
|
| 557 |
+
for (size_t i = 1; i < 4; ++i) {
|
| 558 |
+
os << ", " << q.Vertices[i];
|
| 559 |
+
}
|
| 560 |
+
return os << "]";
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
template<typename T>
|
| 564 |
+
__any_device__
|
| 565 |
+
__lib_inline__ int _signum(T val) {
|
| 566 |
+
return (T(0) < val) - (val < T(0));
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
template<typename T>
|
| 570 |
+
__any_device__
|
| 571 |
+
__lib_inline__ T sign(const Point_<T> &p1, const Point_<T> &p2, const Point_<T> &p3) {
|
| 572 |
+
T ret = (p1.X - p3.X) * (p2.Y - p3.Y) - (p2.X - p3.X) * (p1.Y - p3.Y);
|
| 573 |
+
auto sgn = _signum(ret);
|
| 574 |
+
return sgn;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
template<typename T>
|
| 578 |
+
__any_device__
|
| 579 |
+
__lib_inline__ T Segment_<T>::Length() const
|
| 580 |
+
{
|
| 581 |
+
#ifndef __NVCC__
|
| 582 |
+
using std::sqrt;
|
| 583 |
+
#endif
|
| 584 |
+
return sqrt(LengthSq());
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
template<typename T>
|
| 588 |
+
__any_device__
|
| 589 |
+
__lib_inline__ T Segment_<T>::LengthSq() const
|
| 590 |
+
{
|
| 591 |
+
return dot(B - A);
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
template<typename T>
|
| 595 |
+
__any_device__
|
| 596 |
+
inline bool Segment_<T>::Intersection(const Segment_<T> &other, Point_<T> &out_ptAlong) const
|
| 597 |
+
{
|
| 598 |
+
auto p1 = A, p2 = B, p3 = other.A, p4 = other.B;
|
| 599 |
+
|
| 600 |
+
auto denom = (p4.Y - p3.Y) * (p2.X - p1.X) - (p4.X - p3.X) * (p2.Y - p1.Y);
|
| 601 |
+
|
| 602 |
+
if (abs(denom) < 1e-8) {
|
| 603 |
+
return false;
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
auto numer = (p4.X - p3.X) * (p1.Y - p3.Y) - (p4.Y - p3.Y) * (p1.X - p3.X);
|
| 607 |
+
|
| 608 |
+
auto t = numer / denom;
|
| 609 |
+
|
| 610 |
+
auto Bnumer = (p2.X - p1.X) * (p1.Y - p3.Y) - (p2.Y - p1.Y) * (p1.X - p3.X);
|
| 611 |
+
|
| 612 |
+
auto Bt = Bnumer / denom;
|
| 613 |
+
|
| 614 |
+
if (t < 0 || t > 1 || Bt < 0 || Bt > 1) {
|
| 615 |
+
return false;
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
out_ptAlong = A + t * (B - A);
|
| 619 |
+
|
| 620 |
+
return true;
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
template<typename quad_t>
|
| 624 |
+
__any_device__
|
| 625 |
+
auto quad_center(const quad_t &quad) -> Point_<typename quad_t::inner_type>
|
| 626 |
+
{
|
| 627 |
+
typedef typename quad_t::inner_type T;
|
| 628 |
+
|
| 629 |
+
Point_<T> center = quad[0];
|
| 630 |
+
for (size_t i = 1; i < 4; ++i) {
|
| 631 |
+
center += quad[i];
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
return center / T{ 4 };
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
template<typename T, typename Derived>
|
| 638 |
+
__any_device__
|
| 639 |
+
Point_<T> QuadBase_<T, Derived>::Center() const {
|
| 640 |
+
return quad_center(static_cast<const Derived&>(*this));
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
template<typename quad_t>
|
| 644 |
+
__any_device__
|
| 645 |
+
auto quad_bounds(const quad_t &quad) -> AABB_<typename quad_t::inner_type>
|
| 646 |
+
{
|
| 647 |
+
#ifndef __NVCC__
|
| 648 |
+
using std::min;
|
| 649 |
+
using std::max;
|
| 650 |
+
#endif
|
| 651 |
+
auto minP = quad[0];
|
| 652 |
+
auto maxP = minP;
|
| 653 |
+
for (size_t i = 1; i < 4; ++i) {
|
| 654 |
+
auto qp = quad[i];
|
| 655 |
+
minP = min(minP, qp);
|
| 656 |
+
maxP = max(maxP, qp);
|
| 657 |
+
}
|
| 658 |
+
return { minP.X, minP.Y, maxP.X, maxP.Y };
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
template<typename T, typename Derived>
|
| 662 |
+
__any_device__
|
| 663 |
+
AABB_<T> QuadBase_<T, Derived>::Bounds() const {
|
| 664 |
+
return quad_bounds(static_cast<const Derived&>(*this));
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
template<typename Quad_t, typename point_t>
|
| 668 |
+
__any_device__
|
| 669 |
+
inline bool quad_contains(const Quad_t &quad, const point_t &pt)
|
| 670 |
+
{
|
| 671 |
+
#ifndef __NVCC__
|
| 672 |
+
using std::abs;
|
| 673 |
+
#endif
|
| 674 |
+
|
| 675 |
+
// Checks that the point lies on the interior side of each half plane
|
| 676 |
+
auto d1 = sign(pt, quad[0], quad[1]);
|
| 677 |
+
auto d2 = sign(pt, quad[1], quad[2]);
|
| 678 |
+
auto d3 = sign(pt, quad[2], quad[3]);
|
| 679 |
+
auto d4 = sign(pt, quad[3], quad[0]);
|
| 680 |
+
|
| 681 |
+
// bool has_neg = (d1 < 0) || (d2 < 0) || (d3 < 0) || (d4 < 0);
|
| 682 |
+
// bool has_pos = (d1 > 0) || (d2 > 0) || (d3 > 0) || (d4 > 0);
|
| 683 |
+
int tot = d1 + d2 + d3 + d4;
|
| 684 |
+
|
| 685 |
+
// return !(has_neg && has_pos);
|
| 686 |
+
return abs(tot) == 4;
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
template<typename T, typename Derived>
|
| 690 |
+
__any_device__
|
| 691 |
+
__lib_inline__ bool QuadBase_<T, Derived>::Contains(const Point_<T> &pt) const
|
| 692 |
+
{
|
| 693 |
+
return quad_contains(static_cast<const Derived&>(*this), pt);
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
template<typename PtList>
|
| 697 |
+
__any_device__
|
| 698 |
+
inline auto shoelace_area(const PtList &points, size_t numPts, bool isSigned=false) -> decltype(points[0].X)
|
| 699 |
+
{
|
| 700 |
+
#ifndef __NVCC__
|
| 701 |
+
using std::abs;
|
| 702 |
+
#endif
|
| 703 |
+
|
| 704 |
+
decltype(points[0].X) area = 0;
|
| 705 |
+
|
| 706 |
+
size_t j = numPts - 1;
|
| 707 |
+
for (size_t i = 0; i < numPts; ++i) {
|
| 708 |
+
auto Pi = points[i];
|
| 709 |
+
auto Pj = points[j];
|
| 710 |
+
|
| 711 |
+
area += (Pj.X + Pi.X) * (Pj.Y - Pi.Y);
|
| 712 |
+
j = i;
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
area = area / 2;
|
| 716 |
+
|
| 717 |
+
if (! isSigned) {
|
| 718 |
+
area = abs(area);
|
| 719 |
+
}
|
| 720 |
+
|
| 721 |
+
return area;
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
template<typename T, typename Derived>
|
| 725 |
+
__any_device__
|
| 726 |
+
__lib_inline__ T QuadBase_<T, Derived>::Height() const
|
| 727 |
+
{
|
| 728 |
+
auto &d = static_cast<const Derived&>(*this);
|
| 729 |
+
auto h1 = Segment_<T>(d[1], d[2]).Length();
|
| 730 |
+
auto h2 = Segment_<T>(d[3], d[0]).Length();
|
| 731 |
+
return (h1 + h2) / 2;
|
| 732 |
+
}
|
| 733 |
+
|
| 734 |
+
template<typename T, typename Derived>
|
| 735 |
+
__any_device__
|
| 736 |
+
__lib_inline__ T QuadBase_<T, Derived>::Width() const
|
| 737 |
+
{
|
| 738 |
+
auto &d = static_cast<const Derived&>(*this);
|
| 739 |
+
auto w1 = Segment_<T>(d[0], d[1]).Length();
|
| 740 |
+
auto w2 = Segment_<T>(d[3], d[2]).Length();
|
| 741 |
+
return (w1 + w2) / 2;
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
// A quad can be defined as the sum of the area of two triangles
|
| 745 |
+
template<typename T, typename Derived>
|
| 746 |
+
__any_device__
|
| 747 |
+
inline T QuadBase_<T, Derived>::Area() const
|
| 748 |
+
{
|
| 749 |
+
// auto vertices = static_cast<const Derived *>(this)->Vertices;
|
| 750 |
+
return shoelace_area(static_cast<const Derived&>(*this), 4);
|
| 751 |
+
}
|
| 752 |
+
|
| 753 |
+
template<typename Quad_t1, typename Quad_t2>
|
| 754 |
+
__any_device__
|
| 755 |
+
inline auto intersection_area(const Quad_t1 &quadsA, const Quad_t2 &quadsB) -> typename Quad_t1::inner_type
|
| 756 |
+
{
|
| 757 |
+
#ifndef __NVCC__
|
| 758 |
+
using std::atan2;
|
| 759 |
+
#endif
|
| 760 |
+
|
| 761 |
+
typedef typename Quad_t1::inner_type T;
|
| 762 |
+
|
| 763 |
+
static const size_t MAX_PTS = 32;
|
| 764 |
+
|
| 765 |
+
Point_<T> points[MAX_PTS], sortedPoints[MAX_PTS];
|
| 766 |
+
T angles[MAX_PTS];
|
| 767 |
+
size_t indices[MAX_PTS];
|
| 768 |
+
size_t numPts = 0;
|
| 769 |
+
|
| 770 |
+
auto addPt = [&] (const Point_<T> &p) {
|
| 771 |
+
points[numPts] = p;
|
| 772 |
+
++numPts;
|
| 773 |
+
};
|
| 774 |
+
|
| 775 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 776 |
+
Point_<T> aPt = quadsA[i];
|
| 777 |
+
Point_<T> bPt = quadsB[i];
|
| 778 |
+
|
| 779 |
+
if (quadsA.Contains(bPt)) {
|
| 780 |
+
addPt(bPt);
|
| 781 |
+
}
|
| 782 |
+
if (quadsB.Contains(aPt)) {
|
| 783 |
+
addPt(aPt);
|
| 784 |
+
}
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 788 |
+
Segment_<T> segA{ quadsA[i], quadsA[(i + 1) % 4] };
|
| 789 |
+
|
| 790 |
+
for (size_t j = 0; j < 4; ++j) {
|
| 791 |
+
Segment_<T> segB{ quadsB[j], quadsB[(j + 1) % 4] };
|
| 792 |
+
|
| 793 |
+
Point_<T> ptAlong;
|
| 794 |
+
if (segA.Intersection(segB, ptAlong)) {
|
| 795 |
+
addPt(ptAlong);
|
| 796 |
+
}
|
| 797 |
+
}
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
if (numPts == 0) {
|
| 801 |
+
return 0;
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
Point_<T> center{ 0, 0 };
|
| 805 |
+
for (size_t i = 0; i < numPts; ++i) {
|
| 806 |
+
center += points[i];
|
| 807 |
+
}
|
| 808 |
+
center /= numPts;
|
| 809 |
+
|
| 810 |
+
for (size_t i = 0; i < numPts; ++i) {
|
| 811 |
+
points[i] -= center;
|
| 812 |
+
|
| 813 |
+
angles[i] = atan2(points[i].Y, points[i].X);
|
| 814 |
+
|
| 815 |
+
indices[i] = i;
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
// Perform an argsort over the angles
|
| 819 |
+
SORT_ALGO(indices, indices + numPts,
|
| 820 |
+
[&] (size_t a, size_t b) {
|
| 821 |
+
return angles[a] < angles[b];
|
| 822 |
+
}
|
| 823 |
+
);
|
| 824 |
+
|
| 825 |
+
for (size_t i = 0; i < numPts; ++i) {
|
| 826 |
+
sortedPoints[i] = points[indices[i]];
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
// Finally, we can compute the area of this polygon using the shoelace formula
|
| 830 |
+
T area = shoelace_area(sortedPoints, numPts);
|
| 831 |
+
|
| 832 |
+
return area;
|
| 833 |
+
}
|
| 834 |
+
|
| 835 |
+
template<typename T, typename Derived>
|
| 836 |
+
template<typename Derived2>
|
| 837 |
+
__any_device__
|
| 838 |
+
__lib_inline__ T QuadBase_<T, Derived>::IntersectionArea(const QuadBase_<T, Derived2> &other) const
|
| 839 |
+
{
|
| 840 |
+
return intersection_area(
|
| 841 |
+
static_cast<const Derived&>(*this),
|
| 842 |
+
static_cast<const Derived2&>(other)
|
| 843 |
+
);
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
template<typename T1, typename T2>
|
| 847 |
+
__any_device__
|
| 848 |
+
__lib_inline__ auto geometry_iou(const T1 &a, const T2 &b) -> decltype(a.Area())
|
| 849 |
+
{
|
| 850 |
+
auto aArea = a.Area();
|
| 851 |
+
auto bArea = b.Area();
|
| 852 |
+
auto ixArea = a.IntersectionArea(b);
|
| 853 |
+
|
| 854 |
+
auto unionArea = aArea + bArea - ixArea;
|
| 855 |
+
|
| 856 |
+
return ixArea / unionArea;
|
| 857 |
+
}
|
| 858 |
+
|
| 859 |
+
template<typename T, typename Derived>
|
| 860 |
+
template<typename Derived2>
|
| 861 |
+
__any_device__
|
| 862 |
+
__lib_inline__ T QuadBase_<T, Derived>::IOU(const QuadBase_<T, Derived2> &other) const
|
| 863 |
+
{
|
| 864 |
+
return geometry_iou(
|
| 865 |
+
static_cast<const Derived&>(*this),
|
| 866 |
+
static_cast<const Derived2&>(other)
|
| 867 |
+
);
|
| 868 |
+
}
|
| 869 |
+
|
| 870 |
+
template<typename T, typename Derived>
|
| 871 |
+
template<typename Derived2>
|
| 872 |
+
__any_device__
|
| 873 |
+
__lib_inline__ T QuadBase_<T, Derived>::IOU_UpperBound(const QuadBase_<T, Derived2> &other) const
|
| 874 |
+
{
|
| 875 |
+
return geometry_iou(
|
| 876 |
+
Bounds(),
|
| 877 |
+
other.Bounds()
|
| 878 |
+
);
|
| 879 |
+
}
|
| 880 |
+
|
| 881 |
+
template<typename T1, typename T2>
|
| 882 |
+
__any_device__ __lib_inline__
|
| 883 |
+
auto geometry_region_sizes(const T1 &a, const T2 &b) -> tuple_t<decltype(a.Area()), decltype(a.Area()), decltype(a.IntersectionArea(b))>
|
| 884 |
+
{
|
| 885 |
+
auto aArea = a.Area();
|
| 886 |
+
auto bArea = b.Area();
|
| 887 |
+
auto ixArea = a.IntersectionArea(b);
|
| 888 |
+
|
| 889 |
+
auto unionArea = aArea + bArea - ixArea;
|
| 890 |
+
auto iou = ixArea / unionArea;
|
| 891 |
+
|
| 892 |
+
return { ixArea / aArea, ixArea / bArea, iou };
|
| 893 |
+
}
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
template<typename T, typename Derived>
|
| 897 |
+
template<typename Derived2>
|
| 898 |
+
__any_device__ __lib_inline__
|
| 899 |
+
tuple_t<T, T, T> QuadBase_<T, Derived>::RegionSizes(const QuadBase_<T, Derived2> &other) const
|
| 900 |
+
{
|
| 901 |
+
return geometry_region_sizes(
|
| 902 |
+
static_cast<const Derived&>(*this),
|
| 903 |
+
static_cast<const Derived2&>(other)
|
| 904 |
+
);
|
| 905 |
+
}
|
| 906 |
+
|
| 907 |
+
template<typename T, typename Derived>
|
| 908 |
+
template<typename Derived2>
|
| 909 |
+
__any_device__ __lib_inline__
|
| 910 |
+
tuple_t<T, T, T> QuadBase_<T, Derived>::RegionSizes_UpperBound(const QuadBase_<T, Derived2> &other) const
|
| 911 |
+
{
|
| 912 |
+
return geometry_region_sizes(
|
| 913 |
+
Bounds(),
|
| 914 |
+
other.Bounds()
|
| 915 |
+
);
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
template<typename polygon_t>
|
| 919 |
+
__any_device__
|
| 920 |
+
auto polygon_bounds(const polygon_t &poly) -> AABB_<typename polygon_t::inner_type>
|
| 921 |
+
{
|
| 922 |
+
#ifndef __NVCC__
|
| 923 |
+
using std::min;
|
| 924 |
+
using std::max;
|
| 925 |
+
#endif
|
| 926 |
+
auto minP = poly[0];
|
| 927 |
+
auto maxP = minP;
|
| 928 |
+
for (size_t i = 1; i < poly.Count; ++i) {
|
| 929 |
+
auto qp = poly[i];
|
| 930 |
+
minP = min(minP, qp);
|
| 931 |
+
maxP = max(maxP, qp);
|
| 932 |
+
}
|
| 933 |
+
return { minP.X, minP.Y, maxP.X, maxP.Y };
|
| 934 |
+
}
|
| 935 |
+
|
| 936 |
+
template<typename T, typename Derived>
|
| 937 |
+
__any_device__
|
| 938 |
+
AABB_<T> PolygonBase_<T, Derived>::Bounds() const {
|
| 939 |
+
return polygon_bounds(static_cast<const Derived&>(*this));
|
| 940 |
+
}
|
| 941 |
+
|
| 942 |
+
template<typename polygon_t, typename point_t>
|
| 943 |
+
__any_device__
|
| 944 |
+
bool polygon_contains(const polygon_t &poly, const point_t &pt)
|
| 945 |
+
{
|
| 946 |
+
typedef typename polygon_t::inner_type T;
|
| 947 |
+
|
| 948 |
+
// Some arbitrary segment. Technically this should be a ray, but functionally this will work
|
| 949 |
+
Segment_<T> testSeg{ pt, { -1e6, -2e6 }};
|
| 950 |
+
Point_<T> trash;
|
| 951 |
+
|
| 952 |
+
int32_t ixCount = 0;
|
| 953 |
+
for (size_t i = 0; i < poly.Count; ++i) {
|
| 954 |
+
Segment_<T> polySeg{ poly[i], poly[(i + 1) % poly.Count] };
|
| 955 |
+
|
| 956 |
+
if (testSeg.Intersection(polySeg, trash)) {
|
| 957 |
+
++ixCount;
|
| 958 |
+
}
|
| 959 |
+
}
|
| 960 |
+
|
| 961 |
+
// If there are an odd number of intersections, then the point is inside
|
| 962 |
+
return (ixCount % 2) == 1;
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
template<typename T, typename Derived>
|
| 966 |
+
__any_device__
|
| 967 |
+
bool PolygonBase_<T, Derived>::Contains(const Point_<T> &pt) const {
|
| 968 |
+
return polygon_contains(static_cast<const Derived&>(*this), pt);
|
| 969 |
+
}
|
| 970 |
+
|
| 971 |
+
template<typename polygon_t>
|
| 972 |
+
__any_device__
|
| 973 |
+
auto polygon_edge_length(const polygon_t &poly) -> typename polygon_t::inner_type
|
| 974 |
+
{
|
| 975 |
+
typedef typename polygon_t::inner_type T;
|
| 976 |
+
|
| 977 |
+
T ret = 0;
|
| 978 |
+
|
| 979 |
+
for (size_t i = 0; i < poly.Count; ++i) {
|
| 980 |
+
Segment_<T> seg{ poly[i], poly[(i + 1) % poly.Count] };
|
| 981 |
+
|
| 982 |
+
ret += seg.Length();
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
return ret;
|
| 986 |
+
}
|
| 987 |
+
|
| 988 |
+
template<typename T, typename Derived>
|
| 989 |
+
__any_device__
|
| 990 |
+
T PolygonBase_<T, Derived>::EdgeLength() const {
|
| 991 |
+
return polygon_edge_length(static_cast<const Derived&>(*this));
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
template<typename polygon_t>
|
| 995 |
+
__any_device__
|
| 996 |
+
auto polygon_center(const polygon_t &poly) -> Point_<typename polygon_t::inner_type>
|
| 997 |
+
{
|
| 998 |
+
typedef typename polygon_t::inner_type T;
|
| 999 |
+
|
| 1000 |
+
T cx = 0, cy = 0, a = 0;
|
| 1001 |
+
size_t j = poly.Count - 1;
|
| 1002 |
+
for (size_t i = 0; i < poly.Count; ++i) {
|
| 1003 |
+
Point_<T> p0 = poly[i];
|
| 1004 |
+
Point_<T> p1 = poly[j];
|
| 1005 |
+
|
| 1006 |
+
T common = (p0.X * p1.Y - p1.X * p0.Y);
|
| 1007 |
+
cx += (p0.X + p1.X) * common;
|
| 1008 |
+
cy += (p0.Y + p1.Y) * common;
|
| 1009 |
+
a += common;
|
| 1010 |
+
|
| 1011 |
+
j = i;
|
| 1012 |
+
}
|
| 1013 |
+
|
| 1014 |
+
a /= 2;
|
| 1015 |
+
|
| 1016 |
+
Point_<T> center{ cx / (6 * a), cy / (6 * a) };
|
| 1017 |
+
|
| 1018 |
+
return center;
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
template<typename T, typename Derived>
|
| 1022 |
+
__any_device__
|
| 1023 |
+
Point_<T> PolygonBase_<T, Derived>::Center() const {
|
| 1024 |
+
return polygon_center(static_cast<const Derived&>(*this));
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
template<typename T, typename Derived>
|
| 1028 |
+
__any_device__
|
| 1029 |
+
T PolygonBase_<T, Derived>::Area() const {
|
| 1030 |
+
const Derived &dThis = static_cast<const Derived&>(*this);
|
| 1031 |
+
return shoelace_area(dThis, dThis.Count);
|
| 1032 |
+
}
|
| 1033 |
+
|
| 1034 |
+
|
| 1035 |
+
template<typename T>
|
| 1036 |
+
__any_device__
|
| 1037 |
+
Point_<T> nearest_point_on_segment(const Point_<T> &pt, const Segment_<T> &seg)
|
| 1038 |
+
{
|
| 1039 |
+
#ifndef __NVCC__
|
| 1040 |
+
using std::max;
|
| 1041 |
+
using std::min;
|
| 1042 |
+
#endif
|
| 1043 |
+
|
| 1044 |
+
const T l2 = seg.LengthSq();
|
| 1045 |
+
|
| 1046 |
+
if (l2 == 0.0) {
|
| 1047 |
+
return seg.A;
|
| 1048 |
+
}
|
| 1049 |
+
|
| 1050 |
+
const auto v = seg.A;
|
| 1051 |
+
const auto w = seg.B;
|
| 1052 |
+
// Consider the line extending the segment, parameterized as v + t*(w-v)
|
| 1053 |
+
// Find projection of point p onto the line
|
| 1054 |
+
auto t = dot(pt - v, w - v) / l2;
|
| 1055 |
+
|
| 1056 |
+
// Clamp between t=0 and t=1
|
| 1057 |
+
t = max(static_cast<T>(0), min(static_cast<T>(1), t));
|
| 1058 |
+
|
| 1059 |
+
const auto projection = v + t * (w - v);
|
| 1060 |
+
|
| 1061 |
+
return projection;
|
| 1062 |
+
}
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
template<typename T>
|
| 1066 |
+
__any_device__
|
| 1067 |
+
Segment_<T> shortest_line_between_segments(const Segment_<T> &a, const Segment_<T> &b)
|
| 1068 |
+
{
|
| 1069 |
+
Segment_<T> segs[] = {
|
| 1070 |
+
{ a.A, nearest_point_on_segment(a.A, b) },
|
| 1071 |
+
{ a.B, nearest_point_on_segment(a.B, b) },
|
| 1072 |
+
{ nearest_point_on_segment(b.A, a), b.A },
|
| 1073 |
+
{ nearest_point_on_segment(b.B, a), b.B }
|
| 1074 |
+
};
|
| 1075 |
+
|
| 1076 |
+
T minDist = std::numeric_limits<T>::max();
|
| 1077 |
+
size_t idx;
|
| 1078 |
+
|
| 1079 |
+
#pragma unroll
|
| 1080 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 1081 |
+
T dist = segs[i].LengthSq();
|
| 1082 |
+
if (dist < minDist) {
|
| 1083 |
+
minDist = dist;
|
| 1084 |
+
idx = i;
|
| 1085 |
+
}
|
| 1086 |
+
}
|
| 1087 |
+
|
| 1088 |
+
return segs[idx];
|
| 1089 |
+
}
|
| 1090 |
+
|
| 1091 |
+
// Find the distance between a point and the nearest point along the specified segment
|
| 1092 |
+
template<typename T>
|
| 1093 |
+
__any_device__
|
| 1094 |
+
T distance_to_segment(const Point_<T> &pt, const Segment_<T> &seg)
|
| 1095 |
+
{
|
| 1096 |
+
auto projection = nearest_point_on_segment(pt, seg);
|
| 1097 |
+
|
| 1098 |
+
auto dist = length(pt - projection);
|
| 1099 |
+
|
| 1100 |
+
return dist;
|
| 1101 |
+
}
|
nemo-retriever-ocr/cpp/geometry_api/calc_poly_min_rrect.cpp
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "geometry_api.h"
|
| 6 |
+
|
| 7 |
+
#include "../graph_detection/encode_util.h"
|
| 8 |
+
|
| 9 |
+
#include "../geometry.h"
|
| 10 |
+
#include "matrix2x2.h"
|
| 11 |
+
|
| 12 |
+
using namespace std;
|
| 13 |
+
|
| 14 |
+
template<typename T>
|
| 15 |
+
void _calc_poly_min_rrect(const torch::TensorAccessor<T, 2> vertices, torch::TensorAccessor<T, 2> outRRect);
|
| 16 |
+
template<typename T>
|
| 17 |
+
void _calc_quad_min_rrect(const torch::TensorAccessor<T, 2> vertices, torch::TensorAccessor<T, 2> outRRect);
|
| 18 |
+
|
| 19 |
+
torch::Tensor calc_poly_min_rrect(torch::Tensor vertices)
|
| 20 |
+
{
|
| 21 |
+
if (vertices.size(0) < 3) {
|
| 22 |
+
throw runtime_error("Invalid polygon! Expected >= 3 vertices, got " + to_string(vertices.size(0)));
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
auto ret = torch::empty({ 4, 2 }, vertices.options());
|
| 26 |
+
|
| 27 |
+
auto retAcc = ret.accessor<float, 2>();
|
| 28 |
+
|
| 29 |
+
if (vertices.size(0) != 4) {
|
| 30 |
+
// OpenCV requires this to be a contiguous buffer
|
| 31 |
+
vertices = vertices.contiguous();
|
| 32 |
+
_calc_poly_min_rrect(vertices.accessor<float, 2>(), retAcc);
|
| 33 |
+
} else {
|
| 34 |
+
_calc_quad_min_rrect(vertices.accessor<float, 2>(), retAcc);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
return ret;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
template<typename T>
|
| 42 |
+
void _calc_bounds(const torch::TensorAccessor<T, 2> &vertices, torch::TensorAccessor<T, 2> &outRRect,
|
| 43 |
+
const Point_<T> &leftCenter, const Point_<T> &rightCenter)
|
| 44 |
+
{
|
| 45 |
+
typedef Point_<T> Pointf;
|
| 46 |
+
|
| 47 |
+
Pointf vecAlong = rightCenter - leftCenter;
|
| 48 |
+
auto alongMag = length(vecAlong);
|
| 49 |
+
|
| 50 |
+
if (alongMag == 0.0f) {
|
| 51 |
+
throw runtime_error("Invalid polygon!");
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
vecAlong /= alongMag;
|
| 55 |
+
|
| 56 |
+
Pointf dOrtho{ -vecAlong.Y, vecAlong.X };
|
| 57 |
+
|
| 58 |
+
Pointf center = (leftCenter + rightCenter) / 2.0f;
|
| 59 |
+
|
| 60 |
+
Matrix2x2<T> rotMat{ vecAlong, dOrtho };
|
| 61 |
+
|
| 62 |
+
auto get_fn = [&vertices, ¢er] (int64_t i) {
|
| 63 |
+
return Pointf{ vertices[i] } - center;
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
// All we care about it getting the bounds in the normalized space, so this saves
|
| 67 |
+
// us from having to do any memory allocation
|
| 68 |
+
Pointf minPt{ 0, 0 }, maxPt{ 0, 0 };
|
| 69 |
+
auto tx_fn = [&minPt, &maxPt] (int64_t i, const Pointf &pt) {
|
| 70 |
+
minPt = min(minPt, pt);
|
| 71 |
+
maxPt = max(maxPt, pt);
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
matmul_fn(vertices.size(0), get_fn, rotMat, tx_fn, transpose_tag{});
|
| 75 |
+
|
| 76 |
+
Pointf rotBox[4] = {
|
| 77 |
+
minPt,
|
| 78 |
+
{ maxPt.X, minPt.Y },
|
| 79 |
+
maxPt,
|
| 80 |
+
{ minPt.X, maxPt.Y }
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
auto get_fn2 = [&rotBox] (int64_t i) {
|
| 84 |
+
return rotBox[i];
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
auto assign_fn = [¢er, &outRRect] (int64_t i, const Pointf &pt) {
|
| 88 |
+
outRRect[i][0] = pt.X + center.X;
|
| 89 |
+
outRRect[i][1] = pt.Y + center.Y;
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
matmul_fn(4, get_fn2, rotMat, assign_fn, contiguous_tag{});
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
template<typename T>
|
| 97 |
+
void _calc_poly_min_rrect(const torch::TensorAccessor<T, 2> vertices, torch::TensorAccessor<T, 2> outRRect)
|
| 98 |
+
{
|
| 99 |
+
typedef Point_<T> Pointf;
|
| 100 |
+
typedef Polygon_<T> Polygonf;
|
| 101 |
+
|
| 102 |
+
Polygonf poly{ vertices.data(), vertices.size(0) };
|
| 103 |
+
|
| 104 |
+
vector<graph_detection::Edge> bottoms = graph_detection::find_bottom(poly, false);
|
| 105 |
+
|
| 106 |
+
if (bottoms.size() != 2) {
|
| 107 |
+
throw runtime_error("Invalid polygon!");
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
vector<graph_detection::Edge> longEdges[2];
|
| 111 |
+
graph_detection::find_long_edges(poly, bottoms.data(), longEdges[0], longEdges[1]);
|
| 112 |
+
|
| 113 |
+
////
|
| 114 |
+
// Determine which edge is above the other
|
| 115 |
+
Pointf cpts[2];
|
| 116 |
+
for (size_t i = 0; i < 2; ++i) {
|
| 117 |
+
auto &pedge = longEdges[i];
|
| 118 |
+
|
| 119 |
+
cpts[i] = Pointf{0.0f, 0.0f};
|
| 120 |
+
float ct = 0;
|
| 121 |
+
for (size_t z = 0; z < pedge.size(); ++z) {
|
| 122 |
+
auto edge = pedge[z];
|
| 123 |
+
Pointf p1 = poly[edge.A];
|
| 124 |
+
Pointf p2 = poly[edge.B];
|
| 125 |
+
cpts[i] += (p1 + p2) / 2.0f;
|
| 126 |
+
ct += 1.0f;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
if (ct < 1.0f) {
|
| 130 |
+
throw runtime_error("Edge was empty!");
|
| 131 |
+
}
|
| 132 |
+
cpts[i] /= ct;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
float vpp = graph_detection::vector_sin(cpts[0] - cpts[1]);
|
| 136 |
+
if (vpp >= 0) {
|
| 137 |
+
swap(bottoms[0], bottoms[1]);
|
| 138 |
+
}
|
| 139 |
+
////
|
| 140 |
+
|
| 141 |
+
Pointf edge1[2] = { poly[bottoms[0].A], poly[bottoms[0].B] };
|
| 142 |
+
Pointf edge2[2] = { poly[bottoms[1].A], poly[bottoms[1].B] };
|
| 143 |
+
|
| 144 |
+
Pointf c0 = (edge1[0] + edge1[1]) / 2.0f;
|
| 145 |
+
Pointf c1 = (edge2[0] + edge2[1]) / 2.0f;
|
| 146 |
+
|
| 147 |
+
_calc_bounds(vertices, outRRect, c0, c1);
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
template<typename T>
|
| 151 |
+
void _calc_quad_min_rrect(const torch::TensorAccessor<T, 2> vertices, torch::TensorAccessor<T, 2> outRRect)
|
| 152 |
+
{
|
| 153 |
+
typedef Point_<T> Pointf;
|
| 154 |
+
|
| 155 |
+
// Instead of finding an arbitrary rotated box, find a reasonable
|
| 156 |
+
// fit for the quadrangle
|
| 157 |
+
Pointf pts[4] = {
|
| 158 |
+
vertices[0], vertices[1], vertices[2], vertices[3]
|
| 159 |
+
};
|
| 160 |
+
|
| 161 |
+
Pointf c0 = (pts[0] + pts[3]) / 2.0f;
|
| 162 |
+
Pointf c1 = (pts[1] + pts[2]) / 2.0f;
|
| 163 |
+
|
| 164 |
+
_calc_bounds(vertices, outRRect, c0, c1);
|
| 165 |
+
}
|
nemo-retriever-ocr/cpp/geometry_api/geometry_api.cpp
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "geometry_api.h"
|
| 6 |
+
|
| 7 |
+
#include "geometry_api_common.h"
|
| 8 |
+
|
| 9 |
+
using namespace std;
|
| 10 |
+
|
| 11 |
+
torch::Tensor rrect_to_quads_gpu(torch::Tensor rrects, float cellSize);
|
| 12 |
+
|
| 13 |
+
template<typename T>
|
| 14 |
+
torch::Tensor rrect_to_quads_impl(torch::Tensor rrects, T cellSize)
|
| 15 |
+
{
|
| 16 |
+
// BHW(5)
|
| 17 |
+
auto rrectAccess = rrects.accessor<T, 4>();
|
| 18 |
+
|
| 19 |
+
T cellOff = cellSize / 2;
|
| 20 |
+
|
| 21 |
+
auto quads = torch::empty({ rrects.size(0), rrects.size(1), rrects.size(2), 4, 2 }, rrects.options());
|
| 22 |
+
|
| 23 |
+
auto quadsAccess = quads.accessor<T, 5>();
|
| 24 |
+
|
| 25 |
+
for (long b = 0; b < rrects.size(0); ++b) {
|
| 26 |
+
for (long y = 0; y < rrects.size(1); ++y) {
|
| 27 |
+
for (long x = 0; x < rrects.size(2); ++x) {
|
| 28 |
+
auto rrect = rrectAccess[b][y][x];
|
| 29 |
+
|
| 30 |
+
auto quad = quadsAccess[b][y][x];
|
| 31 |
+
|
| 32 |
+
assign_rrect_to_quad(rrect, quad, cellSize, cellOff,
|
| 33 |
+
static_cast<T>(x),
|
| 34 |
+
static_cast<T>(y));
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
return quads;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
torch::Tensor rrect_to_quads(torch::Tensor rrects, float cellSize)
|
| 43 |
+
{
|
| 44 |
+
if (rrects.is_cuda()) {
|
| 45 |
+
return rrect_to_quads_gpu(rrects, cellSize);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
torch::Tensor quads;
|
| 49 |
+
AT_DISPATCH_FLOATING_TYPES(
|
| 50 |
+
rrects.scalar_type(),
|
| 51 |
+
"rrect_to_quads_impl",
|
| 52 |
+
([&] {
|
| 53 |
+
quads = rrect_to_quads_impl<scalar_t>(rrects, scalar_t(cellSize));
|
| 54 |
+
})
|
| 55 |
+
);
|
| 56 |
+
|
| 57 |
+
return quads;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
template<typename T>
|
| 62 |
+
torch::Tensor rrect_to_quads_backward_impl(torch::Tensor rrects, torch::Tensor gradOutput)
|
| 63 |
+
{
|
| 64 |
+
// BHW(5)
|
| 65 |
+
auto gradInput = torch::empty_like(rrects);
|
| 66 |
+
|
| 67 |
+
auto rrectAccess = rrects.accessor<T, 4>();
|
| 68 |
+
// BHW42
|
| 69 |
+
auto gradOutputAccess = gradOutput.accessor<T, 5>();
|
| 70 |
+
auto gradInputAccess = gradInput.accessor<T, 4>();
|
| 71 |
+
|
| 72 |
+
for (long b = 0; b < rrects.size(0); ++b) {
|
| 73 |
+
for (long y = 0; y < rrects.size(1); ++y) {
|
| 74 |
+
for (long x = 0; x < rrects.size(2); ++x) {
|
| 75 |
+
assign_grad_rrect_to_quad<T>(rrectAccess[b][y][x], gradOutputAccess[b][y][x], gradInputAccess[b][y][x]);
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
return gradInput;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
torch::Tensor rrect_to_quads_backward_gpu(torch::Tensor rrects, torch::Tensor gradOutput);
|
| 84 |
+
|
| 85 |
+
torch::Tensor rrect_to_quads_backward(torch::Tensor rrects, torch::Tensor gradOutput)
|
| 86 |
+
{
|
| 87 |
+
if (rrects.is_cuda()) {
|
| 88 |
+
return rrect_to_quads_backward_gpu(rrects, gradOutput);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
torch::Tensor gradInput;
|
| 92 |
+
AT_DISPATCH_FLOATING_TYPES(
|
| 93 |
+
rrects.scalar_type(),
|
| 94 |
+
"rrect_to_quads_backward_impl",
|
| 95 |
+
([&] {
|
| 96 |
+
gradInput = rrect_to_quads_backward_impl<scalar_t>(rrects, gradOutput);
|
| 97 |
+
})
|
| 98 |
+
);
|
| 99 |
+
|
| 100 |
+
return gradInput;
|
| 101 |
+
}
|
nemo-retriever-ocr/cpp/geometry_api/geometry_api.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <torch/torch.h>
|
| 8 |
+
|
| 9 |
+
torch::Tensor rrect_to_quads(torch::Tensor rrects, float cellSize);
|
| 10 |
+
torch::Tensor rrect_to_quads_backward(torch::Tensor rrects, torch::Tensor gradOutput);
|
| 11 |
+
|
| 12 |
+
torch::Tensor calc_poly_min_rrect(torch::Tensor vertices);
|
| 13 |
+
|
| 14 |
+
float get_rel_continuation_cos(torch::Tensor rrectA, torch::Tensor rrectB);
|
| 15 |
+
|
| 16 |
+
torch::Tensor get_poly_bounds_quad(torch::Tensor poly);
|
nemo-retriever-ocr/cpp/geometry_api/geometry_api_common.h
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <torch/torch.h>
|
| 8 |
+
|
| 9 |
+
#include "../cuda_intellisense.cuh"
|
| 10 |
+
#include "../geometry.h"
|
| 11 |
+
|
| 12 |
+
#if defined(__NVCC__)
|
| 13 |
+
#include <math_constants.h>
|
| 14 |
+
#define GEO_PI CUDART_PI_F
|
| 15 |
+
#else
|
| 16 |
+
#include <math.h>
|
| 17 |
+
#define GEO_PI M_PI
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
template<typename access_t, typename point_t>
|
| 22 |
+
__device__
|
| 23 |
+
inline
|
| 24 |
+
void pt_assign(access_t acc, const point_t &p) {
|
| 25 |
+
acc[0] = p.X;
|
| 26 |
+
acc[1] = p.Y;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
template<typename T, typename rrect_access_t>
|
| 30 |
+
__device__ __lib_inline__
|
| 31 |
+
InPlaceQuad_<T> cvt_rrect_to_quad(const rrect_access_t &rrect, T cellSize, T cellOff, T x, T y)
|
| 32 |
+
{
|
| 33 |
+
typedef Point_<T> Pointf;
|
| 34 |
+
|
| 35 |
+
Pointf prior{
|
| 36 |
+
x * cellSize + cellOff,
|
| 37 |
+
y * cellSize + cellOff
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
T dTop = rrect[0];
|
| 41 |
+
T dRight = rrect[1];
|
| 42 |
+
T dBottom = rrect[2];
|
| 43 |
+
T dLeft = rrect[3];
|
| 44 |
+
T theta = rrect[4];
|
| 45 |
+
|
| 46 |
+
T piOver2{GEO_PI / 2.0f};
|
| 47 |
+
Pointf vX{ cos(theta), sin(theta) };
|
| 48 |
+
Pointf vY{ cos(theta - piOver2), sin(theta - piOver2) };
|
| 49 |
+
|
| 50 |
+
InPlaceQuad_<T> ret;
|
| 51 |
+
|
| 52 |
+
ret[0] = prior - vX * dLeft + vY * dTop;
|
| 53 |
+
ret[1] = prior + vX * dRight + vY * dTop;
|
| 54 |
+
ret[2] = prior + vX * dRight - vY * dBottom;
|
| 55 |
+
ret[3] = prior - vX * dLeft - vY * dBottom;
|
| 56 |
+
|
| 57 |
+
return ret;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template<typename rrect_access_t, typename quad_access_t, typename T>
|
| 61 |
+
__device__ __lib_inline__
|
| 62 |
+
void assign_rrect_to_quad(const rrect_access_t &rrect, quad_access_t &quad,
|
| 63 |
+
T cellSize, T cellOff, T x, T y)
|
| 64 |
+
{
|
| 65 |
+
const InPlaceQuad_<T> cvQuad = cvt_rrect_to_quad<T>(rrect, cellSize, cellOff, x, y);
|
| 66 |
+
|
| 67 |
+
const T *pInQuad = reinterpret_cast<const T*>(&cvQuad);
|
| 68 |
+
T *pOutQuad = reinterpret_cast<T*>(quad.data());
|
| 69 |
+
|
| 70 |
+
#pragma unroll
|
| 71 |
+
for (uint32_t i = 0; i < 8; ++i) {
|
| 72 |
+
pOutQuad[i] = pInQuad[i];
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
template<typename T, typename rrect_access_t, typename quad_access_t>
|
| 77 |
+
__device__
|
| 78 |
+
inline
|
| 79 |
+
void assign_grad_rrect_to_quad(const rrect_access_t &rrect,
|
| 80 |
+
const quad_access_t &gradOutput,
|
| 81 |
+
rrect_access_t gradInput)
|
| 82 |
+
{
|
| 83 |
+
typedef Point_<T> Pointf;
|
| 84 |
+
|
| 85 |
+
T Top = rrect[0];
|
| 86 |
+
T Right = rrect[1];
|
| 87 |
+
T Bottom = rrect[2];
|
| 88 |
+
T Left = rrect[3];
|
| 89 |
+
T theta = rrect[4];
|
| 90 |
+
|
| 91 |
+
T piOver2{GEO_PI / 2.0f};
|
| 92 |
+
Pointf vX{ cos(theta), sin(theta) };
|
| 93 |
+
Pointf vY{ cos(theta - piOver2), sin(theta - piOver2) };
|
| 94 |
+
|
| 95 |
+
Pointf dVX{ -vX.Y, vX.X };
|
| 96 |
+
Pointf dVY{ -vY.Y, vY.X };
|
| 97 |
+
|
| 98 |
+
Pointf gP0 = gradOutput[0],
|
| 99 |
+
gP1 = gradOutput[1],
|
| 100 |
+
gP2 = gradOutput[2],
|
| 101 |
+
gP3 = gradOutput[3];
|
| 102 |
+
|
| 103 |
+
// Top
|
| 104 |
+
gradInput[0] = (gP0 * vY + gP1 * vY).Sum();
|
| 105 |
+
// Right
|
| 106 |
+
gradInput[1] = (gP1 * vX + gP2 * vX).Sum();
|
| 107 |
+
// Bottom
|
| 108 |
+
gradInput[2] = -(gP2 * vY + gP3 * vY).Sum();
|
| 109 |
+
// Left
|
| 110 |
+
gradInput[3] = -(gP0 * vX + gP3 * vX).Sum();
|
| 111 |
+
|
| 112 |
+
// Theta
|
| 113 |
+
gradInput[4] = (
|
| 114 |
+
gP0 * (-Left * dVX + Top * dVY) +
|
| 115 |
+
gP1 * (Right * dVX + Top * dVY) +
|
| 116 |
+
gP2 * (Right * dVX - Bottom * dVY) +
|
| 117 |
+
gP3 * (-Left * dVX - Bottom * dVY)
|
| 118 |
+
).Sum();
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
#undef GEO_PI
|
nemo-retriever-ocr/cpp/geometry_api/geometry_api_gpu.cu
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "geometry_api.h"
|
| 6 |
+
|
| 7 |
+
#include "../geometry.h"
|
| 8 |
+
#include "../cuda_intellisense.cuh"
|
| 9 |
+
#include "geometry_api_common.h"
|
| 10 |
+
|
| 11 |
+
#include <trove/ptr.h>
|
| 12 |
+
|
| 13 |
+
using namespace std;
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
template<typename T>
|
| 17 |
+
struct RRect_ {
|
| 18 |
+
T Data[5];
|
| 19 |
+
|
| 20 |
+
template<typename index_t>
|
| 21 |
+
__device__
|
| 22 |
+
const T &operator[](index_t i) const { return Data[i]; }
|
| 23 |
+
template<typename index_t>
|
| 24 |
+
__device__
|
| 25 |
+
T &operator[](index_t i) { return Data[i]; }
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
template<typename T>
|
| 29 |
+
__global__
|
| 30 |
+
void device_rrect_to_quads_gpu(torch::PackedTensorAccessor64<T, 2> rrectAccess,
|
| 31 |
+
torch::PackedTensorAccessor64<T, 3> quadsAccess,
|
| 32 |
+
int64_t numRows, int64_t numCols,
|
| 33 |
+
T cellSize)
|
| 34 |
+
{
|
| 35 |
+
typedef Point_<T> Pointf;
|
| 36 |
+
typedef RRect_<T> RRectf;
|
| 37 |
+
typedef InPlaceQuad_<T> Quadf;
|
| 38 |
+
constexpr T TWO = 2;
|
| 39 |
+
|
| 40 |
+
const int64_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 41 |
+
|
| 42 |
+
if (jobIdx >= rrectAccess.size(0)) {
|
| 43 |
+
return;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
int64_t row = jobIdx / numCols;
|
| 47 |
+
const int64_t col = jobIdx - (row * numCols);
|
| 48 |
+
row = row % numRows;
|
| 49 |
+
|
| 50 |
+
auto rawRRect = reinterpret_cast<RRectf*>(rrectAccess.data());
|
| 51 |
+
auto rawQuad = reinterpret_cast<Quadf*>(quadsAccess.data());
|
| 52 |
+
#if defined(NDEBUG)
|
| 53 |
+
trove::coalesced_ptr<RRectf> pRRect(rawRRect);
|
| 54 |
+
trove::coalesced_ptr<Quadf> pQuad(rawQuad);
|
| 55 |
+
#else
|
| 56 |
+
auto pRRect = rawRRect;
|
| 57 |
+
auto pQuad = rawQuad;
|
| 58 |
+
#endif
|
| 59 |
+
|
| 60 |
+
RRectf rrect = pRRect[jobIdx];
|
| 61 |
+
|
| 62 |
+
T cellOff = cellSize / TWO;
|
| 63 |
+
Quadf cvQuad = cvt_rrect_to_quad<T>(rrect, cellSize, cellOff, col, row);
|
| 64 |
+
|
| 65 |
+
pQuad[jobIdx] = cvQuad;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
torch::Tensor rrect_to_quads_gpu(torch::Tensor rrects, float cellSize)
|
| 69 |
+
{
|
| 70 |
+
if (!rrects.is_contiguous()) {
|
| 71 |
+
throw std::runtime_error("Expected the rrects to be contiguous!");
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
torch::Tensor quads = torch::empty({ rrects.size(0), rrects.size(1), rrects.size(2), 4, 2 }, rrects.options());
|
| 75 |
+
|
| 76 |
+
auto rrFlat = rrects.flatten(0, 2);
|
| 77 |
+
auto qFlat = quads.flatten(0, 2);
|
| 78 |
+
|
| 79 |
+
dim3 blockSize(96);
|
| 80 |
+
dim3 gridSize(div_up(qFlat.size(0), blockSize.x));
|
| 81 |
+
|
| 82 |
+
if (quads.numel() > 0) {
|
| 83 |
+
AT_DISPATCH_FLOATING_TYPES(
|
| 84 |
+
quads.scalar_type(),
|
| 85 |
+
"cuda_rrect_to_quads",
|
| 86 |
+
([&] {
|
| 87 |
+
|
| 88 |
+
device_rrect_to_quads_gpu<scalar_t> KERNEL_ARG2(gridSize, blockSize) (
|
| 89 |
+
rrFlat.packed_accessor64<scalar_t, 2>(),
|
| 90 |
+
qFlat.packed_accessor64<scalar_t, 3>(),
|
| 91 |
+
rrects.size(1), rrects.size(2),
|
| 92 |
+
cellSize
|
| 93 |
+
);
|
| 94 |
+
|
| 95 |
+
})
|
| 96 |
+
);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
return quads;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
template<typename scalar_t>
|
| 103 |
+
__global__
|
| 104 |
+
void device_rrect_to_quads_backward_gpu(torch::PackedTensorAccessor64<scalar_t, 2> rrect,
|
| 105 |
+
torch::PackedTensorAccessor64<scalar_t, 3> gradOutput,
|
| 106 |
+
torch::PackedTensorAccessor64<scalar_t, 2> gradInput)
|
| 107 |
+
{
|
| 108 |
+
const int64_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 109 |
+
|
| 110 |
+
if (jobIdx >= rrect.size(0)) return;
|
| 111 |
+
|
| 112 |
+
assign_grad_rrect_to_quad<scalar_t>(rrect[jobIdx], gradOutput[jobIdx], gradInput[jobIdx]);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
torch::Tensor rrect_to_quads_backward_gpu(torch::Tensor rrects, torch::Tensor gradOutput)
|
| 117 |
+
{
|
| 118 |
+
auto gradInput = torch::empty_like(rrects);
|
| 119 |
+
|
| 120 |
+
auto flatRRects = rrects.reshape({ -1, 5 });
|
| 121 |
+
auto flatGradOutput = gradOutput.reshape({ -1, 4, 2 });
|
| 122 |
+
auto flatGradInput = gradInput.reshape({ -1, 5 });
|
| 123 |
+
|
| 124 |
+
dim3 blockSize(32);
|
| 125 |
+
dim3 gridSize(div_up(rrects.size(0) * rrects.size(1) * rrects.size(2), blockSize.x));
|
| 126 |
+
|
| 127 |
+
if (rrects.numel() > 0) {
|
| 128 |
+
AT_DISPATCH_FLOATING_TYPES(
|
| 129 |
+
rrects.scalar_type(),
|
| 130 |
+
"cuda_rrect_to_quads_backward",
|
| 131 |
+
([&] {
|
| 132 |
+
device_rrect_to_quads_backward_gpu KERNEL_ARG2(gridSize, blockSize) (
|
| 133 |
+
flatRRects.packed_accessor64<scalar_t, 2>(),
|
| 134 |
+
flatGradOutput.packed_accessor64<scalar_t, 3>(),
|
| 135 |
+
flatGradInput.packed_accessor64<scalar_t, 2>()
|
| 136 |
+
);
|
| 137 |
+
})
|
| 138 |
+
);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
return gradInput;
|
| 142 |
+
}
|
nemo-retriever-ocr/cpp/geometry_api/get_rel_continuation_cos.cpp
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "geometry_api.h"
|
| 6 |
+
|
| 7 |
+
#include "../geometry.h"
|
| 8 |
+
|
| 9 |
+
using namespace std;
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
float get_rel_continuation_cos(torch::Tensor rrectATensor, torch::Tensor rrectBTensor)
|
| 13 |
+
{
|
| 14 |
+
typedef Point_<float> Pointf;
|
| 15 |
+
|
| 16 |
+
if (rrectATensor.size(0) != 4 || rrectBTensor.size(0) != 4) {
|
| 17 |
+
throw runtime_error("Invalid rrect arguments. Both must have 4 vertices! A=" +
|
| 18 |
+
to_string(rrectATensor.size(0)) + ", B=" + to_string(rrectBTensor.size(0)));
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
auto rrectA = rrectATensor.accessor<float, 2>();
|
| 22 |
+
auto rrectB = rrectBTensor.accessor<float, 2>();
|
| 23 |
+
|
| 24 |
+
Pointf aPts[4] = {
|
| 25 |
+
rrectA[0], rrectA[1], rrectA[2], rrectA[3]
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
auto c1 = (aPts[0] + aPts[3]) / 2.0f;
|
| 29 |
+
auto c2 = (aPts[1] + aPts[2]) / 2.0f;
|
| 30 |
+
|
| 31 |
+
auto aDir = c2 - c1;
|
| 32 |
+
auto aLen = length(aDir);
|
| 33 |
+
|
| 34 |
+
if (aLen > 0) {
|
| 35 |
+
aDir /= aLen;
|
| 36 |
+
} else {
|
| 37 |
+
aDir = Pointf{ 1, 0 };
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
auto centerA = (c1 + c2) / 2.0f;
|
| 41 |
+
|
| 42 |
+
Pointf bPts[4] = {
|
| 43 |
+
rrectB[0], rrectB[1], rrectB[2], rrectB[3]
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
auto centerB = (bPts[0] + bPts[1] + bPts[2] + bPts[3]) / 4.0f;
|
| 47 |
+
|
| 48 |
+
auto connDir = centerB - centerA;
|
| 49 |
+
auto connLen = length(connDir);
|
| 50 |
+
|
| 51 |
+
if (connLen == 0.0f) {
|
| 52 |
+
return 1.0f;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
connDir /= connLen;
|
| 56 |
+
|
| 57 |
+
auto cosT = dot(aDir, connDir);
|
| 58 |
+
|
| 59 |
+
return cosT;
|
| 60 |
+
}
|
nemo-retriever-ocr/cpp/geometry_api/matrix2x2.h
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include "../geometry.h"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
struct contiguous_tag{};
|
| 11 |
+
|
| 12 |
+
struct transpose_tag{};
|
| 13 |
+
|
| 14 |
+
template<typename layout_t, uint32_t R, uint32_t C>
|
| 15 |
+
struct Matrix2x2_Offset;
|
| 16 |
+
|
| 17 |
+
template<uint32_t R, uint32_t C>
|
| 18 |
+
struct Matrix2x2_Offset<contiguous_tag, R, C>
|
| 19 |
+
{
|
| 20 |
+
static const uint32_t OFFSET = R * 2 + C;
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
template<uint32_t R, uint32_t C>
|
| 24 |
+
struct Matrix2x2_Offset<transpose_tag, R, C>
|
| 25 |
+
{
|
| 26 |
+
static const uint32_t OFFSET = C * 2 + R;
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
template<typename T, typename layout_t, uint32_t R, uint32_t C>
|
| 31 |
+
struct Matrix2x2_Indexor
|
| 32 |
+
{
|
| 33 |
+
static const uint32_t OFFSET = Matrix2x2_Offset<layout_t, R, C>::OFFSET;
|
| 34 |
+
|
| 35 |
+
static T &get(T *data) { return data[OFFSET]; }
|
| 36 |
+
static const T get(const T *data) { return data[OFFSET]; }
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
template<typename T>
|
| 41 |
+
struct Matrix2x2
|
| 42 |
+
{
|
| 43 |
+
Matrix2x2() = default;
|
| 44 |
+
Matrix2x2(T r0c0, T r0c1, T r1c0, T r1c1)
|
| 45 |
+
: m_data{ r0c0, r0c1, r1c0, r1c1 }
|
| 46 |
+
{
|
| 47 |
+
}
|
| 48 |
+
Matrix2x2(const Point_<T> &r0, const Point_<T> &r1)
|
| 49 |
+
: m_data{ r0.X, r0.Y, r1.X, r1.Y }
|
| 50 |
+
{
|
| 51 |
+
}
|
| 52 |
+
Matrix2x2(const Point_<T> &r0, const Point_<T> &r1, transpose_tag)
|
| 53 |
+
: m_data{ r0.X, r1.X, r0.Y, r1.Y }
|
| 54 |
+
{
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
inline T &operator[](uint32_t i) { return m_data[i]; }
|
| 58 |
+
inline const T operator[](uint32_t i) const { return m_data[i]; }
|
| 59 |
+
|
| 60 |
+
T m_data[4];
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
template<typename T, typename layout_t>
|
| 64 |
+
struct Matrix2x2_View
|
| 65 |
+
{
|
| 66 |
+
Matrix2x2_View(const Matrix2x2<T> &m) : m_data(m.m_data) {}
|
| 67 |
+
|
| 68 |
+
const T *m_data;
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
template<uint32_t R, uint32_t C, typename T, typename layout_t>
|
| 72 |
+
const T get(const Matrix2x2_View<T, layout_t> &m)
|
| 73 |
+
{
|
| 74 |
+
return Matrix2x2_Indexor<T, layout_t, R, C>::get(m.m_data);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
template<typename T, typename get_pt_t, typename callback_t, typename layout_t = contiguous_tag>
|
| 78 |
+
inline
|
| 79 |
+
void matmul_fn(int64_t N, const get_pt_t &get_fn, const Matrix2x2<T> &mat, const callback_t &callback,
|
| 80 |
+
layout_t lt = layout_t{})
|
| 81 |
+
{
|
| 82 |
+
Matrix2x2_View<T, layout_t> m{ mat };
|
| 83 |
+
|
| 84 |
+
#pragma omp simd
|
| 85 |
+
for (int64_t i = 0; i < N; ++i) {
|
| 86 |
+
Point_<T> pt = get_fn(i);
|
| 87 |
+
|
| 88 |
+
T x = pt.X * get<0, 0>(m) + pt.Y * get<1, 0>(m);
|
| 89 |
+
T y = pt.X * get<0, 1>(m) + pt.Y * get<1, 1>(m);
|
| 90 |
+
|
| 91 |
+
callback(i, Point_<T>{ x, y });
|
| 92 |
+
}
|
| 93 |
+
}
|
nemo-retriever-ocr/cpp/geometry_api/poly_bounds_quad.cpp
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "geometry_api.h"
|
| 6 |
+
|
| 7 |
+
using namespace std;
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
template<typename T>
|
| 11 |
+
void pt_assign(torch::TensorAccessor<T, 1> acc, T x, T y)
|
| 12 |
+
{
|
| 13 |
+
acc[0] = x;
|
| 14 |
+
acc[1] = y;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
template<typename T>
|
| 19 |
+
void poly_bounds_quad_impl(torch::TensorAccessor<T, 2> poly, torch::TensorAccessor<T, 2> outBounds)
|
| 20 |
+
{
|
| 21 |
+
T minX = poly[0][0],
|
| 22 |
+
minY = poly[0][1],
|
| 23 |
+
maxX = poly[0][0],
|
| 24 |
+
maxY = poly[0][1];
|
| 25 |
+
|
| 26 |
+
const int64_t numVertices = poly.size(0);
|
| 27 |
+
|
| 28 |
+
for (int64_t i = 0; i < numVertices; ++i) {
|
| 29 |
+
auto vert = poly[i];
|
| 30 |
+
|
| 31 |
+
minX = min(minX, vert[0]);
|
| 32 |
+
maxX = max(maxX, vert[0]);
|
| 33 |
+
|
| 34 |
+
minY = min(minY, vert[1]);
|
| 35 |
+
maxY = max(maxY, vert[1]);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
pt_assign(outBounds[0], minX, minY);
|
| 39 |
+
pt_assign(outBounds[1], maxX, minY);
|
| 40 |
+
pt_assign(outBounds[2], maxX, maxY);
|
| 41 |
+
pt_assign(outBounds[3], minX, maxY);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
torch::Tensor get_poly_bounds_quad(torch::Tensor poly)
|
| 46 |
+
{
|
| 47 |
+
auto ret = torch::empty({ 4, 2 }, poly.options());
|
| 48 |
+
|
| 49 |
+
AT_DISPATCH_FLOATING_TYPES(
|
| 50 |
+
poly.scalar_type(),
|
| 51 |
+
"poly_bounds_quad_impl",
|
| 52 |
+
([&] {
|
| 53 |
+
poly_bounds_quad_impl(
|
| 54 |
+
poly.accessor<scalar_t, 2>(),
|
| 55 |
+
ret.accessor<scalar_t, 2>()
|
| 56 |
+
);
|
| 57 |
+
})
|
| 58 |
+
);
|
| 59 |
+
|
| 60 |
+
return ret;
|
| 61 |
+
}
|
nemo-retriever-ocr/cpp/graph_detection/encode_util.cpp
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "encode_util.h"
|
| 6 |
+
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <numeric>
|
| 9 |
+
#include <sstream>
|
| 10 |
+
|
| 11 |
+
#include "../third_party/clipper/clipper.hpp"
|
| 12 |
+
|
| 13 |
+
using namespace std;
|
| 14 |
+
|
| 15 |
+
namespace graph_detection {
|
| 16 |
+
|
| 17 |
+
template<typename T>
|
| 18 |
+
struct Candidate : Edge {
|
| 19 |
+
T C;
|
| 20 |
+
|
| 21 |
+
Candidate() = default;
|
| 22 |
+
Candidate(int32_t a, int32_t b, T c) : Edge(a, b), C(c) {}
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
struct DistStruct {
|
| 26 |
+
Candidate<Pointf> A;
|
| 27 |
+
Candidate<Pointf> B;
|
| 28 |
+
float Dist;
|
| 29 |
+
|
| 30 |
+
DistStruct() = default;
|
| 31 |
+
DistStruct(Candidate<Pointf> a, Candidate<Pointf> b, float dist) : A(a), B(b), Dist(dist) {}
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
template<typename T>
|
| 35 |
+
float vec_cos(const Point_<T> &a, const Point_<T> &b)
|
| 36 |
+
{
|
| 37 |
+
return dot(a, b) / (length(a) * length(b) + 1e-8);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
template<typename T, typename Fn = std::less<T>>
|
| 41 |
+
vector<size_t> arg_sort(const vector<T> &vec, Fn comp = Fn())
|
| 42 |
+
{
|
| 43 |
+
vector<size_t> ret;
|
| 44 |
+
ret.reserve(vec.size());
|
| 45 |
+
for (size_t i = 0; i < vec.size(); ++i) {
|
| 46 |
+
ret.push_back(i);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
sort(begin(ret), end(ret),
|
| 50 |
+
[&vec, &comp] (size_t idxA, size_t idxB) {
|
| 51 |
+
return comp(vec[idxA], vec[idxB]);
|
| 52 |
+
}
|
| 53 |
+
);
|
| 54 |
+
|
| 55 |
+
return ret;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
float edge_length(const Polygon_<float> &poly, const vector<Edge> &edges);
|
| 60 |
+
|
| 61 |
+
vector<Edge> find_bottom(const Polygon_<float> &poly, bool useVertexOrder)
|
| 62 |
+
{
|
| 63 |
+
if (poly.Count < 4) {
|
| 64 |
+
throw runtime_error("Invalid polygon. Fewer than 4 vertices!");
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// If we trust the source of the geometries, then this saves us both computation,
|
| 68 |
+
// but can also be more reliable since we won't reorder the vertices
|
| 69 |
+
if (useVertexOrder) {
|
| 70 |
+
if ((poly.Count % 2) == 1) {
|
| 71 |
+
throw runtime_error("Can't use trusted vertex order when the vertex count is odd!");
|
| 72 |
+
}
|
| 73 |
+
int32_t halfCt = poly.Count / 2;
|
| 74 |
+
return { { halfCt - 1, halfCt },
|
| 75 |
+
{ static_cast<int32_t>(poly.Count) - 1, 0 } };
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
if (poly.Count == 4) {
|
| 79 |
+
float d1 = length(poly[1] - poly[0]) + length(poly[2] - poly[3]);
|
| 80 |
+
float d2 = length(poly[2] - poly[1]) + length(poly[0] - poly[3]);
|
| 81 |
+
|
| 82 |
+
if (4 * d1 < d2) {
|
| 83 |
+
return { { 0, 1 }, { 2, 3 } };
|
| 84 |
+
} else {
|
| 85 |
+
return { { 1, 2 }, { 3, 0 } };
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
auto idx_wrap = [&poly] (size_t idx) {
|
| 90 |
+
return poly[idx % poly.Count];
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
vector<Candidate<float>> candidates;
|
| 94 |
+
for (size_t i = 1; i < (poly.Count + 1); ++i) {
|
| 95 |
+
auto vPrev = idx_wrap(i) - idx_wrap(i - 1);
|
| 96 |
+
auto vNext = idx_wrap(i + 2) - idx_wrap(i + 1);
|
| 97 |
+
|
| 98 |
+
// We're looking for the segment where the preceding and following segment
|
| 99 |
+
// essentially travel in opposite directions
|
| 100 |
+
if (vec_cos(vPrev, vNext) < -0.875f) {
|
| 101 |
+
auto currSeg = idx_wrap(i) - idx_wrap(i + 1);
|
| 102 |
+
candidates.emplace_back(i % poly.Count, (i + 1) % poly.Count, length(currSeg));
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if (candidates.size() != 2 || candidates[0].A == candidates[1].B || candidates[0].B == candidates[1].A) {
|
| 107 |
+
// If candidate number < 2, or two bottom are joined, select 2 farthest edge
|
| 108 |
+
vector<Candidate<Pointf>> midList;
|
| 109 |
+
for (size_t i = 0; i < poly.Count; ++i) {
|
| 110 |
+
Pointf midPoint = (idx_wrap(i) + idx_wrap(i + 1)) / 2.0f;
|
| 111 |
+
midList.emplace_back(i, (i + 1) % poly.Count, midPoint);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
vector<DistStruct> distList;
|
| 115 |
+
|
| 116 |
+
// Only found one good candidate, so search for the edge that's the furthest from this candidate
|
| 117 |
+
if (candidates.size() == 1) {
|
| 118 |
+
auto idx1a = candidates.back().A;
|
| 119 |
+
auto idx1b = candidates.back().B;
|
| 120 |
+
Candidate<Pointf> cand1{ idx1a, idx1b, (idx_wrap(idx1a) + idx_wrap(idx1b)) / 2.0f };
|
| 121 |
+
for (size_t j = 0; j < poly.Count; ++j) {
|
| 122 |
+
auto &cand2 = midList[j];
|
| 123 |
+
|
| 124 |
+
if (cand1.Touches(cand2)) continue;
|
| 125 |
+
|
| 126 |
+
float dist = length(cand1.C - cand2.C);
|
| 127 |
+
distList.emplace_back(cand1, cand2, dist);
|
| 128 |
+
}
|
| 129 |
+
} else {
|
| 130 |
+
for (size_t i = 0; i < poly.Count; ++i) {
|
| 131 |
+
for (size_t j = i + 1; j < poly.Count; ++j) {
|
| 132 |
+
auto &cand1 = midList[i];
|
| 133 |
+
auto &cand2 = midList[j];
|
| 134 |
+
|
| 135 |
+
if (cand1.Touches(cand2)) continue;
|
| 136 |
+
|
| 137 |
+
float dist = length(cand1.C - cand2.C);
|
| 138 |
+
distList.emplace_back(cand1, cand2, dist);
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
sort(begin(distList), end(distList), [] (auto a, auto b) { return a.Dist < b.Dist; });
|
| 143 |
+
|
| 144 |
+
if (distList.empty()) {
|
| 145 |
+
throw runtime_error("No valid bottom candidates found for this polygon!");
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
auto &bEdge = distList.back();
|
| 149 |
+
return vector<Edge>{ bEdge.A, bEdge.B };
|
| 150 |
+
|
| 151 |
+
} else {
|
| 152 |
+
return vector<Edge>{ candidates[0], candidates[1] };
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
void find_long_edges(const Polygon_<float> &poly, Edge *bottoms, vector<Edge> &outLongEdge1, vector<Edge> &outLongEdge2)
|
| 157 |
+
{
|
| 158 |
+
int32_t b1End = bottoms[0].B;
|
| 159 |
+
int32_t b2End = bottoms[1].B;
|
| 160 |
+
|
| 161 |
+
int32_t nPoints = poly.Count;
|
| 162 |
+
|
| 163 |
+
auto accum_into = [nPoints] (int32_t end1, int32_t end2, vector<Edge> &outEdge) {
|
| 164 |
+
int32_t i = (end1 + 1) % nPoints;
|
| 165 |
+
while ((i % nPoints) != end2) {
|
| 166 |
+
int32_t start = i > 0 ? i - 1 : nPoints - 1;
|
| 167 |
+
int32_t end = i % nPoints;
|
| 168 |
+
outEdge.emplace_back(start, end);
|
| 169 |
+
i = (i + 1) % nPoints;
|
| 170 |
+
}
|
| 171 |
+
};
|
| 172 |
+
|
| 173 |
+
accum_into(b1End, b2End, outLongEdge1);
|
| 174 |
+
accum_into(b2End, b1End, outLongEdge2);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
float edge_length(const Polygon_<float> &poly, const vector<Edge> &edges)
|
| 178 |
+
{
|
| 179 |
+
float ret = 0.0f;
|
| 180 |
+
for (const Edge &e : edges) {
|
| 181 |
+
ret += length(poly[e.B] - poly[e.A]);
|
| 182 |
+
}
|
| 183 |
+
return ret;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
vector<float> edge_lengths(const Polygon_<float> &poly, const vector<Edge> &edges)
|
| 187 |
+
{
|
| 188 |
+
if (edges.empty()) {
|
| 189 |
+
throw runtime_error("Found an empty edge!");
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
vector<float> ret;
|
| 193 |
+
ret.reserve(edges.size());
|
| 194 |
+
|
| 195 |
+
for (const Edge &e : edges) {
|
| 196 |
+
ret.push_back(length(poly[e.B] - poly[e.A]));
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
return ret;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
void split_edge_sequence(const Polygon_<float> &poly, const vector<Edge> &edges,
|
| 203 |
+
const vector<float> &edgeLengths, float nParts,
|
| 204 |
+
vector<Pointf> &outPts);
|
| 205 |
+
|
| 206 |
+
void split_edge_sequence_by_step(const Polygon_<float> &poly, const vector<Edge> &longEdge1, const vector<Edge> &longEdge2,
|
| 207 |
+
float step, vector<Pointf> &outInnerPoints1, vector<Pointf> &outInnerPoints2)
|
| 208 |
+
{
|
| 209 |
+
auto edgeLengths1 = edge_lengths(poly, longEdge1);
|
| 210 |
+
auto edgeLengths2 = edge_lengths(poly, longEdge2);
|
| 211 |
+
|
| 212 |
+
float totalLength = (accumulate(begin(edgeLengths1), end(edgeLengths1), 0.0f) + accumulate(begin(edgeLengths2), end(edgeLengths2), 0.0f)) / 2;
|
| 213 |
+
|
| 214 |
+
float nParts = max<float>(ceil(totalLength / step), 2);
|
| 215 |
+
|
| 216 |
+
split_edge_sequence(poly, longEdge1, edgeLengths1, nParts, outInnerPoints1);
|
| 217 |
+
split_edge_sequence(poly, longEdge2, edgeLengths2, nParts, outInnerPoints2);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
void split_edge_sequence(const Polygon_<float> &poly, const vector<Edge> &edges,
|
| 221 |
+
const vector<float> &edgeLengths, float nParts,
|
| 222 |
+
vector<Pointf> &outPts)
|
| 223 |
+
{
|
| 224 |
+
vector<float> elCumSum = vec_cumsum(edgeLengths);
|
| 225 |
+
|
| 226 |
+
float totalLength = elCumSum.back();
|
| 227 |
+
float lengthPerPart = totalLength / (nParts - 1);
|
| 228 |
+
|
| 229 |
+
size_t iNumParts = nParts;
|
| 230 |
+
size_t currNode = 0;
|
| 231 |
+
size_t ctr = 0;
|
| 232 |
+
for (float i = 0.0f; ctr < iNumParts; i += 1.0f, ++ctr) {
|
| 233 |
+
float t = min(i * lengthPerPart, totalLength);
|
| 234 |
+
|
| 235 |
+
while (t > elCumSum[currNode + 1]) {
|
| 236 |
+
++currNode;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
Edge currEdge = edges[currNode];
|
| 240 |
+
Pointf e1 = poly[currEdge.A];
|
| 241 |
+
Pointf e2 = poly[currEdge.B];
|
| 242 |
+
|
| 243 |
+
float currLen = edgeLengths[currNode];
|
| 244 |
+
|
| 245 |
+
Pointf sampledPt;
|
| 246 |
+
|
| 247 |
+
if (currLen > 0) {
|
| 248 |
+
float deltaT = t - elCumSum[currNode];
|
| 249 |
+
float ratio = deltaT / currLen;
|
| 250 |
+
sampledPt = e1 + ratio * (e2 - e1);
|
| 251 |
+
} else {
|
| 252 |
+
sampledPt = e1;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
outPts.push_back(sampledPt);
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
string print_poly(const Polyf &poly) {
|
| 260 |
+
ostringstream oss;
|
| 261 |
+
oss << "[";
|
| 262 |
+
for (size_t i = 0; i < poly.Count; ++i) {
|
| 263 |
+
if (i > 0) {
|
| 264 |
+
oss << ", ";
|
| 265 |
+
}
|
| 266 |
+
oss << "(" << poly[i].X << ", " << poly[i].Y << ")";
|
| 267 |
+
}
|
| 268 |
+
oss << "]";
|
| 269 |
+
return oss.str();
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
} // namespace graph_detection
|
nemo-retriever-ocr/cpp/graph_detection/encode_util.h
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <vector>
|
| 8 |
+
#include <random>
|
| 9 |
+
#include <algorithm>
|
| 10 |
+
|
| 11 |
+
#include "../geometry.h"
|
| 12 |
+
|
| 13 |
+
namespace graph_detection {
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct Edge {
|
| 18 |
+
int32_t A;
|
| 19 |
+
int32_t B;
|
| 20 |
+
|
| 21 |
+
Edge() = default;
|
| 22 |
+
Edge(int32_t a, int32_t b) : A(a), B(b) {}
|
| 23 |
+
|
| 24 |
+
bool Touches(int32_t idx) const { return A == idx || B == idx; }
|
| 25 |
+
bool Touches(const Edge &other) const;
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
inline
|
| 29 |
+
bool edge_touches(const Edge &edge, int32_t vertex) {
|
| 30 |
+
return edge.A == vertex || edge.B == vertex;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
inline
|
| 34 |
+
bool Edge::Touches(const Edge &other) const {
|
| 35 |
+
return edge_touches(other, A) || edge_touches(other, B);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
typedef Point_<float> Pointf;
|
| 39 |
+
typedef AABB_<float> AABBf;
|
| 40 |
+
typedef Polygon_<float> Polyf;
|
| 41 |
+
typedef std::vector<Pointf> Polyline;
|
| 42 |
+
|
| 43 |
+
std::vector<Edge> find_bottom(const Polygon_<float> &poly, bool useVertexOrder);
|
| 44 |
+
|
| 45 |
+
void find_long_edges(const Polygon_<float> &poly, Edge *bottoms, std::vector<Edge> &outLongEdge1, std::vector<Edge> &outLongEdge2);
|
| 46 |
+
|
| 47 |
+
void split_edge_sequence_by_step(const Polygon_<float> &poly, const std::vector<Edge> &longEdge1, const std::vector<Edge> &longEdge2,
|
| 48 |
+
float step, std::vector<Pointf> &outInnerPoints1, std::vector<Pointf> &outInnerPoints2);
|
| 49 |
+
|
| 50 |
+
std::string print_poly(const Polyf &poly);
|
| 51 |
+
|
| 52 |
+
template<typename T>
|
| 53 |
+
inline
|
| 54 |
+
std::vector<T> vec_cumsum(const std::vector<T> &v)
|
| 55 |
+
{
|
| 56 |
+
std::vector<T> ret;
|
| 57 |
+
ret.reserve(v.size() + 1);
|
| 58 |
+
ret.push_back(0);
|
| 59 |
+
for (T val : v) {
|
| 60 |
+
ret.push_back(ret.back() + val);
|
| 61 |
+
}
|
| 62 |
+
return ret;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
template<typename RandEng, typename Fn>
|
| 66 |
+
inline
|
| 67 |
+
void n_choose_k(size_t n, size_t k, RandEng &randEng, Fn fn)
|
| 68 |
+
{
|
| 69 |
+
if (k == 0) return;
|
| 70 |
+
|
| 71 |
+
// TODO(mranzinger): This algorithm can be replaced with sampling from a geometric
|
| 72 |
+
// distribution, which drastically reduces the runtime complexity
|
| 73 |
+
for (size_t i = 0; i < n; ++i) {
|
| 74 |
+
size_t leftover = n - i;
|
| 75 |
+
if (leftover <= k) {
|
| 76 |
+
fn(i);
|
| 77 |
+
--k;
|
| 78 |
+
} else {
|
| 79 |
+
float p = std::uniform_real_distribution<float>(0.0f, 1.0f)(randEng);
|
| 80 |
+
float probSample = float{k} / float{leftover};
|
| 81 |
+
if (p < probSample) {
|
| 82 |
+
fn(i);
|
| 83 |
+
--k;
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
template<typename T>
|
| 90 |
+
inline T clamp(T val, T minVal, T maxVal) {
|
| 91 |
+
return std::max(std::min(val, maxVal), minVal);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
inline
|
| 95 |
+
Pointf avg_point(const std::vector<Pointf> &points)
|
| 96 |
+
{
|
| 97 |
+
return std::accumulate(std::begin(points), std::end(points), Pointf(0,0)) / float(points.size());
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
inline
|
| 101 |
+
float vector_sin(const Pointf &pt)
|
| 102 |
+
{
|
| 103 |
+
// sin = y / len(pt)
|
| 104 |
+
return pt.Y / (length(pt) + 1e-8);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
inline
|
| 108 |
+
float vector_cos(const Pointf &pt)
|
| 109 |
+
{
|
| 110 |
+
// cos = x / len(pt)
|
| 111 |
+
return pt.X / (length(pt) + 1e-8);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
inline
|
| 115 |
+
void vector_cos_sin(const Pointf & pt, float &outCos, float &outSin)
|
| 116 |
+
{
|
| 117 |
+
float len = length(pt) + 1e-8;
|
| 118 |
+
outCos = pt.X / len;
|
| 119 |
+
outSin = pt.Y / len;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
inline
|
| 123 |
+
float point_dist_to_line(const Pointf &l1, const Pointf &l2, const Pointf &pt)
|
| 124 |
+
{
|
| 125 |
+
auto d = l2 - l1;
|
| 126 |
+
|
| 127 |
+
auto lineLen = length(d);
|
| 128 |
+
|
| 129 |
+
if (lineLen > 0) {
|
| 130 |
+
float distance = abs(
|
| 131 |
+
d.Y * pt.X
|
| 132 |
+
- d.X * pt.Y
|
| 133 |
+
+ l2.X * l1.Y
|
| 134 |
+
- l2.Y * l1.X
|
| 135 |
+
) / lineLen;
|
| 136 |
+
return distance;
|
| 137 |
+
} else {
|
| 138 |
+
return length(pt - l1);
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
template<typename T>
|
| 143 |
+
T find_mode(std::vector<T> &inputs) {
|
| 144 |
+
using std::sort;
|
| 145 |
+
using std::begin;
|
| 146 |
+
using std::end;
|
| 147 |
+
|
| 148 |
+
if (inputs.empty()) {
|
| 149 |
+
throw std::runtime_error("Cannot find mode of empty distribution!");
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
sort(begin(inputs), end(inputs));
|
| 153 |
+
|
| 154 |
+
T currVal = inputs[0];
|
| 155 |
+
size_t currCount = 1;
|
| 156 |
+
|
| 157 |
+
T modeVal = inputs[0];
|
| 158 |
+
size_t modeCount = 1;
|
| 159 |
+
|
| 160 |
+
auto commitCurr = [&] () {
|
| 161 |
+
if (currCount > modeCount) {
|
| 162 |
+
modeCount = currCount;
|
| 163 |
+
modeVal = currVal;
|
| 164 |
+
}
|
| 165 |
+
};
|
| 166 |
+
|
| 167 |
+
for (size_t i = 1; i < inputs.size(); ++i) {
|
| 168 |
+
if (inputs[i] == currVal) {
|
| 169 |
+
++currCount;
|
| 170 |
+
} else {
|
| 171 |
+
// Start of a new value
|
| 172 |
+
commitCurr();
|
| 173 |
+
|
| 174 |
+
currCount = 1;
|
| 175 |
+
currVal = inputs[i];
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
commitCurr();
|
| 180 |
+
|
| 181 |
+
return modeVal;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
} // namespace graph_detection
|
nemo-retriever-ocr/cpp/half_ops.cu
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "half_ops.cuh"
|
nemo-retriever-ocr/cpp/half_ops.cuh
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <torch/torch.h>
|
| 8 |
+
|
| 9 |
+
#include "cuda_intellisense.cuh"
|
| 10 |
+
|
| 11 |
+
#ifndef __CUDACC__
|
| 12 |
+
#pragma message("__CUDACC__ not defined!")
|
| 13 |
+
#else
|
| 14 |
+
#pragma message("__CUDACC__ defined!")
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#ifdef __NVCC__
|
| 18 |
+
#define __qr_device__ __device__
|
| 19 |
+
#define __qr_host__ __host__
|
| 20 |
+
#define __qr_inline__ __forceinline__
|
| 21 |
+
#else
|
| 22 |
+
#define __qr_device__
|
| 23 |
+
#define __qr_host__
|
| 24 |
+
#define __qr_inline__ inline
|
| 25 |
+
#endif
|
| 26 |
+
|
| 27 |
+
#ifdef __CUDACC__
|
| 28 |
+
#include <cuda.h>
|
| 29 |
+
#include <cuda_runtime.h>
|
| 30 |
+
#include <cuda_fp16.h>
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
__qr_inline__ __device__ __half operator-(__half v) {
|
| 34 |
+
return __hneg(v);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
__qr_inline__ __device__ __half operator+(__half a, __half b) {
|
| 38 |
+
return __hadd(a, b);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
__qr_inline__ __device__ __half operator-(__half a, __half b) {
|
| 42 |
+
return __hsub(a, b);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
__qr_inline__ __device__ __half operator*(__half a, __half b) {
|
| 46 |
+
return __hmul(a, b);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
__qr_inline__ __device__ __half operator/(__half a, __half b) {
|
| 50 |
+
return __hdiv(a, b);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
__qr_inline__ __device__ bool operator==(__half a, __half b) {
|
| 54 |
+
return __heq(a, b);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
__qr_inline__ __device__ bool operator<(__half a, __half b) {
|
| 58 |
+
return __hlt(a, b);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
__qr_inline__ __device__ bool operator>(__half a, __half b) {
|
| 62 |
+
return __hgt(a, b);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
__qr_inline__ __device__ __half sqrt(__half v) {
|
| 66 |
+
return hsqrt(v);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
__qr_inline__ __device__ __half floor(__half v) {
|
| 70 |
+
return hfloor(v);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
__qr_inline__ __device__ __half ceil(__half v) {
|
| 74 |
+
return hceil(v);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
__qr_inline__ __device__ __half max(__half a, __half b) {
|
| 78 |
+
return a > b ? a : b;
|
| 79 |
+
}
|
| 80 |
+
#endif //__CUDACC__
|
| 81 |
+
|
| 82 |
+
template<typename Src, typename Dest>
|
| 83 |
+
struct Convert {
|
| 84 |
+
__qr_inline__ static __qr_host__ __qr_device__ constexpr Dest From(Src value) { return static_cast<Dest>(value); }
|
| 85 |
+
__qr_inline__ static __qr_host__ __qr_device__ constexpr Src To(Dest value) { return static_cast<Src>(value); }
|
| 86 |
+
__qr_inline__ static __qr_host__ __qr_device__ constexpr Dest LeftToRight(Src value) { return static_cast<Dest>(value); }
|
| 87 |
+
__qr_inline__ static __qr_host__ __qr_device__ constexpr Src RightToLeft(Dest value) { return static_cast<Src>(value); }
|
| 88 |
+
};
|
| 89 |
+
|
| 90 |
+
#ifdef __CUDACC__
|
| 91 |
+
template<>
|
| 92 |
+
struct Convert<__half, float> {
|
| 93 |
+
__qr_inline__ static __host__ __device__ float From(__half value) { return __half2float(value); }
|
| 94 |
+
__qr_inline__ static __host__ __device__ __half To(float value) { return __float2half(value); }
|
| 95 |
+
__qr_inline__ static __host__ __device__ float LeftToRight(__half value) { return __half2float(value); }
|
| 96 |
+
__qr_inline__ static __host__ __device__ __half RightToLeft(float value) { return __float2half(value); }
|
| 97 |
+
};
|
| 98 |
+
|
| 99 |
+
template<typename Dest>
|
| 100 |
+
struct Convert<__half, Dest> : Convert<__half, float> {
|
| 101 |
+
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
namespace at {
|
| 105 |
+
|
| 106 |
+
template<>
|
| 107 |
+
inline __half* TensorBase::mutable_data_ptr() const {
|
| 108 |
+
TORCH_CHECK(scalar_type() == ScalarType::Half,
|
| 109 |
+
"expected scalar type Half but found ",
|
| 110 |
+
c10::toString(scalar_type()));
|
| 111 |
+
return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data());
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
template<>
|
| 115 |
+
inline __half* TensorBase::data_ptr() const {
|
| 116 |
+
TORCH_CHECK(scalar_type() == ScalarType::Half,
|
| 117 |
+
"expected scalar type Half but found ",
|
| 118 |
+
c10::toString(scalar_type()));
|
| 119 |
+
return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data());
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
template<typename T>
|
| 125 |
+
struct remap_half {
|
| 126 |
+
typedef T type;
|
| 127 |
+
};
|
| 128 |
+
|
| 129 |
+
template<>
|
| 130 |
+
struct remap_half<at::Half> {
|
| 131 |
+
typedef __half type;
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
+
template<typename T>
|
| 135 |
+
__half to_half(T val) {
|
| 136 |
+
return Convert<__half, T>::RightToLeft(val);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
template<typename T>
|
| 140 |
+
struct fp_promote {
|
| 141 |
+
typedef T type;
|
| 142 |
+
};
|
| 143 |
+
|
| 144 |
+
template<>
|
| 145 |
+
struct fp_promote<__half> {
|
| 146 |
+
typedef float type;
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
#endif //__CUDACC__
|
nemo-retriever-ocr/cpp/local_ips/local_ips.h
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <torch/torch.h>
|
| 8 |
+
|
| 9 |
+
torch::Tensor ragged_quad_all_2_all_distance_v2(torch::Tensor embedQuads, torch::Tensor quadsPerExample,
|
| 10 |
+
float xFactor, float yFactor,
|
| 11 |
+
bool allowSelfDistance);
|
nemo-retriever-ocr/cpp/local_ips/quad_all_2_all_dist_v2.cu
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
#include <iostream>
|
| 7 |
+
|
| 8 |
+
#include <cooperative_groups.h>
|
| 9 |
+
#include <cooperative_groups/reduce.h>
|
| 10 |
+
|
| 11 |
+
#include <thrust/binary_search.h>
|
| 12 |
+
#include <thrust/execution_policy.h>
|
| 13 |
+
|
| 14 |
+
#include "local_ips.h"
|
| 15 |
+
#include "../cuda_intellisense.cuh"
|
| 16 |
+
#include "../common.h"
|
| 17 |
+
#include "../geometry.h"
|
| 18 |
+
|
| 19 |
+
using namespace std;
|
| 20 |
+
namespace cg = cooperative_groups;
|
| 21 |
+
|
| 22 |
+
typedef Point_<float> Pointf;
|
| 23 |
+
|
| 24 |
+
__device__ inline
|
| 25 |
+
float square(float val) { return val * val; }
|
| 26 |
+
|
| 27 |
+
__global__
|
| 28 |
+
void device_quad_all_2_all_distance_v2(torch::PackedTensorAccessor64<float, 4> allEmbedQuads,
|
| 29 |
+
torch::PackedTensorAccessor64<int64_t, 1> allRegionCounts,
|
| 30 |
+
torch::PackedTensorAccessor64<int64_t, 1> csWorkPerExample,
|
| 31 |
+
torch::PackedTensorAccessor64<float, 3> outDistances,
|
| 32 |
+
float xFactor, float yFactor,
|
| 33 |
+
bool allowSelfDistance)
|
| 34 |
+
{
|
| 35 |
+
// Note that the blockIdx.x is on purpose here
|
| 36 |
+
int64_t workIdx = blockIdx.x * blockDim.y + threadIdx.y;
|
| 37 |
+
|
| 38 |
+
if (workIdx >= csWorkPerExample[csWorkPerExample.size(0) - 1]) return;
|
| 39 |
+
|
| 40 |
+
auto exIter = thrust::upper_bound(thrust::seq,
|
| 41 |
+
csWorkPerExample.data(), csWorkPerExample.data() + csWorkPerExample.size(0),
|
| 42 |
+
workIdx);
|
| 43 |
+
|
| 44 |
+
const int64_t exIdx = exIter - csWorkPerExample.data();
|
| 45 |
+
|
| 46 |
+
const int64_t workStart = exIdx == 0 ? 0 : csWorkPerExample[exIdx - 1];
|
| 47 |
+
const int64_t workOff = workIdx - workStart;
|
| 48 |
+
|
| 49 |
+
const int64_t row = workOff / allRegionCounts[exIdx];
|
| 50 |
+
const int64_t col = workOff % allRegionCounts[exIdx];
|
| 51 |
+
|
| 52 |
+
auto taRowQuad = allEmbedQuads[exIdx][row];
|
| 53 |
+
auto taColQuad = allEmbedQuads[exIdx][col];
|
| 54 |
+
|
| 55 |
+
Quad_<float> rowQuad(taRowQuad.data()),
|
| 56 |
+
colQuad(taColQuad.data());
|
| 57 |
+
|
| 58 |
+
auto p1 = (rowQuad[0] + rowQuad[3]) / 2.0f;
|
| 59 |
+
auto p2 = (rowQuad[1] + rowQuad[2]) / 2.0f;
|
| 60 |
+
|
| 61 |
+
auto vX = p2 - p1;
|
| 62 |
+
auto lenVX = length(vX);
|
| 63 |
+
if (lenVX > 0) {
|
| 64 |
+
vX = vX / max(lenVX, 1e-8f);
|
| 65 |
+
} else {
|
| 66 |
+
vX = { 1, 0 };
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
Pointf vY{ -vX.Y, vX.X };
|
| 70 |
+
|
| 71 |
+
auto reproj = [&vX, &vY, xFactor, yFactor] (const Pointf &pt) {
|
| 72 |
+
auto dX = dot(pt, vX);
|
| 73 |
+
if (dX >= 0) {
|
| 74 |
+
dX *= xFactor;
|
| 75 |
+
}
|
| 76 |
+
auto dY = dot(pt, vY);
|
| 77 |
+
if (dY >= 0) {
|
| 78 |
+
dY *= yFactor;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
return Pointf{ dX, dY };
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
auto tile16 = cg::tiled_partition<16>(cg::this_thread_block());
|
| 85 |
+
|
| 86 |
+
// Figure out which vertices this thread is processing
|
| 87 |
+
const int64_t rowVertexIdx = tile16.thread_rank() / 4;
|
| 88 |
+
const int64_t colVertexIdx = tile16.thread_rank() % 4;
|
| 89 |
+
|
| 90 |
+
float dist;
|
| 91 |
+
if (row != col) {
|
| 92 |
+
Segment_<float> rowSeg{ rowQuad[rowVertexIdx], rowQuad[(rowVertexIdx + 1) % 4] };
|
| 93 |
+
Segment_<float> colSeg{ colQuad[colVertexIdx], colQuad[(colVertexIdx + 1) % 4] };
|
| 94 |
+
|
| 95 |
+
Segment_<float> minSeg = shortest_line_between_segments(rowSeg, colSeg);
|
| 96 |
+
|
| 97 |
+
Point_<float> vSeg = minSeg.B - minSeg.A;
|
| 98 |
+
|
| 99 |
+
vSeg = reproj(vSeg);
|
| 100 |
+
|
| 101 |
+
dist = length(vSeg);
|
| 102 |
+
} else if (allowSelfDistance) {
|
| 103 |
+
dist = 0;
|
| 104 |
+
} else {
|
| 105 |
+
dist = numeric_limits<float>::infinity();
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// Now find the minimum distance across the group
|
| 109 |
+
int lane = tile16.thread_rank();
|
| 110 |
+
// Each iteration halves the number of active threads
|
| 111 |
+
// Each thread gets the partial min[i] to min[lane+i]
|
| 112 |
+
#pragma unroll
|
| 113 |
+
for (uint32_t i = 1; i < 16; i <<= 1) {
|
| 114 |
+
auto otherDist = tile16.shfl_down(dist, i);
|
| 115 |
+
dist = min(dist, otherDist);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
#ifndef NDEBUG
|
| 119 |
+
float lowestDist = tile16.shfl(dist, 0);
|
| 120 |
+
assert(dist >= lowestDist);
|
| 121 |
+
#endif
|
| 122 |
+
|
| 123 |
+
if (lane == 0) {
|
| 124 |
+
outDistances[exIdx][row][col] = dist;
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
torch::Tensor ragged_quad_all_2_all_distance_v2(torch::Tensor embedQuads, torch::Tensor regionCounts,
|
| 129 |
+
float xFactor, float yFactor,
|
| 130 |
+
bool allowSelfDistance)
|
| 131 |
+
{
|
| 132 |
+
if (!embedQuads.is_contiguous()) {
|
| 133 |
+
throw std::runtime_error("Expected `embedQuads` to be contiguous!");
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
auto outDistances = torch::zeros({ embedQuads.size(0), embedQuads.size(1), embedQuads.size(1) },
|
| 137 |
+
embedQuads.options());
|
| 138 |
+
|
| 139 |
+
if (embedQuads.numel() == 0) {
|
| 140 |
+
return outDistances;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
auto workPerExample = regionCounts * regionCounts;
|
| 144 |
+
|
| 145 |
+
auto csWorkPerExample = torch::cumsum(workPerExample, 0);
|
| 146 |
+
|
| 147 |
+
int64_t totalWork = csWorkPerExample[-1].item<int64_t>();
|
| 148 |
+
|
| 149 |
+
dim3 blockSize(16, 2);
|
| 150 |
+
dim3 gridSize(div_up(totalWork, blockSize.y), 1);
|
| 151 |
+
|
| 152 |
+
device_quad_all_2_all_distance_v2 KERNEL_ARG2(gridSize, blockSize) (
|
| 153 |
+
embedQuads.packed_accessor64<float, 4>(),
|
| 154 |
+
regionCounts.packed_accessor64<int64_t, 1>(),
|
| 155 |
+
csWorkPerExample.packed_accessor64<int64_t, 1>(),
|
| 156 |
+
outDistances.packed_accessor64<float, 3>(),
|
| 157 |
+
xFactor, yFactor,
|
| 158 |
+
allowSelfDistance
|
| 159 |
+
);
|
| 160 |
+
|
| 161 |
+
return outDistances;
|
| 162 |
+
}
|
nemo-retriever-ocr/cpp/module.cpp
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
#include "quad_rectify/quad_rectify.h"
|
| 7 |
+
#include "non_maximal_suppression/non_maximal_suppression.h"
|
| 8 |
+
#include "geometry_api/geometry_api.h"
|
| 9 |
+
#include "beam_decode/beam_decode.h"
|
| 10 |
+
#include "better_grid_sample/grid_sample.h"
|
| 11 |
+
#include "sparse_select/sparse_select.h"
|
| 12 |
+
#include "text_region_grouping/text_region_grouping.h"
|
| 13 |
+
#include "local_ips/local_ips.h"
|
| 14 |
+
|
| 15 |
+
#include <torch/extension.h>
|
| 16 |
+
#include <pybind11/stl.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 20 |
+
m.def("quad_rectify_calc_quad_width", &quad_rectify_calc_quad_width,
|
| 21 |
+
"Quad Rectify Calc Quad Width C++",
|
| 22 |
+
py::arg("quads"),
|
| 23 |
+
py::arg("output_height"),
|
| 24 |
+
py::arg("round_factor") = 16,
|
| 25 |
+
py::arg("max_width") = 0
|
| 26 |
+
);
|
| 27 |
+
m.def("quad_rectify_forward", &quad_rectify_forward, "Quad Rectify Forward C++",
|
| 28 |
+
py::arg("quads"),
|
| 29 |
+
py::arg("image_height"), py::arg("image_width"),
|
| 30 |
+
py::arg("output_height"), py::arg("output_width"),
|
| 31 |
+
py::arg("isotropic") = true
|
| 32 |
+
);
|
| 33 |
+
m.def("quad_rectify_backward", &quad_rectify_backward, "Quad Rectify Backward C++",
|
| 34 |
+
py::arg("quads"), py::arg("grad_output"),
|
| 35 |
+
py::arg("image_height"), py::arg("image_width"),
|
| 36 |
+
py::arg("isotropic") = true
|
| 37 |
+
);
|
| 38 |
+
m.def("quad_non_maximal_suppression", &quad_non_maximal_suppression, "Quad Non-Maximal Suppression C++",
|
| 39 |
+
py::arg("quads"), py::arg("probs"),
|
| 40 |
+
py::arg("prob_threshold"), py::arg("iou_threshold"),
|
| 41 |
+
py::arg("kernel_height"), py::arg("kernel_width"),
|
| 42 |
+
py::arg("max_regions"),
|
| 43 |
+
py::arg("verbose") = false
|
| 44 |
+
);
|
| 45 |
+
|
| 46 |
+
py::class_<LanguageModel>(m, "LanguageModel");
|
| 47 |
+
|
| 48 |
+
m.def("beam_decode", &beam_decode, "beam_decode c++",
|
| 49 |
+
py::arg("probs"),
|
| 50 |
+
py::arg("beam_size") = 100,
|
| 51 |
+
py::arg("blank") = 0,
|
| 52 |
+
py::arg("min_prob") = 0.001,
|
| 53 |
+
py::arg("lang_model") = static_cast<LanguageModel*>(nullptr),
|
| 54 |
+
py::arg("lm_weight") = 1,
|
| 55 |
+
py::arg("combine_duplicates") = true
|
| 56 |
+
);
|
| 57 |
+
|
| 58 |
+
py::class_<TokenMappingWrapper, TokenMappingWrapper::Ptr>(m, "TokenMapping");
|
| 59 |
+
|
| 60 |
+
m.def("create_token_mapping", &create_token_mapping, "create token mapping c++",
|
| 61 |
+
py::arg("token_mapping")
|
| 62 |
+
);
|
| 63 |
+
|
| 64 |
+
m.def("decode_sequences", &decode_sequences, "decode_sequences c++",
|
| 65 |
+
py::arg("tokens"), py::arg("language_model"),
|
| 66 |
+
py::arg("probs") = nullptr
|
| 67 |
+
);
|
| 68 |
+
|
| 69 |
+
m.def("create_sbo_lm", &create_sbo_lm, "create_sbo_lm c++",
|
| 70 |
+
py::arg("data_file_path"),
|
| 71 |
+
py::arg("token_mapping"),
|
| 72 |
+
py::arg("backoff") = 0.4
|
| 73 |
+
);
|
| 74 |
+
|
| 75 |
+
m.def("indirect_grid_sample_forward", &indirect_grid_sample_forward, "indirect_grid_sample::forward c++",
|
| 76 |
+
py::arg("input"), py::arg("grid"), py::arg("input_indices"), py::arg("method")
|
| 77 |
+
);
|
| 78 |
+
m.def("indirect_grad_sample_backward", &indirect_grad_sample_backward, "indirect_grid_sample::backward c++",
|
| 79 |
+
py::arg("grad_output"), py::arg("input"), py::arg("grid"), py::arg("input_indices"), py::arg("method")
|
| 80 |
+
);
|
| 81 |
+
m.def("region_counts_to_indices", ®ion_counts_to_indices, "region counts to indices",
|
| 82 |
+
py::arg("region_counts"), py::arg("num_outputs")
|
| 83 |
+
);
|
| 84 |
+
|
| 85 |
+
m.def("rrect_to_quads", &rrect_to_quads, "convert rotated rectangle to quadrangles",
|
| 86 |
+
py::arg("rrects"), py::arg("cell_size")
|
| 87 |
+
);
|
| 88 |
+
m.def("rrect_to_quads_backward", &rrect_to_quads_backward, "gradient of rrect_to_quads",
|
| 89 |
+
py::arg("rrects"), py::arg("grad_output")
|
| 90 |
+
);
|
| 91 |
+
|
| 92 |
+
m.def("sparse_select", &sparse_select, "Select sparse tensor(s) given a set of indices",
|
| 93 |
+
py::arg("sparse_counts"), py::arg("sparse_tensors"), py::arg("select_indices")
|
| 94 |
+
);
|
| 95 |
+
|
| 96 |
+
m.def("text_region_grouping", &text_region_grouping, "Clusters all of the text into lines and phrases",
|
| 97 |
+
py::arg("quads"), py::arg("counts"),
|
| 98 |
+
py::arg("horizontal_tolerance") = 2.0f,
|
| 99 |
+
py::arg("vertical_tolerance") = 0.5f,
|
| 100 |
+
py::arg("verbose") = false
|
| 101 |
+
);
|
| 102 |
+
|
| 103 |
+
m.def("dense_relations_to_graph", &dense_relations_to_graph, "Converts a dense relational tensor to a graph",
|
| 104 |
+
py::arg("relations")
|
| 105 |
+
);
|
| 106 |
+
|
| 107 |
+
m.def("ragged_quad_all_2_all_distance_v2", &ragged_quad_all_2_all_distance_v2, "get the all-to-all distances in ragged-batch quad mode",
|
| 108 |
+
py::arg("embed_quads"), py::arg("region_counts"),
|
| 109 |
+
py::arg("x_factor") = 1.0f,
|
| 110 |
+
py::arg("y_factor") = 1.0f,
|
| 111 |
+
py::arg("allow_self_distance") = true
|
| 112 |
+
);
|
| 113 |
+
|
| 114 |
+
m.def("calc_poly_min_rrect", &calc_poly_min_rrect, "calculate a reasonable bounding rectangle for a given text polygon",
|
| 115 |
+
py::arg("vertices")
|
| 116 |
+
);
|
| 117 |
+
|
| 118 |
+
m.def("get_rel_continuation_cos", &get_rel_continuation_cos, "c++ get relation cosine between 2 regions",
|
| 119 |
+
py::arg("rrect_a"), py::arg("rrect_b")
|
| 120 |
+
);
|
| 121 |
+
|
| 122 |
+
m.def("get_poly_bounds_quad", &get_poly_bounds_quad, "c++ get polygon bounds",
|
| 123 |
+
py::arg("poly")
|
| 124 |
+
);
|
| 125 |
+
}
|
nemo-retriever-ocr/cpp/non_maximal_suppression/cpu_non_maximal_suppression.cpp
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "non_maximal_suppression.h"
|
| 6 |
+
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include "../geometry.h"
|
| 9 |
+
|
| 10 |
+
using namespace std;
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
template<typename scalar_t>
|
| 14 |
+
void visit_node(
|
| 15 |
+
const torch::TensorAccessor<scalar_t, 4> &quads,
|
| 16 |
+
const torch::TensorAccessor<scalar_t, 2> &probs,
|
| 17 |
+
const torch::TensorAccessor<int32_t, 3> &adjacency,
|
| 18 |
+
MergeQuad_<scalar_t> &mQuad,
|
| 19 |
+
unordered_set<int32_t> &visited,
|
| 20 |
+
int64_t r, int64_t c, int32_t vIdx)
|
| 21 |
+
{
|
| 22 |
+
if (visited.count(vIdx)) {
|
| 23 |
+
return;
|
| 24 |
+
}
|
| 25 |
+
visited.insert(vIdx);
|
| 26 |
+
|
| 27 |
+
int32_t *pAdj = adjacency[r][c].data();
|
| 28 |
+
|
| 29 |
+
int32_t adjCt = pAdj[0];
|
| 30 |
+
assert(adjCt > 0);
|
| 31 |
+
|
| 32 |
+
mQuad.Append(Quad_<scalar_t>(quads[r][c].data()), probs[r][c]);
|
| 33 |
+
|
| 34 |
+
int32_t *pOff = pAdj + 2;
|
| 35 |
+
int32_t *pEnd = pAdj + adjCt + 1;
|
| 36 |
+
|
| 37 |
+
const int32_t W = quads.size(1);
|
| 38 |
+
|
| 39 |
+
for (; pOff != pEnd; ++pOff) {
|
| 40 |
+
int32_t vIdx2 = *pOff;
|
| 41 |
+
int32_t r2 = vIdx2 / W;
|
| 42 |
+
int32_t c2 = vIdx2 % W;
|
| 43 |
+
|
| 44 |
+
visit_node(quads, probs, adjacency, mQuad, visited, r2, c2, vIdx2);
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template<typename scalar_t>
|
| 49 |
+
std::vector<torch::Tensor> quad_nms_from_adjacency_impl(
|
| 50 |
+
const torch::TensorAccessor<scalar_t, 5> &quads,
|
| 51 |
+
const torch::TensorAccessor<scalar_t, 3> &probs,
|
| 52 |
+
const torch::TensorAccessor<int32_t, 4> &adjacency,
|
| 53 |
+
scalar_t probThreshold, scalar_t iouThreshold,
|
| 54 |
+
int64_t maxRegions)
|
| 55 |
+
{
|
| 56 |
+
const uint64_t B = quads.size((int)0);
|
| 57 |
+
const int64_t H = quads.size((int)1);
|
| 58 |
+
const int64_t W = quads.size((int)2);
|
| 59 |
+
|
| 60 |
+
typedef MergeQuad_<scalar_t> MQuad;
|
| 61 |
+
typedef EmbedQuad_<scalar_t> EFQuad;
|
| 62 |
+
|
| 63 |
+
vector<vector<EFQuad>> batchQuads{ static_cast< const unsigned int >( B ) };
|
| 64 |
+
vector<vector<EFQuad>> allQuads{ static_cast< const unsigned int >( B ) };
|
| 65 |
+
vector<vector<vector<size_t>>> batchAdjIdxs{ static_cast< const unsigned int >( B ) };
|
| 66 |
+
|
| 67 |
+
#pragma omp parallel num_threads (8)
|
| 68 |
+
{
|
| 69 |
+
#pragma omp for
|
| 70 |
+
for (int64_t b = 0; b < B; ++b) {
|
| 71 |
+
unordered_set<int32_t> visited;
|
| 72 |
+
|
| 73 |
+
for (int64_t r = 0; r < H; ++r) {
|
| 74 |
+
for (int64_t c = 0; c < W; ++c) {
|
| 75 |
+
auto currProb = probs[b][r][c];
|
| 76 |
+
|
| 77 |
+
if (currProb < probThreshold) {
|
| 78 |
+
continue;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
int32_t vIdx = r * W + c;
|
| 82 |
+
|
| 83 |
+
// Ensure that this quad hasn't already been merged
|
| 84 |
+
if (visited.count(vIdx)) {
|
| 85 |
+
continue;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
MQuad mQuad{ZeroInitTag{}};
|
| 89 |
+
visit_node(quads[b], probs[b], adjacency[b], mQuad, visited, r, c, vIdx);
|
| 90 |
+
|
| 91 |
+
batchQuads[b].push_back(mQuad.Commit());
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
#pragma omp single
|
| 97 |
+
{
|
| 98 |
+
for (size_t b = 0; b < B; ++b) {
|
| 99 |
+
size_t numQuads = batchQuads[b].size();
|
| 100 |
+
batchAdjIdxs[b].resize(numQuads);
|
| 101 |
+
for (int64_t n = 0; n < numQuads; ++n) {
|
| 102 |
+
#pragma omp task default(none) shared(batchAdjIdxs, batchQuads, iouThreshold) firstprivate(b, numQuads, n)
|
| 103 |
+
{
|
| 104 |
+
for (int64_t m = n + 1; m < numQuads; ++m) {
|
| 105 |
+
vector<size_t> &adjIdxs = batchAdjIdxs[b][n];
|
| 106 |
+
vector<EFQuad> &quads = batchQuads[b];
|
| 107 |
+
auto iou = quads[n].IOU(quads[m]);
|
| 108 |
+
|
| 109 |
+
if (iou > iouThreshold) {
|
| 110 |
+
adjIdxs.push_back(m);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
#pragma omp taskwait
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
#pragma omp for
|
| 121 |
+
for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) {
|
| 122 |
+
vector<vector<size_t>> &adjIdxs = batchAdjIdxs[batchIdx];
|
| 123 |
+
vector<EFQuad> &quads = batchQuads[batchIdx];
|
| 124 |
+
vector<EFQuad> &finalQuads = allQuads[batchIdx];
|
| 125 |
+
|
| 126 |
+
// Step 3: Using depth first search, merge the regions
|
| 127 |
+
unordered_set<size_t> visited;
|
| 128 |
+
for (int64_t n = 0; n < quads.size(); ++n) {
|
| 129 |
+
EFQuad currQuad;
|
| 130 |
+
visit_node(quads, n, adjIdxs, currQuad, visited);
|
| 131 |
+
|
| 132 |
+
if (currQuad.NumQuads > 0) {
|
| 133 |
+
currQuad.Prepare();
|
| 134 |
+
|
| 135 |
+
finalQuads.push_back(currQuad);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Only sort the part that we want to keep
|
| 140 |
+
partial_sort(begin(finalQuads),
|
| 141 |
+
begin(finalQuads) + std::min<int64_t>(finalQuads.size(), maxRegions),
|
| 142 |
+
end(finalQuads),
|
| 143 |
+
[] (auto a, auto b) {
|
| 144 |
+
return a.Confidence > b.Confidence;
|
| 145 |
+
}
|
| 146 |
+
);
|
| 147 |
+
|
| 148 |
+
// Truncate the low confidence regions
|
| 149 |
+
if (finalQuads.size() > maxRegions) {
|
| 150 |
+
finalQuads.resize(maxRegions);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
//cout << "Ex " << batchIdx << " quads:" << endl << finalQuads << endl << endl;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
} // End parallel
|
| 157 |
+
|
| 158 |
+
int64_t numOutQuads = 0;
|
| 159 |
+
for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) {
|
| 160 |
+
numOutQuads += allQuads[batchIdx].size();
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
// Step 4: Convert the quads into tensor representation
|
| 164 |
+
auto outQuadTensor = torch::empty({ numOutQuads, 4, 2 }, torch::kFloat32);
|
| 165 |
+
auto outConfTensor = torch::empty({ numOutQuads }, torch::kFloat32);
|
| 166 |
+
torch::Tensor outCountTensor = torch::empty({ static_cast<int64_t>( allQuads.size() ) }, torch::kInt64);
|
| 167 |
+
|
| 168 |
+
auto outQuadAccess = outQuadTensor.accessor<float, 3>();
|
| 169 |
+
auto outConfAccess = outConfTensor.accessor<float, 1>();
|
| 170 |
+
auto outCountAccess = outCountTensor.accessor<int64_t, 1>();
|
| 171 |
+
|
| 172 |
+
int64_t offset = 0;
|
| 173 |
+
for (int64_t batchIdx = 0; batchIdx < allQuads.size(); ++batchIdx) {
|
| 174 |
+
vector<EFQuad> &exQuads = allQuads[batchIdx];
|
| 175 |
+
|
| 176 |
+
outCountAccess[batchIdx] = exQuads.size();
|
| 177 |
+
|
| 178 |
+
for (int64_t qIdx = 0; qIdx < exQuads.size(); ++qIdx, ++offset) {
|
| 179 |
+
copy_quad(exQuads[qIdx], outQuadAccess[offset].data());
|
| 180 |
+
outConfAccess[offset] = exQuads[qIdx].Confidence;
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
return { outQuadTensor, outConfTensor, outCountTensor };
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
std::vector<torch::Tensor> quad_nms_from_adjacency(
|
| 188 |
+
torch::Tensor quads, torch::Tensor probs, torch::Tensor adjacency,
|
| 189 |
+
float probThreshold, float iouThreshold,
|
| 190 |
+
int64_t maxRegions)
|
| 191 |
+
{
|
| 192 |
+
std::vector<torch::Tensor> ret;
|
| 193 |
+
|
| 194 |
+
AT_DISPATCH_FLOATING_TYPES(
|
| 195 |
+
quads.scalar_type(),
|
| 196 |
+
"quad_nms_from_adjacency",
|
| 197 |
+
([&] {
|
| 198 |
+
ret = quad_nms_from_adjacency_impl<scalar_t>(
|
| 199 |
+
quads.accessor<scalar_t, 5>(),
|
| 200 |
+
probs.accessor<scalar_t, 3>(),
|
| 201 |
+
adjacency.accessor<int32_t, 4>(),
|
| 202 |
+
probThreshold, iouThreshold,
|
| 203 |
+
maxRegions
|
| 204 |
+
);
|
| 205 |
+
})
|
| 206 |
+
);
|
| 207 |
+
|
| 208 |
+
return ret;
|
| 209 |
+
}
|
nemo-retriever-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu
ADDED
|
@@ -0,0 +1,1720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#include "non_maximal_suppression.h"
|
| 6 |
+
|
| 7 |
+
#include <cooperative_groups.h>
|
| 8 |
+
#include <cooperative_groups/reduce.h>
|
| 9 |
+
|
| 10 |
+
#include <thrust/binary_search.h>
|
| 11 |
+
#include <thrust/device_vector.h>
|
| 12 |
+
#include <thrust/execution_policy.h>
|
| 13 |
+
|
| 14 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 15 |
+
|
| 16 |
+
#include <trove/ptr.h>
|
| 17 |
+
|
| 18 |
+
#include "../cuda_intellisense.cuh"
|
| 19 |
+
#include "../geometry.h"
|
| 20 |
+
#include "../common.h"
|
| 21 |
+
#include "../scope_timer.h"
|
| 22 |
+
#include "strided_quad.h"
|
| 23 |
+
|
| 24 |
+
// If this flag is turned on, then a bunch of checks will be inserted to ensure that the same results are produced by
|
| 25 |
+
// successive calls to NMS. This means that it makes the library unusable outside of a debug context, so beware!
|
| 26 |
+
//#define NMS_VERIFY_CORRECTNESS
|
| 27 |
+
|
| 28 |
+
namespace cg = cooperative_groups;
|
| 29 |
+
namespace ix = torch::indexing;
|
| 30 |
+
|
| 31 |
+
inline
|
| 32 |
+
void print_tensor_stats2(const std::string &msg, const torch::Tensor& tensor) {
|
| 33 |
+
|
| 34 |
+
auto fTensor = tensor.to(torch::kDouble).cpu();
|
| 35 |
+
|
| 36 |
+
std::stringstream ss;
|
| 37 |
+
if (tensor.numel() > 1) {
|
| 38 |
+
ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << " Max: " << fTensor.max().item<double>() << " Min: " << fTensor.min().item<double>() << " Mean: " << fTensor.mean().item<double>() << " Std: " << fTensor.std().item<double>();
|
| 39 |
+
}
|
| 40 |
+
else if (tensor.numel() == 1) {
|
| 41 |
+
ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << " Value: " << fTensor.item<double>() << std::endl;
|
| 42 |
+
}
|
| 43 |
+
else {
|
| 44 |
+
ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << std::endl;
|
| 45 |
+
}
|
| 46 |
+
std::cout << ss.str() << std::endl;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
inline
|
| 50 |
+
void print_tensor_vec_stats2(std::string msg, const std::vector<torch::Tensor>& tensorVec) {
|
| 51 |
+
std::cout << msg << " Size: " << tensorVec.size() << std::endl;
|
| 52 |
+
std::stringstream ss;
|
| 53 |
+
msg = " - ";
|
| 54 |
+
for (int i = 0; i < tensorVec.size(); ++i) {
|
| 55 |
+
ss << msg << "[" << i << "]:";
|
| 56 |
+
auto tensor = tensorVec[i];
|
| 57 |
+
print_tensor_stats2(ss.str(), tensor);
|
| 58 |
+
ss.str("");
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
std::ostream &operator<<(std::ostream &os, dim3 d)
|
| 63 |
+
{
|
| 64 |
+
return os << "(" << d.x << ", " << d.y << ", " << d.z << ")";
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
#define ADD_OP2(vector2_t) __device__ \
|
| 68 |
+
vector2_t operator+(const vector2_t &a, const vector2_t &b) { \
|
| 69 |
+
return { a.x + b.x, a.y + b.y }; \
|
| 70 |
+
}
|
| 71 |
+
ADD_OP2(float2);
|
| 72 |
+
ADD_OP2(double2);
|
| 73 |
+
#undef ADD_OP2
|
| 74 |
+
|
| 75 |
+
#define ADD_OP4(vector4_t) __device__ \
|
| 76 |
+
vector4_t operator+(const vector4_t &a, const vector4_t &b) { \
|
| 77 |
+
return { a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w }; \
|
| 78 |
+
}
|
| 79 |
+
ADD_OP4(float4);
|
| 80 |
+
ADD_OP4(double4);
|
| 81 |
+
#undef ADD_OP4
|
| 82 |
+
|
| 83 |
+
template<typename T, size_t Size>
|
| 84 |
+
__device__
|
| 85 |
+
std::array<T, Size> operator+(const std::array<T, Size> &a, const std::array<T, Size> &b) {
|
| 86 |
+
std::array<T, Size> ret;
|
| 87 |
+
#pragma unroll
|
| 88 |
+
for (size_t i = 0; i < Size; ++i) {
|
| 89 |
+
ret._Elems[i] = a._Elems[i] + b._Elems[i];
|
| 90 |
+
}
|
| 91 |
+
return ret;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
#if __CUDA_ARCH__ >= 800
|
| 95 |
+
#define __reduce_add_full_warp(val) __reduce_add_sync(0xFFFFFFFF, val)
|
| 96 |
+
#define __reduce_max_full_warp(val) __reduce_max_sync(0xFFFFFFFF, val)
|
| 97 |
+
#define __reduce_min_full_warp(val) __reduce_min_sync(0xFFFFFFFF, val)
|
| 98 |
+
#else
|
| 99 |
+
#define __reduce_add_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::plus<decltype(val)>())
|
| 100 |
+
#define __reduce_max_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::greater<decltype(val)>())
|
| 101 |
+
#define __reduce_min_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::less<decltype(val)>())
|
| 102 |
+
#endif
|
| 103 |
+
|
| 104 |
+
template<typename T>
|
| 105 |
+
struct TToVec;
|
| 106 |
+
template<>
|
| 107 |
+
struct TToVec<float> { typedef float2 type2; typedef float4 type4; };
|
| 108 |
+
template<>
|
| 109 |
+
struct TToVec<double> { typedef double2 type2; typedef double4 type4; };
|
| 110 |
+
|
| 111 |
+
template<typename T, typename accessor_t>
|
| 112 |
+
__device__
|
| 113 |
+
void write_embed_quad(accessor_t &acc, const MergeQuad_<T> &quad, int64_t storeOff)
|
| 114 |
+
{
|
| 115 |
+
constexpr auto EMBED_QUAD_SIZE = sizeof(EmbedQuad_<T>) / sizeof(T);
|
| 116 |
+
static_assert(EMBED_QUAD_SIZE == 10, "Unsupported embed quad size!");
|
| 117 |
+
|
| 118 |
+
const T *mergeBuff = reinterpret_cast<const T*>(&quad);
|
| 119 |
+
|
| 120 |
+
const T confidence = quad.Confidence;
|
| 121 |
+
const auto i = threadIdx.x;
|
| 122 |
+
|
| 123 |
+
if (i >= 10) {
|
| 124 |
+
return;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
T outVal;
|
| 128 |
+
// Coordinates
|
| 129 |
+
if (i < 8) {
|
| 130 |
+
outVal = mergeBuff[i] / confidence;
|
| 131 |
+
// Confidence
|
| 132 |
+
} else if (i == 8) {
|
| 133 |
+
outVal = confidence / mergeBuff[9];
|
| 134 |
+
// NumQuads
|
| 135 |
+
} else {
|
| 136 |
+
outVal = mergeBuff[9];
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
acc[i][storeOff] = outVal;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
template<typename group_t, typename ...Args>
|
| 144 |
+
__device__
|
| 145 |
+
void ordered_print(group_t &group, const char *const fmt, const Args& ...args)
|
| 146 |
+
{
|
| 147 |
+
for (uint32_t i = 0; i < group.size(); ++i) {
|
| 148 |
+
if (group.thread_rank() == i) {
|
| 149 |
+
printf(fmt, args...);
|
| 150 |
+
}
|
| 151 |
+
group.sync();
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template<typename T>
|
| 156 |
+
__global__
|
| 157 |
+
void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
|
| 158 |
+
torch::PackedTensorAccessor64<T, 3> allConfs,
|
| 159 |
+
T confThreshold, T iouThreshold,
|
| 160 |
+
torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
|
| 161 |
+
torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads
|
| 162 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 163 |
+
, torch::PackedTensorAccessor64<int32_t, 2> allOutIds
|
| 164 |
+
#endif
|
| 165 |
+
)
|
| 166 |
+
{
|
| 167 |
+
typedef InPlaceQuad_<T> Quadf;
|
| 168 |
+
static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
|
| 169 |
+
|
| 170 |
+
constexpr uint32_t ALL_MASK = 0xFFFFFFFF;
|
| 171 |
+
constexpr uint32_t WARP_SIZE = 32;
|
| 172 |
+
constexpr T MIN_VALID_AREA = 8;
|
| 173 |
+
|
| 174 |
+
const uint32_t B = allQuads.size(0);
|
| 175 |
+
const uint32_t H = allQuads.size(1);
|
| 176 |
+
|
| 177 |
+
const uint32_t b = blockIdx.z;
|
| 178 |
+
const uint32_t r = blockIdx.y * blockDim.y + threadIdx.y;
|
| 179 |
+
|
| 180 |
+
if (r >= H) {
|
| 181 |
+
return;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
#define threadRank threadIdx.x
|
| 185 |
+
|
| 186 |
+
auto rawQuads = reinterpret_cast<Quadf*>(allQuads[b][r].data());
|
| 187 |
+
#if defined(NDEBUG)
|
| 188 |
+
trove::coalesced_ptr<Quadf> quads(rawQuads);
|
| 189 |
+
#else
|
| 190 |
+
auto quads = rawQuads;
|
| 191 |
+
#endif
|
| 192 |
+
|
| 193 |
+
auto confs = allConfs[b][r];
|
| 194 |
+
|
| 195 |
+
T conf = confs[threadRank];
|
| 196 |
+
|
| 197 |
+
bool quadValid = conf >= confThreshold;
|
| 198 |
+
uint32_t ballot = __ballot_sync(ALL_MASK, quadValid);
|
| 199 |
+
|
| 200 |
+
// No valid quads in this window, so we're done!
|
| 201 |
+
if (ballot == 0) {
|
| 202 |
+
return;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
const Quadf currQuad = quads[threadRank];
|
| 206 |
+
|
| 207 |
+
const T qArea = currQuad.Area();
|
| 208 |
+
|
| 209 |
+
quadValid = quadValid && qArea > MIN_VALID_AREA;
|
| 210 |
+
ballot = __ballot_sync(ALL_MASK, quadValid);
|
| 211 |
+
if (ballot == 0) {
|
| 212 |
+
return;
|
| 213 |
+
}
|
| 214 |
+
if (! quadValid) {
|
| 215 |
+
conf = 0;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
MergeQuad_<T> qAccum{ZeroInitTag{}};
|
| 219 |
+
|
| 220 |
+
Quadf prevQuad;
|
| 221 |
+
auto pCurrQuad = reinterpret_cast<const T*>(&currQuad);
|
| 222 |
+
auto pPrevQuad = reinterpret_cast<T*>(&prevQuad);
|
| 223 |
+
#pragma unroll
|
| 224 |
+
for (uint32_t i = 0; i < 8; ++i) {
|
| 225 |
+
pPrevQuad[i] = __shfl_up_sync(ALL_MASK, pCurrQuad[i], 1);
|
| 226 |
+
}
|
| 227 |
+
T prevConf = __shfl_up_sync(ALL_MASK, conf, 1);
|
| 228 |
+
|
| 229 |
+
if (threadRank == 0) {
|
| 230 |
+
prevConf = 0;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
bool iouValid = false;
|
| 234 |
+
T iou = 0;
|
| 235 |
+
if (quadValid) {
|
| 236 |
+
qAccum.Append(currQuad, conf);
|
| 237 |
+
|
| 238 |
+
if (prevConf >= confThreshold) {
|
| 239 |
+
iou = prevQuad.IOU_UpperBound(currQuad);
|
| 240 |
+
if (iou >= iouThreshold) {
|
| 241 |
+
iouValid = true;
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
// This is the start of a span if the current confidence is above threshold, but the quad to the left is either below threshold,
|
| 247 |
+
// or the IOU between the quads is below threshold
|
| 248 |
+
const bool isStartOfSpan = quadValid && !iouValid;
|
| 249 |
+
|
| 250 |
+
uint32_t label = isStartOfSpan;
|
| 251 |
+
// All labels start out as 0 or 1, and we'll then do a cumsum over the warp, which gives each thread an assigned label
|
| 252 |
+
// We also know that the final thread also contains the number of labels.
|
| 253 |
+
#pragma unroll
|
| 254 |
+
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
|
| 255 |
+
auto inc = __shfl_up_sync(ALL_MASK, label, offset);
|
| 256 |
+
if (threadRank >= offset) {
|
| 257 |
+
label += inc;
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
// Before we zero out invalid labels, get the total number of labels
|
| 262 |
+
const uint32_t numLabels = __shfl_sync(ALL_MASK, label, WARP_SIZE - 1);
|
| 263 |
+
|
| 264 |
+
// Zero out the label if the current quad isn't valid
|
| 265 |
+
label = quadValid ? label : 0;
|
| 266 |
+
|
| 267 |
+
T* accumPtr = reinterpret_cast<T*>(&qAccum);
|
| 268 |
+
// Reduce all of the quads s.t. the left-most position in the span contains the full quad.
|
| 269 |
+
// We use `label` to decide whether to do the accumulation
|
| 270 |
+
#pragma unroll
|
| 271 |
+
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
|
| 272 |
+
const auto otherLabel = __shfl_down_sync(ALL_MASK, label, offset);
|
| 273 |
+
|
| 274 |
+
// Regardless of whether the labels match, all threads in the warp must make the shfl_down
|
| 275 |
+
// call. So we use factor to modulate whether the given merge is valid
|
| 276 |
+
const T factor = otherLabel == label && offset + threadRank < WARP_SIZE ? 1.0f : 0.0f;
|
| 277 |
+
|
| 278 |
+
#pragma unroll
|
| 279 |
+
for (uint32_t i = 0; i < 10; ++i) {
|
| 280 |
+
accumPtr[i] += factor * __shfl_down_sync(ALL_MASK, accumPtr[i], offset);
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// Elect thread-0 to figure out where to store the results
|
| 285 |
+
uint32_t storeOff = 0;
|
| 286 |
+
if (threadRank == 0) {
|
| 287 |
+
storeOff = atomicAdd(&allOutCounts[b], numLabels);
|
| 288 |
+
}
|
| 289 |
+
// Broadcast that offset to the whole warp
|
| 290 |
+
storeOff = __shfl_sync(ALL_MASK, storeOff, 0);
|
| 291 |
+
|
| 292 |
+
auto outEmbedQuads = allOutEmbedQuads[b];
|
| 293 |
+
// Now write out each quad, but collectively
|
| 294 |
+
for (uint32_t procLabel = 1; procLabel <= numLabels; ++procLabel) {
|
| 295 |
+
// Discover the index of the start of each label span
|
| 296 |
+
ballot = __ballot_sync(ALL_MASK, procLabel == label);
|
| 297 |
+
// ffs will find the (1-based) index of the least significant bit in ballot.
|
| 298 |
+
// This just so happens to be the start of the span for the current label
|
| 299 |
+
uint32_t startIdx = __ffs(ballot) - 1;
|
| 300 |
+
|
| 301 |
+
const T* inT = reinterpret_cast<T*>(&qAccum);
|
| 302 |
+
MergeQuad_<T> outQuad;
|
| 303 |
+
T* outT = reinterpret_cast<T*>(&outQuad);
|
| 304 |
+
#pragma unroll
|
| 305 |
+
for (uint32_t i = 0; i < 10; ++i) {
|
| 306 |
+
outT[i] = __shfl_sync(ALL_MASK, inT[i], startIdx);
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
|
| 310 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 311 |
+
if (threadRank == 0) {
|
| 312 |
+
allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
|
| 313 |
+
}
|
| 314 |
+
#endif
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
if (threadRank == 0) {
|
| 318 |
+
// Increment the total number of quads by the number encountered on this row
|
| 319 |
+
atomicAdd(&allOutCounts[B], numLabels);
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
#undef threadRank
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
template<bool IsSingleExample, typename T>
|
| 326 |
+
__global__
|
| 327 |
+
void device_a2a_adjacency_sparse(const uint64_t punCounts,
|
| 328 |
+
T iouThreshold,
|
| 329 |
+
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 330 |
+
torch::PackedTensorAccessor64<bool, 2> outIsStart,
|
| 331 |
+
torch::PackedTensorAccessor64<int32_t, 2> outAdjCounts,
|
| 332 |
+
torch::PackedTensorAccessor64<int32_t, 3> outSparseAdj)
|
| 333 |
+
{
|
| 334 |
+
const uint32_t b = blockIdx.y;
|
| 335 |
+
|
| 336 |
+
const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 337 |
+
|
| 338 |
+
const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 339 |
+
const int32_t row = jobIdx / quadCt;
|
| 340 |
+
const int32_t col = jobIdx % quadCt;
|
| 341 |
+
|
| 342 |
+
// Only compute the upper triangular portion of the matrix
|
| 343 |
+
if (row >= quadCt || col < row) {
|
| 344 |
+
return;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
T* exData = IsSingleExample ? embedQuads.data() : embedQuads[b].data();
|
| 348 |
+
|
| 349 |
+
const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
|
| 350 |
+
qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
|
| 351 |
+
|
| 352 |
+
T pctRow, pctCol, iou;
|
| 353 |
+
thrust::tie(pctRow, pctCol, iou) = geometry_region_sizes(qRow, qCol);
|
| 354 |
+
|
| 355 |
+
auto warpGroup = cg::tiled_partition<32>(cg::this_thread_block());
|
| 356 |
+
|
| 357 |
+
auto rowGroup = cg::labeled_partition(warpGroup, row);
|
| 358 |
+
|
| 359 |
+
const bool isValid = iou >= iouThreshold;
|
| 360 |
+
|
| 361 |
+
const uint32_t ballot = rowGroup.ballot(isValid);
|
| 362 |
+
const uint32_t numValid = __popc(ballot);
|
| 363 |
+
|
| 364 |
+
auto exAdjCounts = outAdjCounts[b].data();
|
| 365 |
+
|
| 366 |
+
int32_t storeOff = 0;
|
| 367 |
+
if (numValid > 0 && rowGroup.thread_rank() == 0) {
|
| 368 |
+
storeOff = atomicAdd(exAdjCounts + row, numValid);
|
| 369 |
+
}
|
| 370 |
+
storeOff = rowGroup.shfl(storeOff, 0);
|
| 371 |
+
|
| 372 |
+
if (isValid) {
|
| 373 |
+
// This will set all of the bits to the left of this one to 1, otherwise 0.
|
| 374 |
+
// We can use this to count the number of bits that are set, and are less significant than this one,
|
| 375 |
+
// to get the local storage offset
|
| 376 |
+
uint32_t lowerMask = (1 << rowGroup.thread_rank()) - 1;
|
| 377 |
+
|
| 378 |
+
storeOff += __popc(ballot & lowerMask);
|
| 379 |
+
|
| 380 |
+
outSparseAdj[b][row][storeOff] = col;
|
| 381 |
+
if (row != col) {
|
| 382 |
+
// Because `col` gets merged into `row`, we mark it as inactive for reduction purposes.
|
| 383 |
+
// All of the quads that `col` is adjacent to will be absorbed by `row`.
|
| 384 |
+
outIsStart[b][col] = false;
|
| 385 |
+
|
| 386 |
+
// Also store the transposed relation
|
| 387 |
+
storeOff = atomicAdd(exAdjCounts + col, 1);
|
| 388 |
+
outSparseAdj[b][col][storeOff] = row;
|
| 389 |
+
}
|
| 390 |
+
} else if (pctRow > 0.8f || pctCol > 0.8f) {
|
| 391 |
+
T anchorHeight = qRow.Height();
|
| 392 |
+
T otherHeight = qCol.Height();
|
| 393 |
+
|
| 394 |
+
T ratio = anchorHeight > otherHeight ?
|
| 395 |
+
otherHeight / anchorHeight :
|
| 396 |
+
anchorHeight / otherHeight;
|
| 397 |
+
if (ratio > 0.9f) {
|
| 398 |
+
if (pctRow > 0.8f) {
|
| 399 |
+
// Other envelops anchor
|
| 400 |
+
outIsStart[b][row] = false;
|
| 401 |
+
}
|
| 402 |
+
else {
|
| 403 |
+
outIsStart[b][col] = false;
|
| 404 |
+
}
|
| 405 |
+
}
|
| 406 |
+
}
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
template<uint32_t NumWarps, bool IsSingleExample, typename T, int32_t I_CELL_SIZE>
|
| 410 |
+
__global__
|
| 411 |
+
void device_a2a_adjacency_build_grid(const uint64_t punCounts,
|
| 412 |
+
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 413 |
+
torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
|
| 414 |
+
torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
|
| 415 |
+
{
|
| 416 |
+
constexpr T MIN_T = std::numeric_limits<T>::min();
|
| 417 |
+
constexpr T MAX_T = std::numeric_limits<T>::max();
|
| 418 |
+
constexpr uint32_t WARP_SIZE = 32;
|
| 419 |
+
constexpr uint32_t BLOCK_SIZE = NumWarps * WARP_SIZE;
|
| 420 |
+
constexpr uint32_t FULL_WARP = 0xFFFFFFFF;
|
| 421 |
+
constexpr uint32_t FIRST_16_THREADS = 0x0FFFF;
|
| 422 |
+
constexpr T CELL_SIZE = I_CELL_SIZE;
|
| 423 |
+
constexpr T INV_CELL_SIZE = 1 / CELL_SIZE;
|
| 424 |
+
|
| 425 |
+
const uint32_t b = blockIdx.z;
|
| 426 |
+
|
| 427 |
+
const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 428 |
+
const uint32_t quadIdx = blockIdx.y;
|
| 429 |
+
|
| 430 |
+
if (!IsSingleExample && quadIdx >= quadCt) {
|
| 431 |
+
return;
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
const uint32_t threadRank = threadIdx.x;
|
| 435 |
+
const uint32_t localThreadRank = threadRank & 0x1F;
|
| 436 |
+
|
| 437 |
+
auto exQuads = embedQuads[b];
|
| 438 |
+
|
| 439 |
+
const uint32_t numCells[2] = { outGridCells.size(2), outGridCells.size(1) };
|
| 440 |
+
|
| 441 |
+
const uint32_t numRows = outGridCells.size(1);
|
| 442 |
+
const uint32_t numCols = outGridCells.size(2);
|
| 443 |
+
|
| 444 |
+
// We use flip so that we can compute min and max simultaneously.
|
| 445 |
+
// First 4 threads compute the min, next 4 compute the max
|
| 446 |
+
T sign = localThreadRank < 8 ? 1.0f : -1.0f;
|
| 447 |
+
T myVal = sign * (localThreadRank < 16 ? exQuads[localThreadRank & 0x7][quadIdx] : MIN_T);
|
| 448 |
+
#pragma unroll
|
| 449 |
+
for (uint32_t offset = 2; offset < 8; offset <<= 1) {
|
| 450 |
+
T nextVal = __shfl_down_sync(FIRST_16_THREADS, myVal, offset);
|
| 451 |
+
myVal = min(myVal, nextVal);
|
| 452 |
+
}
|
| 453 |
+
const uint32_t cellVal = max(0.0f, sign * INV_CELL_SIZE * myVal);
|
| 454 |
+
|
| 455 |
+
uint32_t minCell[2] = { __shfl_sync(FULL_WARP, cellVal, 0), __shfl_sync(FULL_WARP, cellVal, 1) },
|
| 456 |
+
maxCell[2] = { __shfl_sync(FULL_WARP, cellVal, 8), __shfl_sync(FULL_WARP, cellVal, 9) };
|
| 457 |
+
|
| 458 |
+
#pragma unroll
|
| 459 |
+
for (uint32_t i = 0; i < 2; ++i) {
|
| 460 |
+
maxCell[i] = min(numCells[i] - 1, maxCell[i]);
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
const uint32_t sizes[2] = { maxCell[0] - minCell[0] + 1, maxCell[1] - minCell[1] + 1 };
|
| 464 |
+
|
| 465 |
+
const uint32_t totalCells = sizes[0] * sizes[1];
|
| 466 |
+
|
| 467 |
+
auto exGridCells = outGridCells[b];
|
| 468 |
+
|
| 469 |
+
for (uint32_t i = threadRank; i < totalCells; i += BLOCK_SIZE) {
|
| 470 |
+
uint32_t row = minCell[1] + i / sizes[0];
|
| 471 |
+
uint32_t col = minCell[0] + i % sizes[0];
|
| 472 |
+
|
| 473 |
+
int32_t *pCell = exGridCells[row][col].data();
|
| 474 |
+
|
| 475 |
+
// The first value in the array is the count, and the rest are the quad indices
|
| 476 |
+
int32_t storeOff = atomicAdd(pCell, 1) + 1;
|
| 477 |
+
pCell[storeOff] = quadIdx;
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
if (threadRank < 2) {
|
| 481 |
+
outQuadCells[b][quadIdx][threadRank] = minCell[threadRank];
|
| 482 |
+
} else if (threadRank < 4) {
|
| 483 |
+
outQuadCells[b][quadIdx][threadRank] = maxCell[threadRank - 2];
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
typedef uint8_t visit_mask_t;
|
| 488 |
+
|
| 489 |
+
template<uint32_t NumWarps, bool IsSingleExample, typename T>
|
| 490 |
+
__global__
|
| 491 |
+
void device_a2a_adjacency_with_grid(const uint64_t punCounts,
|
| 492 |
+
T iouThreshold,
|
| 493 |
+
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 494 |
+
torch::PackedTensorAccessor64<int32_t, 4> allCells,
|
| 495 |
+
torch::PackedTensorAccessor64<int32_t, 3> allQuadExtents,
|
| 496 |
+
torch::PackedTensorAccessor64<bool, 2> outIsStart,
|
| 497 |
+
torch::PackedTensorAccessor64<int32_t, 2> outAdjCounts,
|
| 498 |
+
torch::PackedTensorAccessor64<int32_t, 3> outSparseAdj)
|
| 499 |
+
{
|
| 500 |
+
constexpr T MIN_T = std::numeric_limits<T>::min();
|
| 501 |
+
constexpr T MAX_T = std::numeric_limits<T>::max();
|
| 502 |
+
constexpr uint32_t WARP_SIZE = 32;
|
| 503 |
+
constexpr uint32_t BLOCK_SIZE = NumWarps * WARP_SIZE;
|
| 504 |
+
|
| 505 |
+
const uint32_t b = blockIdx.z;
|
| 506 |
+
|
| 507 |
+
const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 508 |
+
const uint32_t quadIdx = blockIdx.y;
|
| 509 |
+
|
| 510 |
+
if (!IsSingleExample && quadIdx >= quadCt) {
|
| 511 |
+
return;
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
const uint32_t threadRank = threadIdx.x;
|
| 515 |
+
|
| 516 |
+
auto exQuads = allEmbedQuads[b];
|
| 517 |
+
|
| 518 |
+
__shared__ T s_quadVerts[8];
|
| 519 |
+
__shared__ uint32_t s_quadExtent[4];
|
| 520 |
+
extern __shared__ uint32_t s_alreadyVisited[];
|
| 521 |
+
|
| 522 |
+
if (threadRank < 8) {
|
| 523 |
+
s_quadVerts[threadRank] = exQuads[threadRank][quadIdx];
|
| 524 |
+
} else if (threadRank < 12) {
|
| 525 |
+
s_quadExtent[threadRank - 8] = reinterpret_cast<uint32_t*>(allQuadExtents[b][quadIdx].data())[threadRank - 8];
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
uint32_t zeroTerm = (quadCt + 31u) >> 5u; // Fast version of div_up(quadCt, 32)
|
| 529 |
+
for (uint32_t col = threadRank; col < zeroTerm; col += BLOCK_SIZE) {
|
| 530 |
+
s_alreadyVisited[col] = 0;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
__syncthreads();
|
| 534 |
+
|
| 535 |
+
auto exCells = allCells[b];
|
| 536 |
+
auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
|
| 537 |
+
auto exAdjValues = outSparseAdj[b][quadIdx].data();
|
| 538 |
+
|
| 539 |
+
T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
|
| 540 |
+
|
| 541 |
+
const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
|
| 542 |
+
|
| 543 |
+
const uint32_t startCol = s_quadExtent[0],
|
| 544 |
+
endCol = s_quadExtent[2];
|
| 545 |
+
for (uint32_t row = s_quadExtent[1], endRow = s_quadExtent[3]; row <= endRow; ++row) {
|
| 546 |
+
auto rowCells = exCells[row];
|
| 547 |
+
|
| 548 |
+
for (uint32_t col = startCol; col <= endCol; ++col) {
|
| 549 |
+
auto colCells = reinterpret_cast<const uint32_t*>(rowCells[col].data());
|
| 550 |
+
|
| 551 |
+
const uint32_t ct = colCells[0];
|
| 552 |
+
|
| 553 |
+
for (uint32_t i = threadRank + 1; i <= ct; i += BLOCK_SIZE) {
|
| 554 |
+
const uint32_t otherIdx = colCells[i];
|
| 555 |
+
|
| 556 |
+
const uint32_t maskIdx = otherIdx >> 5; // Divide by 32, since there are 32 bits per mask slot
|
| 557 |
+
const uint32_t maskBit = 1 << (otherIdx & 0x1F); // Set the relevant bit for this mask ID
|
| 558 |
+
|
| 559 |
+
const bool alreadyVisited = atomicOr(s_alreadyVisited + maskIdx, maskBit) & maskBit;
|
| 560 |
+
|
| 561 |
+
if (!alreadyVisited) {
|
| 562 |
+
const auto bdsOther = StridedEmbedQuad_<T>{ exData + otherIdx * allEmbedQuads.stride(2), allEmbedQuads.stride(1) }.Bounds();
|
| 563 |
+
|
| 564 |
+
T pctAnchor, pctOther, iou;
|
| 565 |
+
thrust::tie(pctAnchor, pctOther, iou) = geometry_region_sizes(bdsAnchor, bdsOther);
|
| 566 |
+
|
| 567 |
+
if (iou >= iouThreshold) {
|
| 568 |
+
auto validGroup = cg::coalesced_threads();
|
| 569 |
+
|
| 570 |
+
uint32_t storeOff = 0;
|
| 571 |
+
if (validGroup.thread_rank() == 0) {
|
| 572 |
+
storeOff = atomicAdd(exAdjCounts + quadIdx, validGroup.size());
|
| 573 |
+
}
|
| 574 |
+
storeOff = validGroup.shfl(storeOff, 0) + validGroup.thread_rank();
|
| 575 |
+
|
| 576 |
+
exAdjValues[storeOff] = otherIdx;
|
| 577 |
+
|
| 578 |
+
if (otherIdx > quadIdx) {
|
| 579 |
+
outIsStart[b][otherIdx] = false;
|
| 580 |
+
}
|
| 581 |
+
} else if (pctAnchor > 0.8f || pctOther > 0.8f) {
|
| 582 |
+
T anchorHeight = bdsAnchor.Height();
|
| 583 |
+
T otherHeight = bdsOther.Height();
|
| 584 |
+
|
| 585 |
+
T ratio = anchorHeight > otherHeight ?
|
| 586 |
+
otherHeight / anchorHeight :
|
| 587 |
+
anchorHeight / otherHeight;
|
| 588 |
+
if (ratio > 0.9f) {
|
| 589 |
+
if (pctAnchor > 0.8f) {
|
| 590 |
+
// Other envelops anchor
|
| 591 |
+
outIsStart[b][quadIdx] = false;
|
| 592 |
+
} else {
|
| 593 |
+
outIsStart[b][otherIdx] = false;
|
| 594 |
+
}
|
| 595 |
+
}
|
| 596 |
+
}
|
| 597 |
+
}
|
| 598 |
+
}
|
| 599 |
+
}
|
| 600 |
+
}
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
template<bool IsSingleExample>
|
| 604 |
+
__global__
|
| 605 |
+
void device_flatten_graph_iterative(const uint64_t punCounts,
|
| 606 |
+
torch::PackedTensorAccessor64<bool, 2> allIsStart,
|
| 607 |
+
volatile uint32_t *allAdjCounts,
|
| 608 |
+
volatile uint32_t *allAdjValues
|
| 609 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 610 |
+
, int32_t *maxDepth
|
| 611 |
+
#endif
|
| 612 |
+
)
|
| 613 |
+
{
|
| 614 |
+
constexpr uint32_t WARP_SIZE = 32;
|
| 615 |
+
constexpr uint32_t VISIT_STACK_SIZE = 9;
|
| 616 |
+
constexpr uint32_t TERM_VALUE = std::numeric_limits<uint32_t>::max();
|
| 617 |
+
|
| 618 |
+
constexpr visit_mask_t VISITED_MASK = 0b001;
|
| 619 |
+
constexpr visit_mask_t ADDED_MASK = 0b010;
|
| 620 |
+
constexpr visit_mask_t QUEUED_MASK = 0b100;
|
| 621 |
+
constexpr visit_mask_t QUEUED_OR_VISITED_MASK = VISITED_MASK | QUEUED_MASK;
|
| 622 |
+
|
| 623 |
+
const uint32_t b = blockIdx.z;
|
| 624 |
+
const uint32_t anchorRow = blockIdx.y;
|
| 625 |
+
|
| 626 |
+
const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 627 |
+
|
| 628 |
+
// Only need to check this if there are multiple examples, since in the case of a single example,
|
| 629 |
+
// the grid is precisely sized to that quadCt
|
| 630 |
+
if constexpr (!IsSingleExample) {
|
| 631 |
+
if (anchorRow >= quadCt) {
|
| 632 |
+
return;
|
| 633 |
+
}
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
auto isStart = allIsStart[b].data();
|
| 637 |
+
|
| 638 |
+
const uint32_t threadRank = threadIdx.x;
|
| 639 |
+
|
| 640 |
+
extern __shared__ visit_mask_t s_visitedMask[];
|
| 641 |
+
|
| 642 |
+
#ifndef NMS_VERIFY_CORRECTNESS
|
| 643 |
+
// Only need to process the anchor rows, since they're the only ones
|
| 644 |
+
// that will make it through the full NMS operation.
|
| 645 |
+
// NOTE: There's a race condition where some rows may be marked as anchor,
|
| 646 |
+
// but they'll later be marked non-anchor over the course of this kernel.
|
| 647 |
+
// That's fine. It's a bit of extra work, but there's no real way around it.
|
| 648 |
+
const bool anchorIsStart = isStart[anchorRow];
|
| 649 |
+
if (!anchorIsStart) {
|
| 650 |
+
return;
|
| 651 |
+
}
|
| 652 |
+
#endif
|
| 653 |
+
|
| 654 |
+
uint32_t *pIntVisitedMask = reinterpret_cast<uint32_t*>(s_visitedMask);
|
| 655 |
+
uint32_t zeroTerm = (quadCt + 3) >> 2; // Fast version of div_up(quadCt, 4)
|
| 656 |
+
for (uint32_t col = threadRank; col < zeroTerm; col += blockDim.x) {
|
| 657 |
+
pIntVisitedMask[col] = 0;
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
__syncthreads();
|
| 661 |
+
|
| 662 |
+
const uint32_t maxExCount = allIsStart.size(1);
|
| 663 |
+
auto adjCounts = allAdjCounts + (b * maxExCount);
|
| 664 |
+
auto adjValues = allAdjValues + (b * maxExCount * maxExCount);
|
| 665 |
+
|
| 666 |
+
auto adjAnchorValues = adjValues + (anchorRow * maxExCount);
|
| 667 |
+
// For the anchor row, set the visited mask to 0b10, which will signify that we haven't visited it yet,
|
| 668 |
+
// but that the value is already in the adjacency vector.
|
| 669 |
+
// 0bx1 signifies that the value has been visited
|
| 670 |
+
for (uint32_t i = threadRank, ct = adjCounts[anchorRow]; i < ct; i += blockDim.x) {
|
| 671 |
+
const auto adjCol = adjAnchorValues[i];
|
| 672 |
+
s_visitedMask[adjCol] = ADDED_MASK;
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
__syncthreads();
|
| 676 |
+
|
| 677 |
+
if (threadRank == 0) {
|
| 678 |
+
s_visitedMask[anchorRow] |= QUEUED_MASK;
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
__syncthreads();
|
| 682 |
+
|
| 683 |
+
// TODO(mranzinger): Is it worth incorporating these other threads?
|
| 684 |
+
// It seems like the vast majority of adjacency counts is <32
|
| 685 |
+
if (threadRank >= WARP_SIZE) {
|
| 686 |
+
return;
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
uint32_t visitStack[VISIT_STACK_SIZE];
|
| 690 |
+
visitStack[0] = TERM_VALUE;
|
| 691 |
+
visitStack[1] = anchorRow;
|
| 692 |
+
#ifndef NDEBUG
|
| 693 |
+
for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
|
| 694 |
+
visitStack[i] = -2;
|
| 695 |
+
}
|
| 696 |
+
#endif
|
| 697 |
+
int32_t visitPtr = 1;
|
| 698 |
+
|
| 699 |
+
while (true) {
|
| 700 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 701 |
+
assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
|
| 702 |
+
#endif
|
| 703 |
+
const uint32_t threadNextCol = visitStack[visitPtr];
|
| 704 |
+
const uint32_t warpNextCol = __reduce_min_full_warp(threadNextCol);
|
| 705 |
+
|
| 706 |
+
// Check to see if this thread got chosen.
|
| 707 |
+
// If so, decrement the stack counter
|
| 708 |
+
if (threadNextCol == warpNextCol) {
|
| 709 |
+
#ifndef NDEBUG
|
| 710 |
+
// This makes it easier to debug where the pointer is
|
| 711 |
+
visitStack[visitPtr] = -2;
|
| 712 |
+
#endif
|
| 713 |
+
--visitPtr;
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
// If the maximum value encountered is -1, that means that none of the threads
|
| 717 |
+
// had another value to process
|
| 718 |
+
if (warpNextCol == TERM_VALUE) {
|
| 719 |
+
break;
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
const uint32_t procRow = warpNextCol;
|
| 723 |
+
|
| 724 |
+
__syncthreads();
|
| 725 |
+
|
| 726 |
+
bool isAlreadyVisited = s_visitedMask[procRow] & VISITED_MASK;
|
| 727 |
+
|
| 728 |
+
if (isAlreadyVisited) {
|
| 729 |
+
continue;
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
const uint32_t procAdjCount = adjCounts[procRow];
|
| 733 |
+
auto procAdjValues = adjValues + (procRow * maxExCount);
|
| 734 |
+
|
| 735 |
+
// Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
|
| 736 |
+
// The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
|
| 737 |
+
// has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
|
| 738 |
+
// the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
|
| 739 |
+
for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
|
| 740 |
+
const uint32_t adjCol = procAdjValues[i];
|
| 741 |
+
|
| 742 |
+
// This will set the queued flag for this column, if it's not already set.
|
| 743 |
+
// It also returns the old state. In our case, we only want to add this value to the
|
| 744 |
+
// stack iff it hasn't already been visited, and hasn't been queued elsewhere
|
| 745 |
+
// NOTE: CUDA doesn't support atomicOr on uint8_t :(, but it's not necessary that
|
| 746 |
+
// the operation be absolutely atomic, so the poor man's version is probably okay
|
| 747 |
+
const auto oldMask = s_visitedMask[adjCol];
|
| 748 |
+
auto newMask = oldMask;
|
| 749 |
+
|
| 750 |
+
bool alreadyAdded = oldMask & ADDED_MASK;
|
| 751 |
+
|
| 752 |
+
auto group = cg::coalesced_threads();
|
| 753 |
+
const uint32_t gThreadRank = group.thread_rank();
|
| 754 |
+
uint32_t notAddedBallot = group.ballot(!alreadyAdded);
|
| 755 |
+
if (notAddedBallot) {
|
| 756 |
+
// Only one warp will ever be adding values to a given row, which means
|
| 757 |
+
// that we don't need atomics. However, other warps may be reading data
|
| 758 |
+
// from anchorRow, which means that we need to add the values first,
|
| 759 |
+
// followed by incrementing the count. This order makes things
|
| 760 |
+
// concurrency safe.
|
| 761 |
+
const uint32_t globalStoreOff = adjCounts[anchorRow];
|
| 762 |
+
// Gets the count of the bits to the left of this thread
|
| 763 |
+
const uint32_t localStoreOff = __popc(notAddedBallot & ((1 << gThreadRank) - 1));
|
| 764 |
+
|
| 765 |
+
if (!alreadyAdded) {
|
| 766 |
+
adjAnchorValues[globalStoreOff + localStoreOff] = adjCol;
|
| 767 |
+
if (adjCol > anchorRow) {
|
| 768 |
+
// Also, ensure that this quad is no longer marked as a starting quad
|
| 769 |
+
isStart[adjCol] = false;
|
| 770 |
+
}
|
| 771 |
+
newMask |= ADDED_MASK;
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
// Finally, commit the change by incrementing the counter
|
| 775 |
+
if (gThreadRank == 0) {
|
| 776 |
+
adjCounts[anchorRow] += __popc(notAddedBallot);
|
| 777 |
+
}
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
bool alreadyHandled = oldMask & QUEUED_OR_VISITED_MASK;
|
| 781 |
+
|
| 782 |
+
if (!alreadyHandled) {
|
| 783 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 784 |
+
newMask |= QUEUED_MASK;
|
| 785 |
+
++visitPtr;
|
| 786 |
+
assert(visitPtr < VISIT_STACK_SIZE);
|
| 787 |
+
atomicMax(maxDepth, visitPtr);
|
| 788 |
+
visitStack[visitPtr] = adjCol;
|
| 789 |
+
#else
|
| 790 |
+
// Prefer potentially inconsistent results over buffer overflow
|
| 791 |
+
if (visitPtr < VISIT_STACK_SIZE - 1) {
|
| 792 |
+
newMask |= QUEUED_MASK;
|
| 793 |
+
++visitPtr;
|
| 794 |
+
visitStack[visitPtr] = adjCol;
|
| 795 |
+
}
|
| 796 |
+
#endif
|
| 797 |
+
}
|
| 798 |
+
|
| 799 |
+
if (newMask != oldMask) {
|
| 800 |
+
s_visitedMask[adjCol] = newMask;
|
| 801 |
+
}
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
// We actually rely on the `pop_next` function largely to handle recursing down into the next row
|
| 805 |
+
__syncthreads();
|
| 806 |
+
}
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
void add_to_set(const torch::TensorAccessor<int32_t, 1>& adjCounts,
|
| 810 |
+
const torch::TensorAccessor<int32_t, 2>& adjValues,
|
| 811 |
+
int32_t row,
|
| 812 |
+
std::unordered_set<int32_t>& possible)
|
| 813 |
+
{
|
| 814 |
+
if (possible.count(row)) {
|
| 815 |
+
return;
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
possible.insert(row);
|
| 819 |
+
|
| 820 |
+
const int32_t adjCount = adjCounts[row];
|
| 821 |
+
auto values = adjValues[row].data();
|
| 822 |
+
|
| 823 |
+
for (int32_t i = 0; i < adjCount; ++i) {
|
| 824 |
+
const int32_t col = values[i];
|
| 825 |
+
add_to_set(adjCounts, adjValues, col, possible);
|
| 826 |
+
}
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
template<bool IsSingleExample>
|
| 830 |
+
void cpu_flatten_graph(const uint64_t punCounts,
|
| 831 |
+
torch::Tensor isStartTensorGPU,
|
| 832 |
+
torch::Tensor adjCountsTensorGPU,
|
| 833 |
+
torch::Tensor adjValuesTensorGPU)
|
| 834 |
+
{
|
| 835 |
+
auto isStartTensor = isStartTensorGPU.cpu();
|
| 836 |
+
auto adjCountsTensor = adjCountsTensorGPU.cpu();
|
| 837 |
+
auto adjValuesTensor = adjValuesTensorGPU.cpu();
|
| 838 |
+
|
| 839 |
+
auto allIsStart = isStartTensor.accessor<bool, 2>();
|
| 840 |
+
auto allAdjCounts = adjCountsTensor.accessor<int32_t, 2>();
|
| 841 |
+
auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
|
| 842 |
+
|
| 843 |
+
for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
|
| 844 |
+
const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 845 |
+
|
| 846 |
+
for (int32_t row = 0; row < quadCt; ++row) {
|
| 847 |
+
std::unordered_set<int32_t> fullAdjSet;
|
| 848 |
+
add_to_set(allAdjCounts[b], allAdjValues[b], row, fullAdjSet);
|
| 849 |
+
|
| 850 |
+
int32_t &currCt = allAdjCounts[b][row];
|
| 851 |
+
int32_t *currValues = allAdjValues[b][row].data();
|
| 852 |
+
std::unordered_set<int32_t> existingSet{ currValues, currValues + currCt };
|
| 853 |
+
|
| 854 |
+
for (int32_t adjCol : fullAdjSet) {
|
| 855 |
+
if (existingSet.count(adjCol)) {
|
| 856 |
+
continue;
|
| 857 |
+
}
|
| 858 |
+
|
| 859 |
+
currValues[currCt] = adjCol;
|
| 860 |
+
++currCt;
|
| 861 |
+
|
| 862 |
+
if (adjCol > row) {
|
| 863 |
+
allIsStart[b][adjCol] = false;
|
| 864 |
+
}
|
| 865 |
+
}
|
| 866 |
+
}
|
| 867 |
+
}
|
| 868 |
+
|
| 869 |
+
isStartTensorGPU.copy_(isStartTensor);
|
| 870 |
+
adjCountsTensorGPU.copy_(adjCountsTensor);
|
| 871 |
+
adjValuesTensorGPU.copy_(adjValuesTensor);
|
| 872 |
+
}
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
__global__
|
| 876 |
+
void device_a2a_adj_cleanup(const int32_t *counts,
|
| 877 |
+
torch::PackedTensorAccessor64<uint8_t, 3> inOutAdjacency)
|
| 878 |
+
{
|
| 879 |
+
const uint32_t b = blockIdx.y;
|
| 880 |
+
const uint32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 881 |
+
const uint32_t numQuads = counts[b];
|
| 882 |
+
const uint32_t row = jobIdx / numQuads;
|
| 883 |
+
const uint32_t col = jobIdx % numQuads;
|
| 884 |
+
|
| 885 |
+
if (row >= numQuads) {
|
| 886 |
+
return;
|
| 887 |
+
}
|
| 888 |
+
|
| 889 |
+
auto adjacency = inOutAdjacency[b];
|
| 890 |
+
|
| 891 |
+
bool rowPivot = adjacency[row][row] > 0;
|
| 892 |
+
bool colPivot = adjacency[col][col] > 0;
|
| 893 |
+
|
| 894 |
+
if (!rowPivot || !colPivot) {
|
| 895 |
+
adjacency[row][col] = 0;
|
| 896 |
+
}
|
| 897 |
+
}
|
| 898 |
+
|
| 899 |
+
template<uint32_t NumWarps, typename T, bool IsSingleExample>
|
| 900 |
+
__global__
|
| 901 |
+
void device_a2a_collapse(const uint64_t punCounts,
|
| 902 |
+
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 903 |
+
torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
|
| 904 |
+
const int64_t *regionCounts,
|
| 905 |
+
torch::PackedTensorAccessor64<int32_t, 2> allAdjCounts,
|
| 906 |
+
torch::PackedTensorAccessor64<int32_t, 3> allAdjValues,
|
| 907 |
+
//torch::PackedTensorAccessor64<int32_t, 2> allOutPositions,
|
| 908 |
+
torch::PackedTensorAccessor64<T, 3> outQuads,
|
| 909 |
+
T *outConf)
|
| 910 |
+
{
|
| 911 |
+
constexpr uint32_t WARP_SIZE = 32;
|
| 912 |
+
constexpr uint32_t FULL_WARP = 0xFFFFFFFF;
|
| 913 |
+
constexpr uint32_t BLOCK_WIDTH = NumWarps * WARP_SIZE;
|
| 914 |
+
constexpr size_t MERGE_QUAD_SIZE = sizeof(MergeQuad_<T>) / sizeof(T);
|
| 915 |
+
|
| 916 |
+
static_assert(NumWarps < WARP_SIZE, "Only a single warp currently supported!");
|
| 917 |
+
|
| 918 |
+
const uint32_t b = blockIdx.z;
|
| 919 |
+
const uint32_t row = blockIdx.y;
|
| 920 |
+
|
| 921 |
+
const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 922 |
+
|
| 923 |
+
if constexpr (!IsSingleExample) {
|
| 924 |
+
if (row >= quadCt) {
|
| 925 |
+
return;
|
| 926 |
+
}
|
| 927 |
+
}
|
| 928 |
+
|
| 929 |
+
// Only process the lead rows
|
| 930 |
+
const auto isLeadRow = IsSingleExample ? allIsLeadRow.data() : allIsLeadRow[b].data();
|
| 931 |
+
if (!isLeadRow[row]) {
|
| 932 |
+
return;
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
const uint32_t threadRank = threadIdx.x;
|
| 936 |
+
const uint32_t localThreadRank = threadRank & 0x1F;
|
| 937 |
+
const uint32_t warpIdx = threadRank >> 5;
|
| 938 |
+
|
| 939 |
+
__shared__ T s_mergeQuad[MERGE_QUAD_SIZE];
|
| 940 |
+
|
| 941 |
+
if constexpr (NumWarps > 1) {
|
| 942 |
+
if (threadRank < MERGE_QUAD_SIZE) {
|
| 943 |
+
s_mergeQuad[threadRank] = 0.0f;
|
| 944 |
+
}
|
| 945 |
+
|
| 946 |
+
__syncthreads();
|
| 947 |
+
}
|
| 948 |
+
|
| 949 |
+
T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
|
| 950 |
+
|
| 951 |
+
const int32_t adjCount = allAdjCounts[b][row];
|
| 952 |
+
const int32_t *adjIdxs = allAdjValues[b][row].data();
|
| 953 |
+
|
| 954 |
+
MergeQuad_<T> localMerge{ZeroInitTag{}};
|
| 955 |
+
|
| 956 |
+
for (int32_t i = threadRank; i < adjCount; i += BLOCK_WIDTH) {
|
| 957 |
+
const int32_t currQuadIdx = adjIdxs[i];
|
| 958 |
+
const StridedEmbedQuad_<T> qCurr{ exData + currQuadIdx * allEmbedQuads.stride(2), allEmbedQuads.stride(1) };
|
| 959 |
+
|
| 960 |
+
localMerge.Append(qCurr);
|
| 961 |
+
}
|
| 962 |
+
|
| 963 |
+
T *mqV = reinterpret_cast<T*>(&localMerge);
|
| 964 |
+
#pragma unroll
|
| 965 |
+
for (uint32_t offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
| 966 |
+
T mergeFactor = offset + localThreadRank < 32;
|
| 967 |
+
#pragma unroll
|
| 968 |
+
for (uint32_t i = 0; i < MERGE_QUAD_SIZE; ++i) {
|
| 969 |
+
mqV[i] += mergeFactor * __shfl_down_sync(FULL_WARP, mqV[i], offset);
|
| 970 |
+
}
|
| 971 |
+
}
|
| 972 |
+
#pragma unroll
|
| 973 |
+
for (uint32_t i = 0; i < MERGE_QUAD_SIZE; ++i) {
|
| 974 |
+
mqV[i] = __shfl_sync(FULL_WARP, mqV[i], 0);
|
| 975 |
+
}
|
| 976 |
+
|
| 977 |
+
// Only need to do a multi-warp merge if there are enough quads to justify it
|
| 978 |
+
if (NumWarps > 1 && adjCount > WARP_SIZE) {
|
| 979 |
+
if (localThreadRank < MERGE_QUAD_SIZE) {
|
| 980 |
+
atomicAdd(s_mergeQuad + localThreadRank, mqV[localThreadRank]);
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
__syncthreads();
|
| 984 |
+
|
| 985 |
+
mqV = s_mergeQuad;
|
| 986 |
+
}
|
| 987 |
+
|
| 988 |
+
// Figure out the output position
|
| 989 |
+
uint32_t writePosition = 0;
|
| 990 |
+
if constexpr (!IsSingleExample) {
|
| 991 |
+
for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) {
|
| 992 |
+
writePosition += regionCounts[i];
|
| 993 |
+
}
|
| 994 |
+
}
|
| 995 |
+
|
| 996 |
+
const int32_t numLongs = row >> 3; // Divide by 8
|
| 997 |
+
const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
|
| 998 |
+
const uint64_t *lpCurrIsLeadRow = reinterpret_cast<const uint64_t*>(pCurrIsLeadRow);
|
| 999 |
+
|
| 1000 |
+
for (int32_t i = threadRank; i < numLongs; i += BLOCK_WIDTH) {
|
| 1001 |
+
writePosition += __popcll(lpCurrIsLeadRow[i]);
|
| 1002 |
+
}
|
| 1003 |
+
for (int32_t i = (numLongs * 8) + threadRank; i < row; i += BLOCK_WIDTH) {
|
| 1004 |
+
if (pCurrIsLeadRow[i]) {
|
| 1005 |
+
++writePosition;
|
| 1006 |
+
}
|
| 1007 |
+
}
|
| 1008 |
+
// Sum all of the individual offsets over the warp
|
| 1009 |
+
writePosition = __reduce_add_full_warp(writePosition);
|
| 1010 |
+
// Reduce across warps, if applicable
|
| 1011 |
+
if constexpr (NumWarps > 1) {
|
| 1012 |
+
__shared__ uint32_t s_threadWritePositions[NumWarps];
|
| 1013 |
+
if (localThreadRank == 0) {
|
| 1014 |
+
s_threadWritePositions[warpIdx] = writePosition;
|
| 1015 |
+
}
|
| 1016 |
+
__syncthreads();
|
| 1017 |
+
writePosition = threadRank < NumWarps ? s_threadWritePositions[threadRank] : 0;
|
| 1018 |
+
writePosition = __reduce_add_full_warp(writePosition);
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
if (threadRank >= 9) {
|
| 1022 |
+
return;
|
| 1023 |
+
}
|
| 1024 |
+
|
| 1025 |
+
const T sumConfidence = mqV[8];
|
| 1026 |
+
const T numQuads = mqV[9];
|
| 1027 |
+
const T divisor = threadRank < 8 ? sumConfidence : numQuads;
|
| 1028 |
+
|
| 1029 |
+
const T myVal = mqV[threadRank] / divisor;
|
| 1030 |
+
|
| 1031 |
+
auto writeVerts = outQuads[writePosition].data();
|
| 1032 |
+
|
| 1033 |
+
if (threadRank < 8) {
|
| 1034 |
+
writeVerts[threadRank] = myVal;
|
| 1035 |
+
} else {
|
| 1036 |
+
outConf[writePosition] = myVal;
|
| 1037 |
+
}
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
struct CollapseRowsResult {
|
| 1041 |
+
torch::Tensor ExCounts;
|
| 1042 |
+
torch::Tensor StridedMergeQuads;
|
| 1043 |
+
int32_t TotalNumQuads;
|
| 1044 |
+
// NOTE: This will only be available in Debug builds
|
| 1045 |
+
torch::Tensor QuadIds;
|
| 1046 |
+
int32_t ImageWidth;
|
| 1047 |
+
int32_t ImageHeight;
|
| 1048 |
+
};
|
| 1049 |
+
|
| 1050 |
+
template<typename scalar_t>
|
| 1051 |
+
CollapseRowsResult collapse_rows(
|
| 1052 |
+
torch::Tensor quads, torch::Tensor probs, scalar_t probThreshold, scalar_t iouThreshold
|
| 1053 |
+
)
|
| 1054 |
+
{
|
| 1055 |
+
if (! quads.is_contiguous()) {
|
| 1056 |
+
throw std::runtime_error("Expected `quads` to be contiguous!");
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
if ((quads.size(2) % 32) != 0) {
|
| 1060 |
+
throw std::runtime_error("Expected the width of the `quads` buffer to be a multiple of 32!");
|
| 1061 |
+
}
|
| 1062 |
+
|
| 1063 |
+
int32_t imageWidth = quads.size(2) * 4;
|
| 1064 |
+
int32_t imageHeight = quads.size(1) * 4;
|
| 1065 |
+
|
| 1066 |
+
quads = quads.reshape({ quads.size(0), -1, 32, 4, 2 });
|
| 1067 |
+
probs = probs.reshape({ probs.size(0), -1, 32 });
|
| 1068 |
+
|
| 1069 |
+
if (quads.size(0) != probs.size(0) || quads.size(1) != probs.size(1)) {
|
| 1070 |
+
throw std::runtime_error("Dimension mismatch between `quads` and `probs`");
|
| 1071 |
+
}
|
| 1072 |
+
|
| 1073 |
+
// The final counter is for the total number of quads for the entire batch
|
| 1074 |
+
auto counts = torch::zeros({ quads.size(0) + 1 }, quads.options().dtype(torch::kInt32));
|
| 1075 |
+
|
| 1076 |
+
int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
|
| 1077 |
+
auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
|
| 1078 |
+
|
| 1079 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1080 |
+
auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
|
| 1081 |
+
std::numeric_limits<int32_t>::max(),
|
| 1082 |
+
counts.options().dtype(torch::kInt32));
|
| 1083 |
+
#else
|
| 1084 |
+
torch::Tensor idsTensor;
|
| 1085 |
+
#endif
|
| 1086 |
+
|
| 1087 |
+
dim3 blockSize(32, 3, 1);
|
| 1088 |
+
dim3 gridSize(1,
|
| 1089 |
+
div_up(quads.size(1), blockSize.y),
|
| 1090 |
+
quads.size(0));
|
| 1091 |
+
|
| 1092 |
+
device_row_collapse KERNEL_ARG2(gridSize, blockSize) (
|
| 1093 |
+
quads.packed_accessor64<scalar_t, 5>(),
|
| 1094 |
+
probs.packed_accessor64<scalar_t, 3>(),
|
| 1095 |
+
probThreshold, iouThreshold,
|
| 1096 |
+
counts.packed_accessor64<int32_t, 1>(),
|
| 1097 |
+
rowMergeTensor.packed_accessor64<scalar_t, 3>()
|
| 1098 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1099 |
+
, idsTensor.packed_accessor64<int32_t, 2>()
|
| 1100 |
+
#endif
|
| 1101 |
+
);
|
| 1102 |
+
|
| 1103 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1104 |
+
static std::unordered_set<int32_t> s_quadIds;
|
| 1105 |
+
auto cpuIdsTensor = idsTensor.cpu();
|
| 1106 |
+
const int32_t *idsPtr = cpuIdsTensor.data_ptr<int32_t>();
|
| 1107 |
+
if (s_quadIds.empty()) {
|
| 1108 |
+
s_quadIds.insert(idsPtr, idsPtr + idsTensor.numel());
|
| 1109 |
+
} else {
|
| 1110 |
+
std::unordered_set<int32_t> otherIds{ idsPtr, idsPtr + idsTensor.numel() };
|
| 1111 |
+
|
| 1112 |
+
if (s_quadIds != otherIds) {
|
| 1113 |
+
throw std::runtime_error("Inconsistent Ids!");
|
| 1114 |
+
}
|
| 1115 |
+
}
|
| 1116 |
+
#endif
|
| 1117 |
+
|
| 1118 |
+
// The final value in `counts` is actually to total number of quads for the entire batch
|
| 1119 |
+
int32_t totalQuads = counts[-1].item<int32_t>();
|
| 1120 |
+
|
| 1121 |
+
counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
|
| 1122 |
+
|
| 1123 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1124 |
+
int64_t maxExCount;
|
| 1125 |
+
if (counts.size(0) > 1) {
|
| 1126 |
+
maxExCount = counts.max().item<int32_t>();
|
| 1127 |
+
} else {
|
| 1128 |
+
maxExCount = totalQuads;
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
static bool s_sortOrder = false;
|
| 1132 |
+
|
| 1133 |
+
rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
|
| 1134 |
+
idsTensor = idsTensor.slice(1, 0, maxExCount);
|
| 1135 |
+
auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder); s_sortOrder = !s_sortOrder;
|
| 1136 |
+
|
| 1137 |
+
auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
|
| 1138 |
+
|
| 1139 |
+
rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
|
| 1140 |
+
idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
|
| 1141 |
+
#endif
|
| 1142 |
+
|
| 1143 |
+
return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
|
| 1144 |
+
}
|
| 1145 |
+
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
void verify_row(const torch::TensorAccessor<int32_t, 1> &adjCounts,
|
| 1149 |
+
const torch::TensorAccessor<int32_t, 2> &adjValues,
|
| 1150 |
+
int32_t row)
|
| 1151 |
+
{
|
| 1152 |
+
// Traverse the graph, and accumulate all set flags across all rows marked
|
| 1153 |
+
// adjacent by the current row. If the merge_up algorithm works correctly, then
|
| 1154 |
+
// `possible` will contain exactly the same set of values as the current row
|
| 1155 |
+
std::unordered_set<int32_t> possible;
|
| 1156 |
+
add_to_set(adjCounts, adjValues, row, possible);
|
| 1157 |
+
|
| 1158 |
+
std::unordered_set<int32_t> thisRow{ row };
|
| 1159 |
+
const int32_t thisCount = adjCounts[row];
|
| 1160 |
+
auto thisValues = adjValues[row].data();
|
| 1161 |
+
thisRow.insert(thisValues, thisValues + thisCount);
|
| 1162 |
+
|
| 1163 |
+
if (thisRow != possible) {
|
| 1164 |
+
throw std::runtime_error("The merge_up algorithm is not correct!");
|
| 1165 |
+
}
|
| 1166 |
+
}
|
| 1167 |
+
|
| 1168 |
+
struct AdjacencyResult {
|
| 1169 |
+
// Shape: BxQ
|
| 1170 |
+
// Specifies whether the given row is a result row
|
| 1171 |
+
torch::Tensor IsLeadRow;
|
| 1172 |
+
// Shape: BxQ
|
| 1173 |
+
// The number of quads that need to be merged with the given quad
|
| 1174 |
+
torch::Tensor AdjCounts;
|
| 1175 |
+
// Shape: BxQx<Num Adjacent>
|
| 1176 |
+
// The indices of the adjacent quads.
|
| 1177 |
+
torch::Tensor AdjValues;
|
| 1178 |
+
int64_t MaxExCount;
|
| 1179 |
+
};
|
| 1180 |
+
|
| 1181 |
+
template<bool IsSingleExample, typename T>
|
| 1182 |
+
void cpu_a2a_adjacency_sparse(const uint64_t punCounts,
|
| 1183 |
+
const T iouThreshold,
|
| 1184 |
+
torch::Tensor embedQuadsTensor,
|
| 1185 |
+
torch::Tensor outIsStartTensorGPU,
|
| 1186 |
+
torch::Tensor outAdjCountsTensorGPU,
|
| 1187 |
+
torch::Tensor outSparseAdjTensorGPU)
|
| 1188 |
+
{
|
| 1189 |
+
embedQuadsTensor = embedQuadsTensor.cpu();
|
| 1190 |
+
auto outIsStartTensor = outIsStartTensorGPU.cpu();
|
| 1191 |
+
auto outAdjCountsTensor = outAdjCountsTensorGPU.cpu();
|
| 1192 |
+
auto outSparseAdjTensor = outSparseAdjTensorGPU.cpu();
|
| 1193 |
+
|
| 1194 |
+
auto embedQuads = embedQuadsTensor.accessor<T, 3>();
|
| 1195 |
+
auto isStart = outIsStartTensor.accessor<bool, 2>();
|
| 1196 |
+
auto adjCounts = outAdjCountsTensor.accessor<int32_t, 2>();
|
| 1197 |
+
auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
|
| 1198 |
+
|
| 1199 |
+
for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
|
| 1200 |
+
const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 1201 |
+
|
| 1202 |
+
T *exData = embedQuads[b].data();
|
| 1203 |
+
|
| 1204 |
+
for (int32_t row = 0; row < quadCt; ++row) {
|
| 1205 |
+
const auto qRow = StridedEmbedQuad_<T>{ exData + row, embedQuads.stride(1) }.Bounds();
|
| 1206 |
+
|
| 1207 |
+
for (int32_t col = 0; col < quadCt; ++col) {
|
| 1208 |
+
const auto qCol = StridedEmbedQuad_<T>{ exData + col, embedQuads.stride(1) }.Bounds();
|
| 1209 |
+
|
| 1210 |
+
T pctRow, pctCol, iou;
|
| 1211 |
+
thrust::tie(pctRow, pctCol, iou) = geometry_region_sizes(qRow, qCol);
|
| 1212 |
+
|
| 1213 |
+
if (iou >= iouThreshold) {
|
| 1214 |
+
int32_t &storeIdx = adjCounts[b][row];
|
| 1215 |
+
adjValues[b][row][storeIdx] = col;
|
| 1216 |
+
++storeIdx;
|
| 1217 |
+
if (row < col) {
|
| 1218 |
+
isStart[b][col] = false;
|
| 1219 |
+
}
|
| 1220 |
+
} else if (pctRow > 0.8f || pctCol > 0.8f) {
|
| 1221 |
+
T anchorHeight = qRow.Height();
|
| 1222 |
+
T otherHeight = qCol.Height();
|
| 1223 |
+
|
| 1224 |
+
T ratio = anchorHeight > otherHeight ?
|
| 1225 |
+
otherHeight / anchorHeight :
|
| 1226 |
+
anchorHeight / otherHeight;
|
| 1227 |
+
if (ratio > 0.9f) {
|
| 1228 |
+
if (pctRow > 0.8f) {
|
| 1229 |
+
// Other envelops anchor
|
| 1230 |
+
isStart[b][row] = false;
|
| 1231 |
+
}
|
| 1232 |
+
else {
|
| 1233 |
+
isStart[b][col] = false;
|
| 1234 |
+
}
|
| 1235 |
+
}
|
| 1236 |
+
}
|
| 1237 |
+
}
|
| 1238 |
+
}
|
| 1239 |
+
}
|
| 1240 |
+
|
| 1241 |
+
outIsStartTensorGPU.copy_(outIsStartTensor);
|
| 1242 |
+
outAdjCountsTensorGPU.copy_(outAdjCountsTensor);
|
| 1243 |
+
outSparseAdjTensorGPU.copy_(outSparseAdjTensor);
|
| 1244 |
+
}
|
| 1245 |
+
|
| 1246 |
+
template<typename T>
|
| 1247 |
+
std::string to_flat_string(torch::Tensor tensor) {
|
| 1248 |
+
tensor = tensor.flatten();
|
| 1249 |
+
|
| 1250 |
+
auto acc = tensor.accessor<T, 1>();
|
| 1251 |
+
|
| 1252 |
+
std::ostringstream oss;
|
| 1253 |
+
oss << "[";
|
| 1254 |
+
if (acc.size(0) > 0) {
|
| 1255 |
+
oss << acc[0];
|
| 1256 |
+
for (int64_t i = 1; i < acc.size(0); ++i) {
|
| 1257 |
+
oss << ", " << acc[i];
|
| 1258 |
+
}
|
| 1259 |
+
}
|
| 1260 |
+
oss << "]";
|
| 1261 |
+
return oss.str();
|
| 1262 |
+
}
|
| 1263 |
+
|
| 1264 |
+
template<typename scalar_t>
|
| 1265 |
+
AdjacencyResult compute_all_to_all_adjacency(
|
| 1266 |
+
const CollapseRowsResult &collapseResult,
|
| 1267 |
+
scalar_t iouThreshold)
|
| 1268 |
+
{
|
| 1269 |
+
torch::Tensor counts = collapseResult.ExCounts;
|
| 1270 |
+
|
| 1271 |
+
int64_t maxExCount;
|
| 1272 |
+
if (counts.size(0) > 1) {
|
| 1273 |
+
maxExCount = counts.max().item<int32_t>();
|
| 1274 |
+
} else {
|
| 1275 |
+
maxExCount = collapseResult.TotalNumQuads;
|
| 1276 |
+
}
|
| 1277 |
+
|
| 1278 |
+
auto isStartTensor = torch::ones({ counts.size(0), maxExCount }, counts.options().dtype(torch::kBool));
|
| 1279 |
+
auto adjCountsTensor = torch::zeros({ counts.size(0), maxExCount }, counts.options().dtype(torch::kInt32));
|
| 1280 |
+
#ifndef NMS_VERIFY_CORRECTNESS
|
| 1281 |
+
auto adjValuesTensor = torch::empty({ counts.size(0), maxExCount, maxExCount }, counts.options().dtype(torch::kInt32));
|
| 1282 |
+
#else
|
| 1283 |
+
auto adjValuesTensor = torch::full({ counts.size(0), maxExCount, maxExCount },
|
| 1284 |
+
5000,
|
| 1285 |
+
counts.options().dtype(torch::kInt32));
|
| 1286 |
+
#endif
|
| 1287 |
+
|
| 1288 |
+
// If the batch is only a single example, instead of hitting global memory for the count, we can
|
| 1289 |
+
// just encode the count into the pointer instead
|
| 1290 |
+
uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
|
| 1291 |
+
if (counts.size(0) == 1) {
|
| 1292 |
+
ptrCounts = maxExCount;
|
| 1293 |
+
}
|
| 1294 |
+
|
| 1295 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1296 |
+
auto cpuAdjValuesTensor = adjValuesTensor.cpu();
|
| 1297 |
+
auto cpuAdjCountsTensor = adjCountsTensor.cpu();
|
| 1298 |
+
auto cpuIsStartTensor = isStartTensor.cpu();
|
| 1299 |
+
#endif
|
| 1300 |
+
|
| 1301 |
+
size_t smemSize;
|
| 1302 |
+
dim3 gridSize, blockSize;
|
| 1303 |
+
|
| 1304 |
+
///////////////////
|
| 1305 |
+
// NOTE(mranzinger): This algorithm uses a fixed sized grid to spatially subdivide the canvas. For virtually all test conditions
|
| 1306 |
+
// I ran this through, it was slightly slower than the brute force approach that parallelizes better.
|
| 1307 |
+
// It's possible that there is some number of words present (e.g. >500) where this algorithm becomes
|
| 1308 |
+
// faster.
|
| 1309 |
+
//
|
| 1310 |
+
//constexpr int32_t CELL_SIZE = 100;
|
| 1311 |
+
//constexpr int64_t NUM_BINS_PER_CELL = 200;
|
| 1312 |
+
//int32_t numXCells = div_up(collapseResult.ImageWidth, CELL_SIZE);
|
| 1313 |
+
//int32_t numYCells = div_up(collapseResult.ImageHeight, CELL_SIZE);
|
| 1314 |
+
//auto gridCellsTensor = torch::zeros({ counts.size(0), numYCells, numXCells, NUM_BINS_PER_CELL }, adjCountsTensor.options());
|
| 1315 |
+
//auto quadCellExtentsTensor = torch::empty({ counts.size(0), maxExCount, 4 }, gridCellsTensor.options());
|
| 1316 |
+
//smemSize = div_up(static_cast<uint32_t>(maxExCount), 32);
|
| 1317 |
+
|
| 1318 |
+
//constexpr uint32_t GRID_NUM_WARPS = 3;
|
| 1319 |
+
//blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
|
| 1320 |
+
//gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1321 |
+
|
| 1322 |
+
//auto buildGridFn = counts.size(0) == 1 ?
|
| 1323 |
+
// device_a2a_adjacency_build_grid<GRID_NUM_WARPS, true, scalar_t, CELL_SIZE> :
|
| 1324 |
+
// device_a2a_adjacency_build_grid<GRID_NUM_WARPS, false, scalar_t, CELL_SIZE>;
|
| 1325 |
+
|
| 1326 |
+
//buildGridFn KERNEL_ARG2(gridSize, blockSize) (
|
| 1327 |
+
// ptrCounts,
|
| 1328 |
+
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1329 |
+
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
| 1330 |
+
// quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
|
| 1331 |
+
//);
|
| 1332 |
+
|
| 1333 |
+
//auto adjGridFn = counts.size(0) == 1 ?
|
| 1334 |
+
// device_a2a_adjacency_with_grid<GRID_NUM_WARPS, true, scalar_t> :
|
| 1335 |
+
// device_a2a_adjacency_with_grid<GRID_NUM_WARPS, false, scalar_t>;
|
| 1336 |
+
|
| 1337 |
+
//adjGridFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
|
| 1338 |
+
// ptrCounts,
|
| 1339 |
+
// iouThreshold,
|
| 1340 |
+
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1341 |
+
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
| 1342 |
+
// quadCellExtentsTensor.packed_accessor64<int32_t, 3>(),
|
| 1343 |
+
// isStartTensor.packed_accessor64<bool, 2>(),
|
| 1344 |
+
// adjCountsTensor.packed_accessor64<int32_t, 2>(),
|
| 1345 |
+
// adjValuesTensor.packed_accessor64<int32_t, 3>()
|
| 1346 |
+
//);
|
| 1347 |
+
///////////////////
|
| 1348 |
+
|
| 1349 |
+
uint32_t totalWork = maxExCount * maxExCount;
|
| 1350 |
+
|
| 1351 |
+
blockSize = dim3{96, 1};
|
| 1352 |
+
gridSize = dim3{div_up(totalWork, blockSize.x),
|
| 1353 |
+
static_cast<uint32_t>(counts.size(0))};
|
| 1354 |
+
|
| 1355 |
+
auto adjFn = counts.size(0) == 1 ? device_a2a_adjacency_sparse<true, scalar_t> : device_a2a_adjacency_sparse<false, scalar_t>;
|
| 1356 |
+
|
| 1357 |
+
// This algorithm is O(n^2) with n being the current number of quads
|
| 1358 |
+
adjFn KERNEL_ARG2(gridSize, blockSize) (
|
| 1359 |
+
ptrCounts,
|
| 1360 |
+
iouThreshold,
|
| 1361 |
+
collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1362 |
+
isStartTensor.packed_accessor64<bool, 2>(),
|
| 1363 |
+
adjCountsTensor.packed_accessor64<int32_t, 2>(),
|
| 1364 |
+
adjValuesTensor.packed_accessor64<int32_t, 3>()
|
| 1365 |
+
);
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1369 |
+
cpu_a2a_adjacency_sparse<true>(ptrCounts, iouThreshold,
|
| 1370 |
+
collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
|
| 1371 |
+
|
| 1372 |
+
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
| 1373 |
+
|
| 1374 |
+
assert(torch::all(cpuAdjCountsTensor == adjCountsTensor.cpu()).item<bool>());
|
| 1375 |
+
assert(torch::all(cpuIsStartTensor == isStartTensor.cpu()).item<bool>());
|
| 1376 |
+
assert(torch::all(cpuAdjValuesTensor == adjValuesTensor.cpu()).item<bool>());
|
| 1377 |
+
|
| 1378 |
+
std::cout << "\tA2A Is Start Count: " << isStartTensor.sum(torch::kInt32).item<int32_t>()
|
| 1379 |
+
<< ", Most Adjacent: " << adjCountsTensor.max().item<int32_t>() << std::endl;
|
| 1380 |
+
|
| 1381 |
+
auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
|
| 1382 |
+
#endif
|
| 1383 |
+
|
| 1384 |
+
auto traverseFn = counts.size(0) == 1 ?
|
| 1385 |
+
device_flatten_graph_iterative<true> :
|
| 1386 |
+
device_flatten_graph_iterative<false>;
|
| 1387 |
+
|
| 1388 |
+
blockSize = dim3{ 128, 1, 1 };
|
| 1389 |
+
gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1390 |
+
smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
|
| 1391 |
+
|
| 1392 |
+
traverseFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
|
| 1393 |
+
ptrCounts,
|
| 1394 |
+
isStartTensor.packed_accessor64<bool, 2>(),
|
| 1395 |
+
reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
|
| 1396 |
+
reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
|
| 1397 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1398 |
+
, maxDepthTensor.data_ptr<int32_t>()
|
| 1399 |
+
#endif
|
| 1400 |
+
);
|
| 1401 |
+
|
| 1402 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1403 |
+
cpu_flatten_graph<true>(ptrCounts, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
|
| 1404 |
+
|
| 1405 |
+
cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
|
| 1406 |
+
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
| 1407 |
+
|
| 1408 |
+
torch::Tensor diffStartIdxs = (cpuIsStartTensor != isStartTensor.cpu()).nonzero_numpy()[0];
|
| 1409 |
+
|
| 1410 |
+
assert(diffStartIdxs.numel() == 0);
|
| 1411 |
+
|
| 1412 |
+
torch::Tensor diffCountIdxs = (cpuAdjCountsTensor != adjCountsTensor.cpu()).nonzero_numpy()[0];
|
| 1413 |
+
|
| 1414 |
+
assert(diffCountIdxs.numel() == 0);
|
| 1415 |
+
|
| 1416 |
+
auto diffValuesTensor = torch::any(cpuAdjValuesTensor != adjValuesTensor.cpu(), /*dim=*/ 2, /*keepdim=*/ false).flatten().nonzero().flatten();
|
| 1417 |
+
|
| 1418 |
+
std::cout << "\t\tDiff Indices: " << to_flat_string<int64_t>(diffValuesTensor) << std::endl;
|
| 1419 |
+
|
| 1420 |
+
auto cpuDiffCountsTensor = cpuAdjCountsTensor.flatten().index({ diffValuesTensor });
|
| 1421 |
+
auto cpuDiffRowsTensor = cpuAdjValuesTensor.flatten(0, 1).index({ diffValuesTensor });
|
| 1422 |
+
auto gpuDiffRowsTensor = adjValuesTensor.cpu().flatten(0, 1).index({ diffValuesTensor });
|
| 1423 |
+
|
| 1424 |
+
for (int64_t i = 0, ct = cpuDiffRowsTensor.size(0); i < ct; ++i) {
|
| 1425 |
+
auto z = cpuDiffCountsTensor[i].item<int32_t>();
|
| 1426 |
+
auto diffRow = diffValuesTensor[i].item<int64_t>();
|
| 1427 |
+
std::cout << "\t\tRow " << diffRow << std::endl;
|
| 1428 |
+
std::cout << "\t\t\tExpected: " << to_flat_string<int32_t>(cpuDiffRowsTensor[i].slice(0, 0, z + 1)) << std::endl;
|
| 1429 |
+
std::cout << "\t\t\t GPU: " << to_flat_string<int32_t>(gpuDiffRowsTensor[i].slice(0, 0, z + 1)) << std::endl;
|
| 1430 |
+
}
|
| 1431 |
+
|
| 1432 |
+
assert(diffValuesTensor.size(0) == 0);
|
| 1433 |
+
|
| 1434 |
+
std::cout << "\tA2A - Flatten - Is Start Count: " << isStartTensor.sum(torch::kInt32).item<int32_t>()
|
| 1435 |
+
<< ", Most Adjacent: " << adjCountsTensor.max().item<int32_t>()
|
| 1436 |
+
<< ", Max Depth: " << maxDepthTensor.item<int32_t>() << std::endl;
|
| 1437 |
+
|
| 1438 |
+
cpuIsStartTensor = isStartTensor.cpu();
|
| 1439 |
+
cpuAdjCountsTensor = adjCountsTensor.cpu();
|
| 1440 |
+
cpuAdjValuesTensor = adjValuesTensor.cpu();
|
| 1441 |
+
auto cpuCounts = counts.cpu();
|
| 1442 |
+
auto cpuCollapseIds = collapseResult.QuadIds.cpu();
|
| 1443 |
+
|
| 1444 |
+
static std::vector<std::unordered_set<int32_t>> s_knownGroups;
|
| 1445 |
+
static std::unordered_map<int32_t, std::unordered_set<int32_t>> s_groupLookup;
|
| 1446 |
+
|
| 1447 |
+
std::vector<std::unordered_set<int32_t>> idGroups;
|
| 1448 |
+
decltype(s_groupLookup) groupLookup;
|
| 1449 |
+
for (int64_t b = 0; b < counts.size(0); ++b) {
|
| 1450 |
+
int64_t quadCt = cpuCounts[b].item<int32_t>();
|
| 1451 |
+
for (int64_t row = 0; row < quadCt; ++row) {
|
| 1452 |
+
bool isLeadRow = cpuIsStartTensor[b][row].item<bool>();
|
| 1453 |
+
auto bCountsTensor = cpuAdjCountsTensor[b];
|
| 1454 |
+
auto bValuesTensor = cpuAdjValuesTensor[b];
|
| 1455 |
+
auto bCounts = bCountsTensor.accessor<int32_t, 1>();
|
| 1456 |
+
auto bValues = bValuesTensor.accessor<int32_t, 2>();
|
| 1457 |
+
|
| 1458 |
+
auto bIdsTensor = cpuCollapseIds[b];
|
| 1459 |
+
auto bIds = bIdsTensor.accessor<int32_t, 1>();
|
| 1460 |
+
|
| 1461 |
+
std::unordered_set<int32_t> sIds;
|
| 1462 |
+
for (int32_t i = 0, ct = bCounts[row]; i < ct; ++i) {
|
| 1463 |
+
int32_t col = bValues[row][i];
|
| 1464 |
+
int32_t id = bIds[col];
|
| 1465 |
+
sIds.insert(id);
|
| 1466 |
+
}
|
| 1467 |
+
|
| 1468 |
+
if (sIds.empty()) {
|
| 1469 |
+
throw std::runtime_error("The ids tensor is empty!");
|
| 1470 |
+
}
|
| 1471 |
+
|
| 1472 |
+
groupLookup[bIds[row]] = sIds;
|
| 1473 |
+
|
| 1474 |
+
if (isLeadRow) {
|
| 1475 |
+
verify_row(bCounts, bValues, row);
|
| 1476 |
+
idGroups.push_back(move(sIds));
|
| 1477 |
+
}
|
| 1478 |
+
}
|
| 1479 |
+
}
|
| 1480 |
+
|
| 1481 |
+
if (s_knownGroups.empty()) {
|
| 1482 |
+
s_knownGroups = move(idGroups);
|
| 1483 |
+
s_groupLookup = move(groupLookup);
|
| 1484 |
+
} else {
|
| 1485 |
+
// Make a copy
|
| 1486 |
+
auto remOrigGroups = s_knownGroups;
|
| 1487 |
+
auto remOrigGroupLookup = s_groupLookup;
|
| 1488 |
+
|
| 1489 |
+
std::vector<int32_t> quadIds;
|
| 1490 |
+
for (auto &kv : remOrigGroupLookup) {
|
| 1491 |
+
quadIds.push_back(kv.first);
|
| 1492 |
+
}
|
| 1493 |
+
for (int32_t qId : quadIds) {
|
| 1494 |
+
assert(groupLookup.count(qId));
|
| 1495 |
+
}
|
| 1496 |
+
assert(groupLookup.size() == remOrigGroupLookup.size());
|
| 1497 |
+
|
| 1498 |
+
for (int32_t qId : quadIds) {
|
| 1499 |
+
auto &oldGroup = remOrigGroupLookup[qId];
|
| 1500 |
+
auto &newGroup = groupLookup[qId];
|
| 1501 |
+
|
| 1502 |
+
if (oldGroup == newGroup) {
|
| 1503 |
+
remOrigGroupLookup.erase(qId);
|
| 1504 |
+
groupLookup.erase(qId);
|
| 1505 |
+
} else {
|
| 1506 |
+
throw std::runtime_error("Group mismatch!");
|
| 1507 |
+
}
|
| 1508 |
+
}
|
| 1509 |
+
|
| 1510 |
+
for (int i = idGroups.size() - 1; i >= 0; --i) {
|
| 1511 |
+
for (int j = remOrigGroups.size() - 1; j >= 0; --j) {
|
| 1512 |
+
auto &idGroup = idGroups[i];
|
| 1513 |
+
auto &knownGroup = remOrigGroups[j];
|
| 1514 |
+
|
| 1515 |
+
if (idGroup == knownGroup) {
|
| 1516 |
+
idGroups.erase(begin(idGroups) + i);
|
| 1517 |
+
remOrigGroups.erase(begin(remOrigGroups) + j);
|
| 1518 |
+
break;
|
| 1519 |
+
}
|
| 1520 |
+
}
|
| 1521 |
+
}
|
| 1522 |
+
|
| 1523 |
+
if (!idGroups.empty() || !remOrigGroups.empty()) {
|
| 1524 |
+
auto group_str = [] (auto &group) {
|
| 1525 |
+
std::vector<int32_t> vGroup{ std::begin(group), std::end(group) };
|
| 1526 |
+
std::sort(std::begin(vGroup), std::end(vGroup));
|
| 1527 |
+
|
| 1528 |
+
auto id_str = [] (int32_t id) {
|
| 1529 |
+
std::ostringstream oss;
|
| 1530 |
+
//oss << "(" << (id / 32) << ", " << (id % 32) << ")";
|
| 1531 |
+
oss << id;
|
| 1532 |
+
return oss.str();
|
| 1533 |
+
};
|
| 1534 |
+
|
| 1535 |
+
std::ostringstream oss;
|
| 1536 |
+
oss << "[" << id_str(vGroup[0]);
|
| 1537 |
+
for (size_t i = 1; i < vGroup.size(); ++i) {
|
| 1538 |
+
oss << ", " << id_str(vGroup[i]);
|
| 1539 |
+
}
|
| 1540 |
+
oss << "]";
|
| 1541 |
+
return oss.str();
|
| 1542 |
+
};
|
| 1543 |
+
|
| 1544 |
+
std::cout << "\tEncountered a difference in groups!" << std::endl
|
| 1545 |
+
<< "\t\tOrig groups:" << std::endl;
|
| 1546 |
+
for (auto &group : remOrigGroups) {
|
| 1547 |
+
std::cout << "\t\t\t" << group_str(group) << std::endl;
|
| 1548 |
+
}
|
| 1549 |
+
std::cout << "\t\tNew groups:" << std::endl;
|
| 1550 |
+
for (auto &group : idGroups) {
|
| 1551 |
+
std::cout << "\t\t\t" << group_str(group) << std::endl;
|
| 1552 |
+
}
|
| 1553 |
+
}
|
| 1554 |
+
}
|
| 1555 |
+
#endif
|
| 1556 |
+
|
| 1557 |
+
return { isStartTensor, adjCountsTensor, adjValuesTensor, maxExCount };
|
| 1558 |
+
}
|
| 1559 |
+
|
| 1560 |
+
|
| 1561 |
+
|
| 1562 |
+
template<typename scalar_t>
|
| 1563 |
+
nms_result_t
|
| 1564 |
+
all_to_all_collapse(
|
| 1565 |
+
const CollapseRowsResult &collapseRowsRes,
|
| 1566 |
+
const AdjacencyResult &adjResult)
|
| 1567 |
+
{
|
| 1568 |
+
auto counts = collapseRowsRes.ExCounts;
|
| 1569 |
+
auto embedQuads = collapseRowsRes.StridedMergeQuads;
|
| 1570 |
+
|
| 1571 |
+
if (!embedQuads.is_contiguous()) {
|
| 1572 |
+
throw std::runtime_error("Input embed quads were not contiguous!");
|
| 1573 |
+
}
|
| 1574 |
+
|
| 1575 |
+
torch::Tensor isLeadRow;
|
| 1576 |
+
if (counts.size(0) == 1) {
|
| 1577 |
+
isLeadRow = adjResult.IsLeadRow;
|
| 1578 |
+
} else {
|
| 1579 |
+
// For multiple examples: IsLeadRow will have true values beyond the extent of the number of quads
|
| 1580 |
+
// However, we know that Counts > 0 only happen within the extent, so the set intersection
|
| 1581 |
+
// tells us which rows are actually lead
|
| 1582 |
+
isLeadRow = torch::logical_and(adjResult.IsLeadRow, adjResult.AdjCounts > 0);
|
| 1583 |
+
}
|
| 1584 |
+
|
| 1585 |
+
auto regionCounts = isLeadRow.sum(/*dim=*/ 1, /*keepdim=*/ false, torch::kInt64);
|
| 1586 |
+
|
| 1587 |
+
const int64_t numOutQuads = counts.size(0) == 1 ? regionCounts.item<int64_t>() : regionCounts.sum().item<int64_t>();
|
| 1588 |
+
|
| 1589 |
+
constexpr int32_t NUM_WARPS = 4;
|
| 1590 |
+
dim3 blockSize(NUM_WARPS * 32, 1, 1);
|
| 1591 |
+
dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
|
| 1592 |
+
|
| 1593 |
+
// If the batch is only a single example, instead of hitting global memory for the count, we can
|
| 1594 |
+
// just encode the count into the pointer instead
|
| 1595 |
+
uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
|
| 1596 |
+
if (counts.size(0) == 1) {
|
| 1597 |
+
ptrCounts = adjResult.MaxExCount;
|
| 1598 |
+
}
|
| 1599 |
+
|
| 1600 |
+
torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
|
| 1601 |
+
torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
|
| 1602 |
+
|
| 1603 |
+
auto collapseFn = counts.size(0) == 1 ?
|
| 1604 |
+
device_a2a_collapse<NUM_WARPS, scalar_t, true> :
|
| 1605 |
+
device_a2a_collapse<NUM_WARPS, scalar_t, false>;
|
| 1606 |
+
|
| 1607 |
+
collapseFn KERNEL_ARG2(gridSize, blockSize) (
|
| 1608 |
+
ptrCounts,
|
| 1609 |
+
embedQuads.packed_accessor64<scalar_t, 3>(),
|
| 1610 |
+
isLeadRow.packed_accessor64<bool, 2>(),
|
| 1611 |
+
regionCounts.data_ptr<int64_t>(),
|
| 1612 |
+
adjResult.AdjCounts.packed_accessor64<int32_t, 2>(),
|
| 1613 |
+
adjResult.AdjValues.packed_accessor64<int32_t, 3>(),
|
| 1614 |
+
outQuads.packed_accessor64<scalar_t, 3>(),
|
| 1615 |
+
outConf.data_ptr<scalar_t>()
|
| 1616 |
+
);
|
| 1617 |
+
|
| 1618 |
+
return { outQuads, outConf, regionCounts };
|
| 1619 |
+
}
|
| 1620 |
+
|
| 1621 |
+
template<typename scalar_t>
|
| 1622 |
+
nms_result_t cuda_quad_non_maximal_suppression_impl(
|
| 1623 |
+
torch::Tensor quads, torch::Tensor probs,
|
| 1624 |
+
scalar_t probThreshold, scalar_t iouThreshold,
|
| 1625 |
+
int64_t maxRegions, bool verbose)
|
| 1626 |
+
{
|
| 1627 |
+
static const bool s_timerEnabled = true;
|
| 1628 |
+
static const bool s_verboseLevel2 = true;
|
| 1629 |
+
|
| 1630 |
+
// Make sure there's a batch dimension
|
| 1631 |
+
if (quads.dim() == 4) {
|
| 1632 |
+
// B,H,W,V,2
|
| 1633 |
+
quads = quads.unsqueeze(0);
|
| 1634 |
+
// B,H,W
|
| 1635 |
+
probs = probs.unsqueeze(0);
|
| 1636 |
+
}
|
| 1637 |
+
|
| 1638 |
+
//print_tensor_vec_stats2("NMS Input (quads, probs): ", { quads, probs });
|
| 1639 |
+
|
| 1640 |
+
double msRowCollapse = -1,
|
| 1641 |
+
msAdjacency = -1,
|
| 1642 |
+
msA2ACollapse = -1,
|
| 1643 |
+
msTotal = -1;
|
| 1644 |
+
|
| 1645 |
+
CollapseRowsResult collapseRows;
|
| 1646 |
+
AdjacencyResult adjacency;
|
| 1647 |
+
torch::Tensor retQuads, retConf, regionCounts;
|
| 1648 |
+
|
| 1649 |
+
{
|
| 1650 |
+
CudaStoreTimer tTotal{msTotal, s_timerEnabled};
|
| 1651 |
+
{
|
| 1652 |
+
CudaStoreTimer t{msRowCollapse, s_timerEnabled && verbose && s_verboseLevel2};
|
| 1653 |
+
|
| 1654 |
+
// First combine all of the quads in each row
|
| 1655 |
+
collapseRows = collapse_rows(quads, probs, probThreshold, iouThreshold);
|
| 1656 |
+
|
| 1657 |
+
if (collapseRows.TotalNumQuads == 0) {
|
| 1658 |
+
return {
|
| 1659 |
+
torch::empty({ 0, 4, 2 }, quads.options()),
|
| 1660 |
+
torch::empty({ 0 }, probs.options()),
|
| 1661 |
+
collapseRows.ExCounts.toType(torch::kInt64)
|
| 1662 |
+
};
|
| 1663 |
+
}
|
| 1664 |
+
}
|
| 1665 |
+
{
|
| 1666 |
+
CudaStoreTimer t{msAdjacency, s_timerEnabled && verbose && s_verboseLevel2};
|
| 1667 |
+
adjacency = compute_all_to_all_adjacency(collapseRows, iouThreshold);
|
| 1668 |
+
}
|
| 1669 |
+
{
|
| 1670 |
+
CudaStoreTimer t{msA2ACollapse, s_timerEnabled && verbose && s_verboseLevel2};
|
| 1671 |
+
std::tie(retQuads, retConf, regionCounts) = all_to_all_collapse<scalar_t>(collapseRows, adjacency);
|
| 1672 |
+
}
|
| 1673 |
+
}
|
| 1674 |
+
|
| 1675 |
+
#ifndef NDEBUG
|
| 1676 |
+
assert(regionCounts.sum().item<int64_t>() == retQuads.size(0));
|
| 1677 |
+
#endif
|
| 1678 |
+
|
| 1679 |
+
//print_tensor_vec_stats2(" Full NMS (quads, conf, counts): ", { retQuads, retConf, retCounts });
|
| 1680 |
+
|
| 1681 |
+
if (s_timerEnabled && verbose) {
|
| 1682 |
+
std::cout << "NMS Cuda " << retQuads.size(0)
|
| 1683 |
+
<< " - Row Collapse (" << quads.size(0) << ", " << quads.size(1) << ", " << quads.size(2) << ") - (" << collapseRows.TotalNumQuads << "): " << msRowCollapse << "ms"
|
| 1684 |
+
<< ", Adjacency (" << adjacency.AdjCounts.sum(torch::kInt32).item<int32_t>() << "): " << msAdjacency << "ms"
|
| 1685 |
+
<< ", A2A Collapse (" << retQuads.size(0) << "): " << msA2ACollapse << "ms"
|
| 1686 |
+
<< ", Total: " << msTotal << "ms"
|
| 1687 |
+
<< std::endl;
|
| 1688 |
+
}
|
| 1689 |
+
|
| 1690 |
+
return { retQuads, retConf, regionCounts };
|
| 1691 |
+
}
|
| 1692 |
+
|
| 1693 |
+
nms_result_t cuda_quad_non_maximal_suppression(
|
| 1694 |
+
torch::Tensor quads, torch::Tensor probs,
|
| 1695 |
+
float probThreshold, float iouThreshold,
|
| 1696 |
+
int64_t kernelHeight, int64_t kernelWidth,
|
| 1697 |
+
int64_t maxRegions, bool verbose)
|
| 1698 |
+
{
|
| 1699 |
+
nms_result_t ret;
|
| 1700 |
+
|
| 1701 |
+
ret = cuda_quad_non_maximal_suppression_impl<float>(
|
| 1702 |
+
quads.toType(torch::kFloat32), probs.toType(torch::kFloat32),
|
| 1703 |
+
probThreshold, iouThreshold,
|
| 1704 |
+
maxRegions, verbose
|
| 1705 |
+
);
|
| 1706 |
+
|
| 1707 |
+
// AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 1708 |
+
// quads.scalar_type(),
|
| 1709 |
+
// "cuda_quad_non_maximal_suppression_impl",
|
| 1710 |
+
// ([&] {
|
| 1711 |
+
// ret = cuda_quad_non_maximal_suppression_impl<scalar_t>(
|
| 1712 |
+
// move(quads), move(probs),
|
| 1713 |
+
// probThreshold, iouThreshold,
|
| 1714 |
+
// maxRegions
|
| 1715 |
+
// );
|
| 1716 |
+
// })
|
| 1717 |
+
// );
|
| 1718 |
+
|
| 1719 |
+
return ret;
|
| 1720 |
+
}
|
nemo-retriever-ocr/cpp/non_maximal_suppression/nms_common.h
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <memory>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <unordered_set>
|
| 10 |
+
|
| 11 |
+
#include "../geometry.h"
|
| 12 |
+
#include "../cuda_intellisense.cuh"
|
| 13 |
+
#include "strided_quad.h"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
std::vector<torch::Tensor> quad_nms_from_adjacency(
|
| 18 |
+
torch::Tensor quads, torch::Tensor probs, torch::Tensor adjacency,
|
| 19 |
+
float probThreshold, float iouThreshold,
|
| 20 |
+
int64_t maxRegions);
|
| 21 |
+
|
| 22 |
+
template<typename T>
|
| 23 |
+
struct EmbedQuad_ : public QuadBase_<T, EmbedQuad_<T> > {
|
| 24 |
+
Point_<T> Vertices[4];
|
| 25 |
+
T Confidence;
|
| 26 |
+
T NumQuads = 0;
|
| 27 |
+
|
| 28 |
+
__device__
|
| 29 |
+
EmbedQuad_(T confidence = 0)
|
| 30 |
+
{
|
| 31 |
+
Reset();
|
| 32 |
+
Confidence = confidence;
|
| 33 |
+
}
|
| 34 |
+
__device__
|
| 35 |
+
EmbedQuad_(const EmbedQuad_ &other) = default;
|
| 36 |
+
|
| 37 |
+
__device__
|
| 38 |
+
void swap(EmbedQuad_ &other) noexcept {
|
| 39 |
+
using std::swap;
|
| 40 |
+
|
| 41 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 42 |
+
swap(Vertices[i], other.Vertices[i]);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
SWAP(Confidence, other.Confidence);
|
| 46 |
+
SWAP(NumQuads, other.NumQuads);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
__device__
|
| 50 |
+
EmbedQuad_(EmbedQuad_ &&other) : EmbedQuad_() {
|
| 51 |
+
other.swap(*this);
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
__device__
|
| 55 |
+
EmbedQuad_ &operator=(EmbedQuad_ other) {
|
| 56 |
+
other.swap(*this);
|
| 57 |
+
return *this;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
__device__
|
| 61 |
+
void Append(const EmbedQuad_ &other) {
|
| 62 |
+
Append(other, other.Confidence, other.NumQuads);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
template<typename Derived>
|
| 66 |
+
__device__
|
| 67 |
+
void Append(const QuadBase_<T, Derived> &q, T conf, T numQuads = 1) {
|
| 68 |
+
Confidence *= NumQuads;
|
| 69 |
+
|
| 70 |
+
if (Confidence > 0) {
|
| 71 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 72 |
+
Vertices[i] *= Confidence;
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
Confidence += conf * numQuads;
|
| 77 |
+
|
| 78 |
+
auto qVertices = static_cast<const Derived *>(&q)->Vertices;
|
| 79 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 80 |
+
Vertices[i] += conf * numQuads * qVertices[i];
|
| 81 |
+
Vertices[i] /= Confidence;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
NumQuads += numQuads;
|
| 85 |
+
Confidence /= NumQuads;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
__device__
|
| 89 |
+
void Prepare() {
|
| 90 |
+
// T factor = 1.0 / Confidence;
|
| 91 |
+
// for (size_t i = 0; i < 4; ++i) {
|
| 92 |
+
// Vertices[i] *= factor;
|
| 93 |
+
// }
|
| 94 |
+
// Confidence /= numQuads;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
__device__
|
| 98 |
+
void Reset() {
|
| 99 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 100 |
+
Vertices[i] = Point_<T>{0, 0};
|
| 101 |
+
}
|
| 102 |
+
Confidence = 0.0f;
|
| 103 |
+
NumQuads = 0;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
__device__
|
| 107 |
+
const Point_<T> &operator[](size_t v) const { return Vertices[v]; }
|
| 108 |
+
__device__
|
| 109 |
+
Point_<T> &operator[](size_t v) { return Vertices[v]; }
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
struct ZeroInitTag {};
|
| 113 |
+
|
| 114 |
+
template<typename T>
|
| 115 |
+
struct MergeQuad_ : public QuadBase_<T, MergeQuad_<T>> {
|
| 116 |
+
Point_<T> Vertices[4];
|
| 117 |
+
T Confidence;
|
| 118 |
+
T NumQuads;
|
| 119 |
+
|
| 120 |
+
MergeQuad_() = default;
|
| 121 |
+
|
| 122 |
+
__device__
|
| 123 |
+
MergeQuad_(ZeroInitTag) : Confidence(0), NumQuads(0) {
|
| 124 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 125 |
+
Vertices[i] = Point_<T>{0, 0};
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
template<typename Derived>
|
| 130 |
+
__device__
|
| 131 |
+
void Append(const QuadBase_<T, Derived> &q, T conf) {
|
| 132 |
+
Confidence += conf;
|
| 133 |
+
++NumQuads;
|
| 134 |
+
|
| 135 |
+
auto &d = static_cast<const Derived&>(q);
|
| 136 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 137 |
+
Vertices[i] += conf * d[i];
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
__device__
|
| 141 |
+
void Append(const EmbedQuad_<T> &q) {
|
| 142 |
+
T qConf = q.NumQuads * q.Confidence;
|
| 143 |
+
|
| 144 |
+
Confidence += qConf;
|
| 145 |
+
NumQuads += q.NumQuads;
|
| 146 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 147 |
+
Vertices[i] += qConf * q.Vertices[i];
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
__device__
|
| 151 |
+
void Append(const StridedEmbedQuad_<T> &q) {
|
| 152 |
+
const T numQuads = q.NumQuads();
|
| 153 |
+
const T qConf = numQuads * q.Confidence();
|
| 154 |
+
|
| 155 |
+
Confidence += qConf;
|
| 156 |
+
NumQuads += numQuads;
|
| 157 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 158 |
+
Vertices[i] += qConf * q[i];
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
__device__
|
| 163 |
+
EmbedQuad_<T> Commit() {
|
| 164 |
+
EmbedQuad_<T> ret;
|
| 165 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 166 |
+
ret.Vertices[i] = Vertices[i] / Confidence;
|
| 167 |
+
}
|
| 168 |
+
ret.Confidence = Confidence / NumQuads;
|
| 169 |
+
ret.NumQuads = NumQuads;
|
| 170 |
+
|
| 171 |
+
return ret;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
__device__
|
| 175 |
+
const Point_<T> &operator[](size_t v) const { return Vertices[v]; }
|
| 176 |
+
__device__
|
| 177 |
+
Point_<T> &operator[](size_t v) { return Vertices[v]; }
|
| 178 |
+
};
|
| 179 |
+
|
| 180 |
+
template<typename T, typename Intermediate=float>
|
| 181 |
+
__device__
|
| 182 |
+
inline T triangle_root(T val)
|
| 183 |
+
{
|
| 184 |
+
// It's easier to visualize this algorithm for a lower triangular matrix
|
| 185 |
+
// What we're trying to find is the `row` of a lower triangular matrix that a given `val` resides in.
|
| 186 |
+
// e.g. 0->0, 2->1, 4->2, etc.
|
| 187 |
+
//
|
| 188 |
+
// 0: 0
|
| 189 |
+
// 1: 1 2
|
| 190 |
+
// 2: 3 4 5
|
| 191 |
+
// 3: 6 7 8 9
|
| 192 |
+
//
|
| 193 |
+
// See https://math.stackexchange.com/questions/698961/finding-the-triangular-root-of-a-number for explanation
|
| 194 |
+
Intermediate numer = Intermediate(-1) + sqrt(Intermediate(1) + Intermediate(8) * Intermediate(val));
|
| 195 |
+
Intermediate denom = Intermediate(2);
|
| 196 |
+
|
| 197 |
+
Intermediate ret = floor(numer / denom);
|
| 198 |
+
return T(ret);
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
template<typename T>
|
| 202 |
+
void visit_node(const std::vector<EmbedQuad_<T>> &allQuads, size_t quadIdx,
|
| 203 |
+
const std::vector<std::vector<size_t>> &adjIdxs, EmbedQuad_<T> &currQuad,
|
| 204 |
+
std::unordered_set<size_t> &visited)
|
| 205 |
+
{
|
| 206 |
+
if (visited.count(quadIdx) > 0) return;
|
| 207 |
+
|
| 208 |
+
const EmbedQuad_<T> &vQuad = allQuads[quadIdx];
|
| 209 |
+
|
| 210 |
+
currQuad.Append(vQuad);
|
| 211 |
+
visited.insert(quadIdx);
|
| 212 |
+
|
| 213 |
+
for (size_t childIdx : adjIdxs[quadIdx]) {
|
| 214 |
+
visit_node(allQuads, childIdx, adjIdxs, currQuad, visited);
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
template<typename T, typename Derived, typename scalar_t>
|
| 219 |
+
void copy_quad(const QuadBase_<T, Derived> &srcQuad, scalar_t *pDest)
|
| 220 |
+
{
|
| 221 |
+
auto vertices = static_cast<const Derived*>(&srcQuad)->Vertices;
|
| 222 |
+
for (size_t i = 0; i < 4; ++i) {
|
| 223 |
+
const Point_<T> &v = vertices[i];
|
| 224 |
+
*pDest++ = v.X;
|
| 225 |
+
*pDest++ = v.Y;
|
| 226 |
+
}
|
| 227 |
+
}
|
nemo-retriever-ocr/cpp/non_maximal_suppression/nms_kd_tree.h
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
// SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <memory>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <stack>
|
| 10 |
+
|
| 11 |
+
#include "../geometry.h"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
#define MODE_GEOMETRY 0x02ull
|
| 15 |
+
#define MODE_CHILDREN 0x00ull
|
| 16 |
+
|
| 17 |
+
#define DIM_X 0x0ull
|
| 18 |
+
#define DIM_Y 0x1ull
|
| 19 |
+
|
| 20 |
+
static const size_t INVALID_IDX = -1;
|
| 21 |
+
|
| 22 |
+
template<typename T>
|
| 23 |
+
struct NMS_BoundsWrapper
|
| 24 |
+
{
|
| 25 |
+
typedef std::unique_ptr<NMS_BoundsWrapper> Ptr;
|
| 26 |
+
typedef AABB_<typename T::inner_type> bds_t;
|
| 27 |
+
|
| 28 |
+
size_t GeoIdx;
|
| 29 |
+
const T *Geometry;
|
| 30 |
+
bds_t Bounds;
|
| 31 |
+
|
| 32 |
+
NMS_BoundsWrapper(size_t geoIdx, const T *geometry) : GeoIdx(geoIdx), Geometry(geometry), Bounds(geometry->Bounds()) { }
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
template<typename T>
|
| 36 |
+
class NMS_NodeAllocator;
|
| 37 |
+
|
| 38 |
+
template<typename T>
|
| 39 |
+
class NMS_KDTree;
|
| 40 |
+
|
| 41 |
+
template<typename T>
|
| 42 |
+
class NMS_BuildCache;
|
| 43 |
+
|
| 44 |
+
template<typename T>
|
| 45 |
+
class NMS_KDNode
|
| 46 |
+
{
|
| 47 |
+
friend class NMS_KDTree<T>;
|
| 48 |
+
|
| 49 |
+
public:
|
| 50 |
+
typedef NMS_BoundsWrapper<T> bds_t;
|
| 51 |
+
typedef std::unique_ptr<NMS_KDNode[]> UPtr;
|
| 52 |
+
typedef typename T::inner_type inner_type;
|
| 53 |
+
typedef std::vector<bds_t*> geo_vec_t;
|
| 54 |
+
typedef std::unique_ptr<geo_vec_t> geo_vec_ptr;
|
| 55 |
+
|
| 56 |
+
void Build(geo_vec_ptr geometries, const typename bds_t::bds_t &envelope,
|
| 57 |
+
NMS_NodeAllocator<T> &allocator, NMS_BuildCache<T> &buildCache);
|
| 58 |
+
|
| 59 |
+
template<typename Fn>
|
| 60 |
+
void FindIntersections(size_t geoIdx, const typename bds_t::bds_t &bds, const Fn &fn) const;
|
| 61 |
+
|
| 62 |
+
private:
|
| 63 |
+
inline uintptr_t Dim() const { return reinterpret_cast<uintptr_t>(m_ptr) & 0x01ull; }
|
| 64 |
+
inline uintptr_t Mode() const { return reinterpret_cast<uintptr_t>(m_ptr) & 0x02ull; }
|
| 65 |
+
inline void Children(NMS_KDNode *&children, inner_type &splitPos) const
|
| 66 |
+
{
|
| 67 |
+
auto vPtr = Geometries();
|
| 68 |
+
splitPos = *reinterpret_cast<inner_type*>(vPtr);
|
| 69 |
+
children = reinterpret_cast<NMS_KDNode*>(vPtr + sizeof(inner_type));
|
| 70 |
+
}
|
| 71 |
+
inline uint8_t* Geometries() const
|
| 72 |
+
{
|
| 73 |
+
return reinterpret_cast<uint8_t*>(reinterpret_cast<uintptr_t>(m_ptr) & ~0x3ull);
|
| 74 |
+
}
|
| 75 |
+
inline void SetPtr(uint8_t *vPtr, uintptr_t mode, uintptr_t dim)
|
| 76 |
+
{
|
| 77 |
+
m_ptr = reinterpret_cast<uint8_t*>(
|
| 78 |
+
reinterpret_cast<uintptr_t>(vPtr) | mode | dim
|
| 79 |
+
);
|
| 80 |
+
}
|
| 81 |
+
void AssignGeometries(geo_vec_ptr geometries, NMS_BuildCache<T> &buildCache);
|
| 82 |
+
|
| 83 |
+
uint8_t *m_ptr;
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
template<typename T>
|
| 87 |
+
class NMS_NodeAllocator
|
| 88 |
+
{
|
| 89 |
+
public:
|
| 90 |
+
typedef NMS_KDNode<T> node_t;
|
| 91 |
+
typedef typename node_t::inner_type inner_type;
|
| 92 |
+
|
| 93 |
+
NMS_NodeAllocator(size_t initialGuess = 512);
|
| 94 |
+
~NMS_NodeAllocator();
|
| 95 |
+
|
| 96 |
+
void Get(size_t numNodes, NMS_KDNode<T> *&outNodes, inner_type *&outSplitPos, uint8_t *&outRawPtr);
|
| 97 |
+
|
| 98 |
+
private:
|
| 99 |
+
std::vector<std::pair<size_t, uint8_t*>> m_buffers;
|
| 100 |
+
size_t m_offset;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
template<typename T>
|
| 104 |
+
class NMS_BuildCache
|
| 105 |
+
{
|
| 106 |
+
public:
|
| 107 |
+
typedef typename NMS_KDNode<T>::bds_t bds_t;
|
| 108 |
+
typedef std::unique_ptr<NMS_BuildCache> Ptr;
|
| 109 |
+
typedef std::vector<bds_t*> geo_vec_t;
|
| 110 |
+
typedef std::unique_ptr<geo_vec_t> geo_vec_ptr;
|
| 111 |
+
|
| 112 |
+
NMS_BuildCache(size_t initialSize);
|
| 113 |
+
~NMS_BuildCache();
|
| 114 |
+
|
| 115 |
+
geo_vec_ptr Get(size_t sizeHint);
|
| 116 |
+
bds_t** GetRawBuffer(size_t numGeos, uint8_t *&rawPtr);
|
| 117 |
+
|
| 118 |
+
void Release(geo_vec_ptr buff);
|
| 119 |
+
|
| 120 |
+
private:
|
| 121 |
+
std::stack<geo_vec_ptr> m_cache;
|
| 122 |
+
std::vector<std::pair<size_t, uint8_t*>> m_rawBuffers;
|
| 123 |
+
size_t m_rawOffset;
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
template<typename T>
|
| 128 |
+
class NMS_KDTree
|
| 129 |
+
{
|
| 130 |
+
typedef typename T::inner_type inner_type;
|
| 131 |
+
typedef NMS_BoundsWrapper<T> bds_t;
|
| 132 |
+
typedef NMS_KDNode<T> node_t;
|
| 133 |
+
|
| 134 |
+
public:
|
| 135 |
+
NMS_KDTree();
|
| 136 |
+
~NMS_KDTree();
|
| 137 |
+
|
| 138 |
+
void Build(const std::vector<T> &geometries);
|
| 139 |
+
|
| 140 |
+
template<typename Fn>
|
| 141 |
+
void FindIntersections(size_t geoIdx, const Fn &fn) const;
|
| 142 |
+
|
| 143 |
+
template<typename Fn>
|
| 144 |
+
void FindIntersections(const T &geo, const Fn &fn) const;
|
| 145 |
+
|
| 146 |
+
private:
|
| 147 |
+
bds_t *m_wrappers;
|
| 148 |
+
NMS_NodeAllocator<T> m_allocator;
|
| 149 |
+
node_t m_root;
|
| 150 |
+
typename NMS_BuildCache<T>::Ptr m_buildCache;
|
| 151 |
+
};
|
| 152 |
+
|
| 153 |
+
template<typename T>
|
| 154 |
+
NMS_KDTree<T>::NMS_KDTree()
|
| 155 |
+
: m_wrappers(nullptr)
|
| 156 |
+
{
|
| 157 |
+
m_root.m_ptr = nullptr;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
template<typename T>
|
| 161 |
+
NMS_KDTree<T>::~NMS_KDTree()
|
| 162 |
+
{
|
| 163 |
+
free(m_wrappers);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
template<typename T>
|
| 167 |
+
void NMS_KDTree<T>::Build(const std::vector<T> &geometries)
|
| 168 |
+
{
|
| 169 |
+
if (geometries.empty()) {
|
| 170 |
+
m_root.m_ptr = nullptr;
|
| 171 |
+
return;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
// Doing this so that we can perform placement-new on the array buffer, and thus
|
| 175 |
+
// can only perform a single memory allocation for all geometries at once
|
| 176 |
+
m_wrappers = reinterpret_cast<bds_t*>(malloc(sizeof(bds_t) * geometries.size()));
|
| 177 |
+
|
| 178 |
+
m_buildCache.reset(new NMS_BuildCache<T>(geometries.size()));
|
| 179 |
+
|
| 180 |
+
auto bdsGeos = m_buildCache->Get(geometries.size());
|
| 181 |
+
|
| 182 |
+
typename bds_t::bds_t envelope;
|
| 183 |
+
|
| 184 |
+
for (size_t i = 0; i < geometries.size(); ++i) {
|
| 185 |
+
// Placement new. Constructs the object in the place specified in the first (...)
|
| 186 |
+
new (m_wrappers + i) bds_t(i, &geometries[i]);
|
| 187 |
+
|
| 188 |
+
bdsGeos->push_back(m_wrappers + i);
|
| 189 |
+
if (i == 0) {
|
| 190 |
+
envelope = m_wrappers[i].Bounds;
|
| 191 |
+
} else {
|
| 192 |
+
envelope = envelope.Union(m_wrappers[i].Bounds);
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
m_root.Build(std::move(bdsGeos), envelope, m_allocator, *m_buildCache);
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
template<typename T>
|
| 201 |
+
void NMS_KDNode<T>::Build(geo_vec_ptr geometries, const typename bds_t::bds_t &envelope,
|
| 202 |
+
NMS_NodeAllocator<T> &allocator, NMS_BuildCache<T> &buildCache)
|
| 203 |
+
{
|
| 204 |
+
static const size_t MAX_GEOMETRIES = 8;
|
| 205 |
+
|
| 206 |
+
if (geometries->size() <= MAX_GEOMETRIES) {
|
| 207 |
+
AssignGeometries(std::move(geometries), buildCache);
|
| 208 |
+
} else {
|
| 209 |
+
geo_vec_ptr leftGeos = buildCache.Get(geometries->size()),
|
| 210 |
+
rightGeos = buildCache.Get(geometries->size());
|
| 211 |
+
|
| 212 |
+
inner_type szX = envelope[2] - envelope[0];
|
| 213 |
+
inner_type szY = envelope[3] - envelope[1];
|
| 214 |
+
|
| 215 |
+
int64_t dim = szX > szY ? 0 : 1;
|
| 216 |
+
auto emn = envelope[dim];
|
| 217 |
+
auto emx = envelope[dim + 2];
|
| 218 |
+
|
| 219 |
+
auto pivotPos = (emn + emx) / 2;
|
| 220 |
+
for (bds_t *g : *geometries) {
|
| 221 |
+
auto mn = g->Bounds[dim];
|
| 222 |
+
auto mx = g->Bounds[dim + 2];
|
| 223 |
+
|
| 224 |
+
if (mn < pivotPos) {
|
| 225 |
+
leftGeos->push_back(g);
|
| 226 |
+
}
|
| 227 |
+
if (mx > pivotPos) {
|
| 228 |
+
rightGeos->push_back(g);
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
if (leftGeos->size() == geometries->size() || rightGeos->size() == geometries->size()) {
|
| 233 |
+
AssignGeometries(std::move(geometries), buildCache);
|
| 234 |
+
buildCache.Release(std::move(leftGeos));
|
| 235 |
+
buildCache.Release(std::move(rightGeos));
|
| 236 |
+
} else {
|
| 237 |
+
buildCache.Release(std::move(geometries));
|
| 238 |
+
|
| 239 |
+
inner_type *nodeSplitPos;
|
| 240 |
+
uint8_t *nodeRawPtr;
|
| 241 |
+
NMS_KDNode *children;
|
| 242 |
+
allocator.Get(2, children, nodeSplitPos, nodeRawPtr);
|
| 243 |
+
|
| 244 |
+
SetPtr(nodeRawPtr, MODE_CHILDREN, dim);
|
| 245 |
+
*nodeSplitPos = pivotPos;
|
| 246 |
+
|
| 247 |
+
typename bds_t::bds_t leftEnv{envelope}, rightEnv{envelope};
|
| 248 |
+
// Set the max of the left envelope to the split plane
|
| 249 |
+
leftEnv[dim + 2] = pivotPos;
|
| 250 |
+
// Set the min of the right envelope to the split plane
|
| 251 |
+
rightEnv[dim] = pivotPos;
|
| 252 |
+
|
| 253 |
+
children[0].Build(std::move(leftGeos), leftEnv, allocator, buildCache);
|
| 254 |
+
children[1].Build(std::move(rightGeos), rightEnv, allocator, buildCache);
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
template<typename T>
|
| 260 |
+
void NMS_KDNode<T>::AssignGeometries(geo_vec_ptr geometries, NMS_BuildCache<T> &buildCache)
|
| 261 |
+
{
|
| 262 |
+
if (geometries->empty()) {
|
| 263 |
+
SetPtr(nullptr, MODE_GEOMETRY, 0);
|
| 264 |
+
} else {
|
| 265 |
+
uint8_t *vPtr;
|
| 266 |
+
bds_t **geoPtr = buildCache.GetRawBuffer(geometries->size(), vPtr);
|
| 267 |
+
std::copy(geometries->begin(), geometries->end(), geoPtr);
|
| 268 |
+
|
| 269 |
+
SetPtr(vPtr, MODE_GEOMETRY, 0);
|
| 270 |
+
}
|
| 271 |
+
buildCache.Release(std::move(geometries));
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
template<typename T>
|
| 275 |
+
template<typename Fn>
|
| 276 |
+
void NMS_KDTree<T>::FindIntersections(size_t geoIdx, const Fn &fn) const
|
| 277 |
+
{
|
| 278 |
+
if (!m_wrappers) return;
|
| 279 |
+
|
| 280 |
+
auto &bds = m_wrappers[geoIdx].Bounds;
|
| 281 |
+
|
| 282 |
+
m_root.FindIntersections(geoIdx, bds, fn);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
template<typename T>
|
| 286 |
+
template<typename Fn>
|
| 287 |
+
void NMS_KDTree<T>::FindIntersections(const T &geo, const Fn &fn) const
|
| 288 |
+
{
|
| 289 |
+
if (!m_wrappers) return;
|
| 290 |
+
|
| 291 |
+
NMS_BoundsWrapper<T> bdsWrapper(INVALID_IDX, &geo);
|
| 292 |
+
|
| 293 |
+
m_root.FindIntersections(INVALID_IDX, bdsWrapper.Bounds, fn);
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
template<typename T>
|
| 297 |
+
template<typename Fn>
|
| 298 |
+
void NMS_KDNode<T>::FindIntersections(size_t geoIdx, const typename bds_t::bds_t &bds, const Fn &fn) const
|
| 299 |
+
{
|
| 300 |
+
auto mode = Mode();
|
| 301 |
+
|
| 302 |
+
if (mode == MODE_GEOMETRY) {
|
| 303 |
+
auto *vPtr = Geometries();
|
| 304 |
+
|
| 305 |
+
size_t numGeos = *reinterpret_cast<size_t*>(vPtr);
|
| 306 |
+
bds_t **geoPtr = reinterpret_cast<bds_t**>(vPtr + sizeof(size_t));
|
| 307 |
+
|
| 308 |
+
bds_t **endPtr = geoPtr + numGeos;
|
| 309 |
+
for (; geoPtr != endPtr; ++geoPtr) {
|
| 310 |
+
const bds_t *child = *geoPtr;
|
| 311 |
+
|
| 312 |
+
// Don't compute this against self
|
| 313 |
+
if (geoIdx != INVALID_IDX && child->GeoIdx <= geoIdx) continue;
|
| 314 |
+
|
| 315 |
+
typename bds_t::bds_t::inner_type pctN, pctM, iou;
|
| 316 |
+
std::tie(pctN, pctM, iou) = geometry_region_sizes(bds, child->Bounds);
|
| 317 |
+
|
| 318 |
+
if (iou > 0) {
|
| 319 |
+
fn(child->GeoIdx, pctN, pctM, iou);
|
| 320 |
+
}
|
| 321 |
+
}
|
| 322 |
+
} else {
|
| 323 |
+
auto dim = Dim();
|
| 324 |
+
|
| 325 |
+
auto mn = bds[dim];
|
| 326 |
+
auto mx = bds[dim + 2];
|
| 327 |
+
|
| 328 |
+
NMS_KDNode *children;
|
| 329 |
+
inner_type splitPos;
|
| 330 |
+
Children(children, splitPos);
|
| 331 |
+
|
| 332 |
+
if (mn < splitPos) {
|
| 333 |
+
children[0].FindIntersections(geoIdx, bds, fn);
|
| 334 |
+
}
|
| 335 |
+
if (mx > splitPos) {
|
| 336 |
+
children[1].FindIntersections(geoIdx, bds, fn);
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
template<typename T>
|
| 342 |
+
NMS_NodeAllocator<T>::NMS_NodeAllocator(size_t initialGuess)
|
| 343 |
+
: m_offset(0)
|
| 344 |
+
{
|
| 345 |
+
size_t allocSize = initialGuess * (sizeof(inner_type) + 2 * sizeof(node_t));
|
| 346 |
+
auto ptr = reinterpret_cast<uint8_t*>(malloc(allocSize));
|
| 347 |
+
m_buffers.emplace_back(initialGuess, ptr);
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
template<typename T>
|
| 351 |
+
NMS_NodeAllocator<T>::~NMS_NodeAllocator()
|
| 352 |
+
{
|
| 353 |
+
for (auto &p : m_buffers) {
|
| 354 |
+
free(p.second);
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
template<typename T>
|
| 359 |
+
void NMS_NodeAllocator<T>::Get(size_t numNodes, node_t *&outNodes, inner_type *&outSplitPos, uint8_t *&outRawPtr)
|
| 360 |
+
{
|
| 361 |
+
auto &currBuff = m_buffers.back();
|
| 362 |
+
|
| 363 |
+
size_t rem = currBuff.first - m_offset;
|
| 364 |
+
|
| 365 |
+
size_t reqSize = sizeof(inner_type) + sizeof(node_t) * numNodes;
|
| 366 |
+
|
| 367 |
+
if (rem >= reqSize) {
|
| 368 |
+
outRawPtr = currBuff.second + m_offset;
|
| 369 |
+
outSplitPos = reinterpret_cast<inner_type*>(outRawPtr);
|
| 370 |
+
outNodes = reinterpret_cast<node_t*>(outRawPtr + sizeof(inner_type));
|
| 371 |
+
m_offset += reqSize;
|
| 372 |
+
return;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
// Rounds up to the nearest factor of 2
|
| 376 |
+
size_t allocSize = (std::max(currBuff.first * 2, reqSize) + 1) & ~0x01ull;
|
| 377 |
+
auto ptr = reinterpret_cast<uint8_t*>(malloc(allocSize));
|
| 378 |
+
m_buffers.emplace_back(allocSize, ptr);
|
| 379 |
+
m_offset = 0;
|
| 380 |
+
|
| 381 |
+
Get(numNodes, outNodes, outSplitPos, outRawPtr);
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
template<typename T>
|
| 385 |
+
NMS_BuildCache<T>::NMS_BuildCache(size_t initialSize)
|
| 386 |
+
: m_rawOffset(0)
|
| 387 |
+
{
|
| 388 |
+
auto allocSize = sizeof(bds_t*) * initialSize * 2;
|
| 389 |
+
auto raw1 = reinterpret_cast<uint8_t*>(malloc(allocSize));
|
| 390 |
+
m_rawBuffers.emplace_back(allocSize, raw1);
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
template<typename T>
|
| 394 |
+
NMS_BuildCache<T>::~NMS_BuildCache()
|
| 395 |
+
{
|
| 396 |
+
for (auto &p : m_rawBuffers) {
|
| 397 |
+
free(p.second);
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
template<typename T>
|
| 402 |
+
typename NMS_BuildCache<T>::geo_vec_ptr NMS_BuildCache<T>::Get(size_t sizeHint)
|
| 403 |
+
{
|
| 404 |
+
geo_vec_ptr ret;
|
| 405 |
+
if (! m_cache.empty()) {
|
| 406 |
+
ret = std::move(m_cache.top());
|
| 407 |
+
m_cache.pop();
|
| 408 |
+
ret->clear();
|
| 409 |
+
} else {
|
| 410 |
+
ret.reset(new std::vector<bds_t*>);
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
ret->reserve(sizeHint);
|
| 414 |
+
return ret;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
template<typename T>
|
| 418 |
+
typename NMS_BuildCache<T>::bds_t** NMS_BuildCache<T>::GetRawBuffer(size_t numGeos, uint8_t *&rawPtr)
|
| 419 |
+
{
|
| 420 |
+
auto &currBuff = m_rawBuffers.back();
|
| 421 |
+
size_t rem = currBuff.first - m_rawOffset;
|
| 422 |
+
|
| 423 |
+
size_t reqSize = sizeof(size_t) + sizeof(bds_t*) * numGeos;
|
| 424 |
+
|
| 425 |
+
if (rem >= reqSize) {
|
| 426 |
+
rawPtr = currBuff.second + m_rawOffset;
|
| 427 |
+
m_rawOffset += reqSize;
|
| 428 |
+
reinterpret_cast<size_t*>(rawPtr)[0] = numGeos;
|
| 429 |
+
return reinterpret_cast<bds_t**>(rawPtr + sizeof(size_t));
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
size_t allocSize = (std::max(currBuff.first * 2, reqSize) + 1) & ~0x01ull;
|
| 433 |
+
auto ptr = reinterpret_cast<uint8_t*>(malloc(allocSize));
|
| 434 |
+
m_rawBuffers.emplace_back(allocSize, ptr);
|
| 435 |
+
m_rawOffset = 0;
|
| 436 |
+
|
| 437 |
+
return GetRawBuffer(numGeos, rawPtr);
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
template<typename T>
|
| 441 |
+
void NMS_BuildCache<T>::Release(geo_vec_ptr buff)
|
| 442 |
+
{
|
| 443 |
+
m_cache.push(std::move(buff));
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
#undef MODE_GEOMETRY
|
| 447 |
+
#undef MODE_CHILDREN
|
| 448 |
+
#undef DIM_X
|
| 449 |
+
#undef DIM_Y
|