Theo Viel commited on
Commit
98a67a0
·
1 Parent(s): 694c514

add weights and code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. checkpoints/charset.txt +1 -0
  2. checkpoints/detector.pth +3 -0
  3. checkpoints/recognizer.pth +3 -0
  4. checkpoints/relational.pth +3 -0
  5. example.py +43 -0
  6. nemo-retriever-ocr/cpp/.gitattributes +1 -0
  7. nemo-retriever-ocr/cpp/.gitignore +6 -0
  8. nemo-retriever-ocr/cpp/.gitmodules +3 -0
  9. nemo-retriever-ocr/cpp/README.md +15 -0
  10. nemo-retriever-ocr/cpp/beam_decode/beam_decode.cpp +460 -0
  11. nemo-retriever-ocr/cpp/beam_decode/beam_decode.h +18 -0
  12. nemo-retriever-ocr/cpp/beam_decode/kn_lm.cpp +86 -0
  13. nemo-retriever-ocr/cpp/beam_decode/kn_lm.h +27 -0
  14. nemo-retriever-ocr/cpp/beam_decode/language_model.cpp +147 -0
  15. nemo-retriever-ocr/cpp/beam_decode/language_model.h +66 -0
  16. nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.cpp +7 -0
  17. nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.h +54 -0
  18. nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.cpp +330 -0
  19. nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.h +80 -0
  20. nemo-retriever-ocr/cpp/beam_decode/prefix.cpp +23 -0
  21. nemo-retriever-ocr/cpp/beam_decode/prefix.h +158 -0
  22. nemo-retriever-ocr/cpp/beam_decode/sbo_lm.cpp +47 -0
  23. nemo-retriever-ocr/cpp/beam_decode/sbo_lm.h +21 -0
  24. nemo-retriever-ocr/cpp/better_grid_sample/cpu_indirect_grid_sample.cpp +94 -0
  25. nemo-retriever-ocr/cpp/better_grid_sample/gpu_grid_sample_utils.cuh +42 -0
  26. nemo-retriever-ocr/cpp/better_grid_sample/gpu_indirect_grid_sample.cu +328 -0
  27. nemo-retriever-ocr/cpp/better_grid_sample/grid_sample.h +67 -0
  28. nemo-retriever-ocr/cpp/common.cpp +13 -0
  29. nemo-retriever-ocr/cpp/common.h +58 -0
  30. nemo-retriever-ocr/cpp/cuda_intellisense.cuh +51 -0
  31. nemo-retriever-ocr/cpp/geometry.h +1101 -0
  32. nemo-retriever-ocr/cpp/geometry_api/calc_poly_min_rrect.cpp +165 -0
  33. nemo-retriever-ocr/cpp/geometry_api/geometry_api.cpp +101 -0
  34. nemo-retriever-ocr/cpp/geometry_api/geometry_api.h +16 -0
  35. nemo-retriever-ocr/cpp/geometry_api/geometry_api_common.h +121 -0
  36. nemo-retriever-ocr/cpp/geometry_api/geometry_api_gpu.cu +142 -0
  37. nemo-retriever-ocr/cpp/geometry_api/get_rel_continuation_cos.cpp +60 -0
  38. nemo-retriever-ocr/cpp/geometry_api/matrix2x2.h +93 -0
  39. nemo-retriever-ocr/cpp/geometry_api/poly_bounds_quad.cpp +61 -0
  40. nemo-retriever-ocr/cpp/graph_detection/encode_util.cpp +272 -0
  41. nemo-retriever-ocr/cpp/graph_detection/encode_util.h +184 -0
  42. nemo-retriever-ocr/cpp/half_ops.cu +5 -0
  43. nemo-retriever-ocr/cpp/half_ops.cuh +149 -0
  44. nemo-retriever-ocr/cpp/local_ips/local_ips.h +11 -0
  45. nemo-retriever-ocr/cpp/local_ips/quad_all_2_all_dist_v2.cu +162 -0
  46. nemo-retriever-ocr/cpp/module.cpp +125 -0
  47. nemo-retriever-ocr/cpp/non_maximal_suppression/cpu_non_maximal_suppression.cpp +209 -0
  48. nemo-retriever-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu +1720 -0
  49. nemo-retriever-ocr/cpp/non_maximal_suppression/nms_common.h +227 -0
  50. nemo-retriever-ocr/cpp/non_maximal_suppression/nms_kd_tree.h +449 -0
checkpoints/charset.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ [" ", "!", "\"", "#", "$", "%", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ":", ";", "<", "=", ">", "?", "@", "A", "B", "C", "D", "E", "F", "FI", "G", "H", "I", "I\u0307", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "SS", "T", "U", "V", "W", "X", "Y", "Z", "[", "\\", "]", "^", "_", "`", "a", "b", "c", "d", "e", "f", "fi", "g", "h", "i", "i\u0307", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "ss", "t", "u", "v", "w", "x", "y", "z", "{", "|", "}", "~", "\u00b2", "\u00b3", "\u00b5", "\u00b9", "\u00ba", "\u00c0", "\u00c1", "\u00c2", "\u00c3", "\u00c4", "\u00c5", "\u00c6", "\u00c7", "\u00c8", "\u00c9", "\u00ca", "\u00cb", "\u00cc", "\u00cd", "\u00ce", "\u00cf", "\u00d0", "\u00d1", "\u00d2", "\u00d3", "\u00d4", "\u00d5", "\u00d6", "\u00d8", "\u00d9", "\u00da", "\u00db", "\u00dc", "\u00dd", "\u00de", "\u00df", "\u00e0", "\u00e1", "\u00e2", "\u00e3", "\u00e4", "\u00e5", "\u00e6", "\u00e7", "\u00e8", "\u00e9", "\u00ea", "\u00eb", "\u00ec", "\u00ed", "\u00ee", "\u00ef", "\u00f0", "\u00f1", "\u00f2", "\u00f3", "\u00f4", "\u00f5", "\u00f6", "\u00f8", "\u00f9", "\u00fa", "\u00fb", "\u00fc", "\u00fd", "\u00fe", "\u00ff", "\u0100", "\u0101", "\u0102", "\u0103", "\u0104", "\u0105", "\u0106", "\u0107", "\u010c", "\u010d", "\u010e", "\u010f", "\u0110", "\u0111", "\u0112", "\u0113", "\u0116", "\u0117", "\u0118", "\u0119", "\u011a", "\u011b", "\u011e", "\u011f", "\u0120", "\u0121", "\u0126", "\u0127", "\u0128", "\u0129", "\u012a", "\u012b", "\u0130", "\u0131", "\u0136", "\u0137", "\u013d", "\u013e", "\u0141", "\u0142", "\u0143", "\u0144", "\u0145", "\u0146", "\u0147", "\u0148", "\u014a", "\u014b", "\u014c", "\u014d", "\u014e", "\u014f", "\u0150", "\u0151", "\u0152", "\u0153", "\u0158", "\u0159", "\u015a", "\u015b", "\u015e", "\u015f", "\u0160", "\u0161", "\u0162", "\u0163", "\u0164", "\u0165", "\u0168", "\u0169", "\u016a", "\u016b", "\u016c", "\u016d", "\u016e", "\u016f", "\u0172", "\u0173", "\u0174", "\u0175", "\u0176", "\u0177", "\u0178", "\u0179", "\u017a", "\u017b", "\u017c", "\u017d", "\u017e", "\u0181", "\u0186", "\u0189", "\u018a", "\u018f", "\u0190", "\u0191", "\u0192", "\u0194", "\u0197", "\u019c", "\u019d", "\u019f", "\u01a0", "\u01a1", "\u01a6", "\u01a9", "\u01ae", "\u01af", "\u01b0", "\u01b1", "\u01b2", "\u01b7", "\u01c2", "\u01cd", "\u01ce", "\u01cf", "\u01d0", "\u01d1", "\u01d2", "\u01d3", "\u01d4", "\u01ea", "\u01eb", "\u0218", "\u0219", "\u021a", "\u021b", "\u0245", "\u0250", "\u0251", "\u0252", "\u0253", "\u0254", "\u0255", "\u0256", "\u0257", "\u0259", "\u025b", "\u025f", "\u0261", "\u0262", "\u0263", "\u0266", "\u0267", "\u0268", "\u026a", "\u026c", "\u026f", "\u0272", "\u0274", "\u0275", "\u0278", "\u027b", "\u027e", "\u0280", "\u0281", "\u0282", "\u0283", "\u0287", "\u0288", "\u028a", "\u028b", "\u028c", "\u028d", "\u028e", "\u0292", "\u0294", "\u0295", "\u0298", "\u029d", "\u029f", "\u02b0", "\u02b2", "\u02b7", "\u02bb", "\u02bc", "\u02be", "\u02bf", "\u02c0", "\u02c1", "\u02c8", "\u02cc", "\u02d0", "\u02e0", "\u02e4", "\u0386", "\u0388", "\u038a", "\u038c", "\u038e", "\u038f", "\u0391", "\u0391\u0342", "\u0392", "\u0393", "\u0394", "\u0395", "\u0396", "\u0397", "\u0397\u0342", "\u0398", "\u0399", "\u0399\u0342", "\u039a", "\u039b", "\u039c", "\u039d", "\u039e", "\u039f", "\u03a0", "\u03a1", "\u03a3", "\u03a4", "\u03a5", "\u03a5\u0313", "\u03a5\u0342", "\u03a6", "\u03a7", "\u03a8", "\u03a9", "\u03a9\u0342", "\u03a9\u0342\u0399", "\u03ac", "\u03ad", "\u03af", "\u03b1", "\u03b1\u0342", "\u03b2", "\u03b3", "\u03b4", "\u03b5", "\u03b6", "\u03b7", "\u03b7\u0342", "\u03b8", "\u03b9", "\u03b9\u0342", "\u03ba", "\u03bb", "\u03bc", "\u03bd", "\u03be", "\u03bf", "\u03c0", "\u03c1", "\u03c2", "\u03c3", "\u03c4", "\u03c5", "\u03c5\u0313", "\u03c5\u0342", "\u03c6", "\u03c7", "\u03c8", "\u03c9", "\u03c9\u0342", "\u03c9\u0342\u03b9", "\u03cc", "\u03cd", "\u03ce", "\u03d5", "\u0401", "\u0406", "\u0408", "\u0410", "\u0411", "\u0412", "\u0413", "\u0414", "\u0415", "\u0416", "\u0417", "\u0418", "\u0419", "\u041a", "\u041b", "\u041c", "\u041d", "\u041e", "\u041f", "\u0420", "\u0421", "\u0422", "\u0423", "\u0425", "\u0426", "\u0427", "\u0428", "\u042a", "\u042b", "\u042c", "\u042d", "\u042e", "\u042f", "\u0430", "\u0431", "\u0432", "\u0433", "\u0434", "\u0435", "\u0436", "\u0437", "\u0438", "\u0439", "\u043a", "\u043b", "\u043c", "\u043d", "\u043e", "\u043f", "\u0440", "\u0441", "\u0442", "\u0443", "\u0445", "\u0446", "\u0447", "\u0448", "\u044a", "\u044b", "\u044c", "\u044d", "\u044e", "\u044f", "\u0451", "\u0456", "\u0458", "\u05b5", "\u05b6", "\u05bc", "\u05d0", "\u05d1", "\u05d2", "\u05d3", "\u05d5", "\u05d7", "\u05d9", "\u05dc", "\u05dd", "\u05de", "\u05e0", "\u05e1", "\u05e2", "\u05e6", "\u05e8", "\u05e9", "\u05ea", "\u0621", "\u0623", "\u0625", "\u0627", "\u0628", "\u0629", "\u062a", "\u062c", "\u062d", "\u062e", "\u062f", "\u0631", "\u0632", "\u0633", "\u0634", "\u0635", "\u0637", "\u0639", "\u063a", "\u0641", "\u0642", "\u0643", "\u0644", "\u0645", "\u0646", "\u0647", "\u0648", "\u064a", "\u06cc", "\u0902", "\u0905", "\u0906", "\u0909", "\u0915", "\u0917", "\u091f", "\u0921", "\u0924", "\u0926", "\u0928", "\u092a", "\u092c", "\u092d", "\u092e", "\u092f", "\u0930", "\u0932", "\u0936", "\u0937", "\u0938", "\u0939", "\u093e", "\u093f", "\u0940", "\u0947", "\u094b", "\u0995", "\u09a4", "\u09b2", "\u09be", "\u09bf", "\u0b95", "\u0ba9", "\u0bb3", "\u0e02", "\u0e07", "\u0e08", "\u0e0a", "\u0e10", "\u0e15", "\u0e17", "\u0e19", "\u0e1b", "\u0e1e", "\u0e23", "\u0e27", "\u0e30", "\u0e31", "\u0e32", "\u0e40", "\u0e41", "\u16c3", "\u16cb", "\u16df", "\u1e0c", "\u1e0d", "\u1e24", "\u1e25", "\u1e36", "\u1e37", "\u1e3a", "\u1e3b", "\u1e42", "\u1e43", "\u1e44", "\u1e45", "\u1e46", "\u1e47", "\u1e48", "\u1e49", "\u1e5a", "\u1e5b", "\u1e5e", "\u1e5f", "\u1e62", "\u1e63", "\u1e6c", "\u1e6d", "\u1e6e", "\u1e6f", "\u1ea0", "\u1ea1", "\u1ea2", "\u1ea3", "\u1ea4", "\u1ea5", "\u1ea6", "\u1ea7", "\u1ea8", "\u1ea9", "\u1eaa", "\u1eab", "\u1eac", "\u1ead", "\u1eae", "\u1eaf", "\u1eb4", "\u1eb5", "\u1eb6", "\u1eb7", "\u1eb8", "\u1eb9", "\u1ebe", "\u1ebf", "\u1ec2", "\u1ec3", "\u1ec4", "\u1ec5", "\u1ec6", "\u1ec7", "\u1eca", "\u1ecb", "\u1ecc", "\u1ecd", "\u1ece", "\u1ecf", "\u1ed0", "\u1ed1", "\u1ed2", "\u1ed3", "\u1ed4", "\u1ed5", "\u1ed6", "\u1ed7", "\u1ed8", "\u1ed9", "\u1eda", "\u1edb", "\u1edc", "\u1edd", "\u1ede", "\u1edf", "\u1ee2", "\u1ee3", "\u1ee4", "\u1ee5", "\u1ee6", "\u1ee7", "\u1ee8", "\u1ee9", "\u1eea", "\u1eeb", "\u1eec", "\u1eed", "\u1eee", "\u1eef", "\u1ef0", "\u1ef1", "\u1ef2", "\u1ef3", "\u1ef4", "\u1ef5", "\u1ef8", "\u1ef9", "\u1f00", "\u1f04", "\u1f08", "\u1f0c", "\u1f10", "\u1f15", "\u1f18", "\u1f1d", "\u1f20", "\u1f21", "\u1f28", "\u1f29", "\u1f30", "\u1f31", "\u1f38", "\u1f39", "\u1f41", "\u1f44", "\u1f49", "\u1f4c", "\u1f50", "\u1f51", "\u1f59", "\u1f61", "\u1f69", "\u1f70", "\u1f72", "\u1f74", "\u1f76", "\u1f78", "\u1f7a", "\u1f7c", "\u1fb6", "\u1fba", "\u1fc6", "\u1fc8", "\u1fca", "\u1fd6", "\u1fda", "\u1fe6", "\u1fea", "\u1ff6", "\u1ff7", "\u1ff8", "\u1ffa", "\u2081", "\u2082", "\u2083", "\u2113", "\u2460", "\u2461", "\u2463", "\u2c6d", "\u2c6f", "\u2c70", "\u3044", "\u3045", "\u3046", "\u304a", "\u304b", "\u304d", "\u304f", "\u3050", "\u3053", "\u3057", "\u3059", "\u305b", "\u305f", "\u3064", "\u3069", "\u306e", "\u3070", "\u307d", "\u3088", "\u3089", "\u3093", "\u30a1", "\u30a2", "\u30a3", "\u30a4", "\u30a6", "\u30a7", "\u30a8", "\u30a9", "\u30aa", "\u30ab", "\u30ac", "\u30af", "\u30b0", "\u30b3", "\u30b4", "\u30b5", "\u30b6", "\u30b7", "\u30b8", "\u30b9", "\u30ba", "\u30bb", "\u30bc", "\u30bd", "\u30bf", "\u30c1", "\u30c3", "\u30c4", "\u30c6", "\u30c7", "\u30c8", "\u30c9", "\u30ca", "\u30cb", "\u30ce", "\u30cf", "\u30d0", "\u30d1", "\u30d2", "\u30d3", "\u30d5", "\u30d6", "\u30d7", "\u30d9", "\u30da", "\u30dc", "\u30de", "\u30df", "\u30e1", "\u30e3", "\u30e4", "\u30e5", "\u30e6", "\u30e9", "\u30ea", "\u30eb", "\u30ec", "\u30ed", "\u30ef", "\u30f3", "\u30f4", "\u30fc", "\ua7aa", "\ua7ac", "\ua7ad", "\ua7ae", "\ua7b1", "\ua7b2", "\ua7c5", "\uac70", "\ub9c8", "\ub9c9", "\ub9d0", "\uc0ac", "\uc778", "\uc804", "\uc9c0", "\uc9d3", "\ud22c", "\ufb01"]
checkpoints/detector.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b7d50c74b2dba9acb8dd76d2fbcf75e6eeae0cb3e9688edf42c91aa5550ade1
3
+ size 181677320
checkpoints/recognizer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db307d9b0dcb6cd15ab6c71e302fd62ca90ce077c3013c9f63a4ba0dbfdf3f50
3
+ size 19823477
checkpoints/relational.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1db5a62853269aabd8a040eeb05038a871032e8275def77653631657cb8ca4a
3
+ size 9048309
example.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import argparse
7
+
8
+ from nemo_retriever_ocr.inference.pipeline import NemoRetrieverOCR
9
+
10
+
11
+ def main(image_path, merge_level, no_visualize, model_dir):
12
+ ocr_pipeline = NemoRetrieverOCR()
13
+
14
+ predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
15
+
16
+ print(f"Found {len(predictions)} text regions.")
17
+
18
+
19
+ if __name__ == "__main__":
20
+ parser = argparse.ArgumentParser(description="Run OCR inference and annotate image.")
21
+ parser.add_argument("image_path", type=str, help="Path to the input image.")
22
+ parser.add_argument(
23
+ "--merge-level",
24
+ type=str,
25
+ choices=["word", "sentence", "paragraph"],
26
+ default="paragraph",
27
+ help="Merge level for OCR output (word, sentence, paragraph).",
28
+ )
29
+ parser.add_argument("--no-visualize", action="store_true", help="Do not save the annotated image.")
30
+ parser.add_argument(
31
+ "--model-dir",
32
+ type=str,
33
+ help="Path to the model checkpoints.",
34
+ default="./checkpoints",
35
+ )
36
+ args = parser.parse_args()
37
+
38
+ main(
39
+ args.image_path,
40
+ merge_level=args.merge_level,
41
+ no_visualize=args.no_visualize,
42
+ model_dir=args.model_dir,
43
+ )
nemo-retriever-ocr/cpp/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ load_png/wuffs-v0.3.c filter=lfs diff=lfs merge=lfs -text
nemo-retriever-ocr/cpp/.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .vscode
3
+ build
4
+ *.egg-info
5
+ dist
6
+ .vs
nemo-retriever-ocr/cpp/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "trove"]
2
+ path = trove
3
+ url = https://github.com/bryancatanzaro/trove.git
nemo-retriever-ocr/cpp/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Optimized Image Operations for PyTorch
2
+
3
+ ## Installation
4
+
5
+ ```
6
+ python setup.py install
7
+ ```
8
+
9
+ ## Usage
10
+
11
+ ```
12
+ # It's important that you do this first
13
+ import torch
14
+ from pytorch_image_ops import color_transform, spatial_transform
15
+ ```
nemo-retriever-ocr/cpp/beam_decode/beam_decode.cpp ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "beam_decode.h"
6
+
7
+ #include <vector>
8
+ #include <deque>
9
+ #include <limits>
10
+ #include <memory>
11
+ #include <unordered_set>
12
+ #include <set>
13
+ #include <algorithm>
14
+ #include <chrono>
15
+
16
+ #include "../common.h"
17
+ #include "prefix.h"
18
+ #include "log_sum_exp.h"
19
+ #include "sbo_lm.h"
20
+
21
+ using namespace std;
22
+
23
+ template<typename scalar_t>
24
+ using pred_seq_t = torch::TensorAccessor<scalar_t, 2>;
25
+
26
+ struct PrefixScore
27
+ {
28
+ float_t lProbBlank;
29
+ float_t lProbChar;
30
+ // float_t raw_lProbBlank;
31
+ // float_t raw_lProbChar;
32
+ mutable float_t _lProb;
33
+
34
+ PrefixScore(float_t lProbBlank = NEG_INF /* log P(0) */, float_t lProbChar = NEG_INF /* log P(0) */)
35
+ : lProbBlank(lProbBlank), lProbChar(lProbChar), _lProb(NEG_INF)
36
+ // , raw_lProbBlank(lProbBlank), raw_lProbChar(lProbChar)
37
+ {}
38
+
39
+ float_t get_lScore() const {
40
+ if (_lProb == NEG_INF) {
41
+ _lProb = log_sum_exp(lProbBlank, lProbChar);
42
+ }
43
+ return _lProb;
44
+ }
45
+
46
+ // float_t get_raw_lScore() const {
47
+ // return log_sum_exp(raw_lProbBlank, raw_lProbChar);
48
+ // }
49
+ };
50
+
51
+ typedef std::unordered_map<Prefix*, PrefixScore> PrefixMap;
52
+ typedef std::pair<Prefix*, PrefixScore> BeamItem;
53
+ typedef std::vector<BeamItem> Beam;
54
+
55
+ /*
56
+ Allows us to get an estimate of the vision model confidence, irrespective of how the language
57
+ model guided the decoding. NOTE: This scoring could follow an entirely different path than
58
+ the returned decoded sequence.
59
+ */
60
+ template<typename scalar_t>
61
+ scalar_t get_vision_confidence(const pred_seq_t<scalar_t> &logProbs, scalar_t minProb)
62
+ {
63
+ const int64_t T = logProbs.size(0);
64
+ const int64_t S = logProbs.size(1);
65
+
66
+ scalar_t ret = 0; // log(1)
67
+
68
+ for (size_t t = 0; t < T; ++t) {
69
+ float_t maxP = logProbs[t][0];
70
+ int64_t maxC = 0;
71
+ for (int64_t c = 1; c < S; ++c) {
72
+ float_t p = logProbs[t][c];
73
+ if (p > maxP) {
74
+ maxP = p;
75
+ maxC = c;
76
+ }
77
+ }
78
+ ret += maxP;
79
+ // Ignore everything past the sequence terminator
80
+ if (maxC == 1) {
81
+ break;
82
+ }
83
+
84
+ if (ret < minProb) {
85
+ break;
86
+ }
87
+ }
88
+
89
+ return ret;
90
+ }
91
+
92
+
93
+ template<typename scalar_t>
94
+ pair<vector<token_t>, float_t>
95
+ ctc_beam_decode_impl(const pred_seq_t<scalar_t> &probs, const int64_t beamSize,
96
+ const int64_t blank, scalar_t minProb,
97
+ const LanguageModel &langModel, scalar_t lmWeight)
98
+ {
99
+ if (blank != 0) {
100
+ throw runtime_error("Currently, only ordinal 0 supported for the blank prediction");
101
+ }
102
+
103
+ const int64_t T = probs.size(0);
104
+ const int64_t S = probs.size(1);
105
+
106
+ // NOTE: In log space, the following is true:
107
+ // 1. Adding two probabilities: log_sum_exp(l_p_a, l_p_b)
108
+ // 2. Multiplying two probabilities: l_p_a + l_p_b
109
+ // 3. log P(0) = -inf
110
+ // 4. log P(1) = 0
111
+
112
+ // Convert to log-space
113
+ if (minProb > 0) {
114
+ minProb = log(minProb);
115
+ } else {
116
+ minProb = NEG_INF;
117
+ }
118
+
119
+ auto retScore = get_vision_confidence(probs, minProb);
120
+
121
+ if (retScore < minProb) {
122
+ return { {}, NEG_INF };
123
+ }
124
+
125
+ PrefixAllocator prefixAlloc;
126
+
127
+ Beam beam;
128
+ beam.emplace_back(prefixAlloc.GetPrefix(), PrefixScore{0, NEG_INF}); // Add a dummy first node
129
+
130
+ Beam terminated;
131
+
132
+ typedef tuple<Prefix*, token_t> lm_cache_key_t;
133
+ unordered_map<lm_cache_key_t, float_t> lmScoreCache;
134
+
135
+ for (int64_t t = 0; t < T; ++t) {
136
+ PrefixMap nextBeam;
137
+
138
+ // Add all of the completed paths to the next beam.
139
+ // This allows us to accumulate new paths into these,
140
+ // but otherwise not process them
141
+ for (const BeamItem &prevNode : beam) {
142
+ if (prevNode.first->Token == 1) {
143
+ nextBeam.insert(prevNode);
144
+ }
145
+ }
146
+
147
+ // Loop over vocab
148
+ for (int64_t s = 0; s < S; ++s) {
149
+ float_t lpEmit = probs[t][s];
150
+
151
+ if (lpEmit < minProb) {
152
+ continue;
153
+ }
154
+
155
+ for (const BeamItem &prevNode : beam) {
156
+ Prefix *prevPrefix = prevNode.first;
157
+ const PrefixScore &prevScore = prevNode.second;
158
+
159
+ // Ignore already completed paths
160
+ if (prevPrefix->Token == 1) {
161
+ continue;
162
+ }
163
+
164
+ // Ignore impossible paths
165
+ if (prevScore.lProbBlank == NEG_INF && prevScore.lProbChar == NEG_INF) {
166
+ continue;
167
+ }
168
+
169
+ // If we propose a blank the prefix doesn't change.
170
+ // Only the probability of ending in blank gets updated.
171
+ if (s == blank) {
172
+ PrefixScore &score = nextBeam[prevPrefix];
173
+ score.lProbBlank = log_sum_exp(score.lProbBlank , prevScore.lProbBlank + lpEmit, prevScore.lProbChar + lpEmit);
174
+ // score.raw_lProbBlank = log_sum_exp(score.raw_lProbBlank, prevScore.raw_lProbBlank + lpEmit, prevScore.raw_lProbChar + lpEmit);
175
+ continue;
176
+ }
177
+
178
+ // Extend the prefix by the new character s and add it to the beam.
179
+ // Only the probability of not ending in blank gets updated.
180
+ token_t prevToken = prevPrefix->Token;
181
+
182
+ // NOTE: We always create a new prefix regardless of duplication because the PrefixScore
183
+ // is simultaneously tracking prefixes that do and don't end in a blank. And it's those
184
+ // that end in a blank that would cause the prefix to be extended.
185
+ auto extendPrefix = prefixAlloc.GetPrefix(s, prevPrefix);
186
+
187
+ // Evaluate the language model, but use the cache if we've already considered this string before
188
+ auto lmCacheItem = make_tuple(prevPrefix, s);
189
+ auto lmCacheIter = lmScoreCache.find(lmCacheItem);
190
+ float_t lpLang = 0;
191
+ if (lmCacheIter == lmScoreCache.end()) {
192
+ lpLang = langModel.ScoreTransition(prevPrefix, s);
193
+ lpLang *= lmWeight;
194
+ lmCacheIter = lmScoreCache.emplace(lmCacheItem, lpLang).first;
195
+ }
196
+ lpLang = lmCacheIter->second;
197
+
198
+ PrefixScore &extendScore = nextBeam[extendPrefix];
199
+ // Remember, adding two log probabilities is equivalent to multiplying two probabilities
200
+ if (s != prevToken) {
201
+ extendScore.lProbChar = log_sum_exp(extendScore.lProbChar, prevScore.lProbBlank + lpEmit + lpLang, prevScore.lProbChar + lpEmit + lpLang);
202
+ // extendScore.raw_lProbChar = log_sum_exp(extendScore.raw_lProbChar, prevScore.raw_lProbBlank + lpEmit , prevScore.raw_lProbChar + lpEmit );
203
+ } else {
204
+ // We don't include the previous probability of not ending in blank if s is repeated at the end. The CTC
205
+ // algorithm merges characters not separated by a blank.
206
+ extendScore.lProbChar = log_sum_exp(extendScore.lProbChar , prevScore.lProbBlank + lpEmit + lpLang);
207
+ // extendScore.raw_lProbChar = log_sum_exp(extendScore.raw_lProbChar, prevScore.raw_lProbBlank + lpEmit );
208
+ }
209
+
210
+ // If the token is repeated, we also have to deal with the unchanged prefix since repeated characters are collapsed
211
+ if (s == prevToken) {
212
+ PrefixScore &collapseScore = nextBeam[prevPrefix];
213
+ collapseScore.lProbChar = log_sum_exp(collapseScore.lProbChar , prevScore.lProbChar + lpEmit);
214
+ // collapseScore.raw_lProbChar = log_sum_exp(collapseScore.raw_lProbChar, prevScore.raw_lProbChar + lpEmit);
215
+ }
216
+
217
+ }
218
+ }
219
+
220
+ Beam vecNextBeam(begin(nextBeam), end(nextBeam));
221
+
222
+ if (vecNextBeam.size() > beamSize) {
223
+ partial_sort(begin(vecNextBeam), begin(vecNextBeam) + beamSize, end(vecNextBeam),
224
+ [] (const BeamItem &a, const BeamItem &b) {
225
+ return a.second.get_lScore() > b.second.get_lScore();
226
+ }
227
+ );
228
+ vecNextBeam.resize(beamSize);
229
+ }
230
+
231
+ beam = move(vecNextBeam);
232
+ }
233
+
234
+ // Find the best raw score
235
+ const BeamItem *bestItem = nullptr;
236
+ // for (const BeamItem &b : beam) {
237
+ // if (bestItem == nullptr or b.second.get_raw_lScore() > bestItem->second.get_raw_lScore()) {
238
+ // bestItem = &b;
239
+ // }
240
+ // }
241
+ if (! beam.empty()) {
242
+ bestItem = &beam[0];
243
+ }
244
+
245
+ if (bestItem != nullptr) {
246
+ auto retList = bestItem->first->ToList();
247
+
248
+ return { move(retList), retScore };
249
+ } else {
250
+ return { {}, NEG_INF };
251
+ }
252
+ }
253
+
254
+ typedef std::pair<Prefix*, float_t> RegBeamItem;
255
+
256
+ bool operator<(const RegBeamItem &a, const RegBeamItem &b) {
257
+ return a.second > b.second;
258
+ }
259
+
260
+ template<typename scalar_t>
261
+ pair<vector<token_t>, float_t>
262
+ reg_beam_decode_impl(const pred_seq_t<scalar_t> &logProbs, const int64_t beamSize,
263
+ scalar_t minProb,
264
+ const LanguageModel &langModel, scalar_t lmWeight)
265
+ {
266
+ const int64_t T = logProbs.size(0);
267
+ const int64_t S = logProbs.size(1);
268
+
269
+ // NOTE: In log space, the following is true:
270
+ // 1. Adding two probabilities: log_sum_exp(l_p_a, l_p_b)
271
+ // 2. Multiplying two probabilities: l_p_a + l_p_b
272
+ // 3. log P(0) = -inf
273
+ // 4. log P(1) = 0
274
+
275
+ // Convert to log-space
276
+ if (minProb > 0) {
277
+ minProb = log(minProb);
278
+ } else {
279
+ minProb = NEG_INF;
280
+ }
281
+
282
+ auto retScore = get_vision_confidence(logProbs, minProb);
283
+
284
+ if (retScore < minProb) {
285
+ return { {}, NEG_INF };
286
+ }
287
+
288
+ PrefixAllocator prefixAlloc;
289
+
290
+ vector<RegBeamItem> beam, nextBeam;
291
+ beam.emplace_back(prefixAlloc.GetPrefix(), 0); // log(1) = 0
292
+
293
+ for (int64_t t = 0; t < T && !beam.empty(); ++t) {
294
+ nextBeam.clear();
295
+
296
+ auto addToBeam = [&nextBeam, beamSize] (const RegBeamItem &rbi) {
297
+ nextBeam.push_back(rbi);
298
+ };
299
+
300
+ // Expand each path in the beam
301
+ for (const RegBeamItem &prevNode : beam) {
302
+ if (prevNode.first->Token == 1) {
303
+ // Move completed paths along without processing further
304
+ addToBeam(prevNode);
305
+ continue;
306
+ }
307
+
308
+ Prefix *prevPrefix = prevNode.first;
309
+ float_t prevScore = prevNode.second;
310
+
311
+ // Loop over vocab
312
+ for (int64_t s = 0; s < S; ++s) {
313
+ float_t lpEmit = logProbs[t][s];
314
+
315
+ if (lpEmit < minProb) {
316
+ // The probability dropped below threshold, so stop processing this path
317
+ continue;
318
+ }
319
+
320
+ auto extendPrefix = prefixAlloc.GetPrefix(s, prevPrefix);
321
+
322
+ float_t lpLang = langModel.ScoreTransition(prevPrefix, s);
323
+
324
+ float_t lpNext = prevScore + lpLang + lpEmit;
325
+
326
+ addToBeam({extendPrefix, lpNext});
327
+ }
328
+ }
329
+
330
+ if (nextBeam.size() > beamSize) {
331
+ // Find the top-k items, and then truncate the rest
332
+ partial_sort(begin(nextBeam), begin(nextBeam) + beamSize, end(nextBeam));
333
+ nextBeam.resize(beamSize);
334
+ }
335
+
336
+ std::swap(beam, nextBeam);
337
+ }
338
+
339
+ if (!beam.empty()) {
340
+ // The highest probability element will always be in the back
341
+ RegBeamItem rbi{ nullptr, NEG_INF };
342
+ for (auto &rb : beam) {
343
+ if (rbi.first == nullptr || rb.second > rbi.second) {
344
+ rbi = rb;
345
+ }
346
+ }
347
+
348
+ auto retList = rbi.first->ToList();
349
+
350
+ return { move(retList), retScore };
351
+ } else {
352
+ return { {}, NEG_INF };
353
+ }
354
+ }
355
+
356
+
357
+
358
+ template<typename scalar_t>
359
+ void dp_beam_decode_impl(const torch::TensorAccessor<scalar_t, 3> &probsAccess,
360
+ torch::TensorAccessor<int64_t, 2> retAccess,
361
+ torch::TensorAccessor<scalar_t, 1> confAccess,
362
+ int64_t beamSize, int64_t blank,
363
+ scalar_t minProb,
364
+ const LanguageModel *langModel,
365
+ scalar_t lmWeight,
366
+ bool combineDuplicates)
367
+ {
368
+ const int64_t N = probsAccess.size(0);
369
+
370
+ #pragma omp parallel for num_threads(8)
371
+ for (int64_t i = 0; i < N; ++i) {
372
+ vector<token_t> seq;
373
+ float_t lConf;
374
+ if (combineDuplicates) {
375
+ tie(seq, lConf) = ctc_beam_decode_impl(probsAccess[i], beamSize, blank,
376
+ minProb,
377
+ *langModel, lmWeight);
378
+ } else {
379
+ tie(seq, lConf) = reg_beam_decode_impl(probsAccess[i], beamSize,
380
+ minProb,
381
+ *langModel, lmWeight);
382
+ }
383
+
384
+ int64_t sz = min<int64_t>(seq.size(), retAccess.size(1));
385
+
386
+ for (int64_t k = 0; k < sz; ++k) {
387
+ retAccess[i][k] = seq[k];
388
+ }
389
+
390
+ confAccess[i] = exp(lConf);
391
+ }
392
+ }
393
+
394
+ std::tuple<torch::Tensor, torch::Tensor>
395
+ beam_decode(torch::Tensor probs, int64_t beamSize, int64_t blank,
396
+ float minProb,
397
+ const LanguageModel *langModel,
398
+ float lmWeight,
399
+ bool combineDuplicates)
400
+ {
401
+ if (langModel == nullptr) {
402
+ langModel = &NullLanguageModel;
403
+ }
404
+
405
+ auto tStart = chrono::high_resolution_clock::now();
406
+
407
+ probs = probs.contiguous();
408
+
409
+ bool collapse = false;
410
+ if (probs.dim() == 2) {
411
+ // N,T,C
412
+ probs = probs.unsqueeze(0);
413
+ collapse = true;
414
+ }
415
+
416
+ probs = probs.log();
417
+
418
+ torch::Tensor ret = torch::ones({ probs.size(0), probs.size(1) }, torch::kInt64);
419
+ torch::Tensor conf = torch::zeros({ probs.size(0) }, probs.options());
420
+
421
+ auto retAccess = ret.accessor<int64_t, 2>();
422
+
423
+ AT_DISPATCH_FLOATING_TYPES(
424
+ probs.scalar_type(),
425
+ "cpu_beam_decode",
426
+ ([&] {
427
+ dp_beam_decode_impl(
428
+ probs.accessor<scalar_t, 3>(),
429
+ retAccess,
430
+ conf.accessor<scalar_t, 1>(),
431
+ beamSize, blank,
432
+ static_cast<scalar_t>(minProb),
433
+ langModel,
434
+ static_cast<scalar_t>(lmWeight),
435
+ combineDuplicates
436
+ );
437
+ })
438
+ );
439
+
440
+ if (collapse) {
441
+ ret = ret.squeeze(0);
442
+ conf = conf[0];
443
+ }
444
+
445
+ auto tEnd = chrono::high_resolution_clock::now();
446
+
447
+ typedef chrono::duration<double, std::milli> tp_t;
448
+ tp_t totalElapsed = tEnd - tStart;
449
+
450
+ cout << "Beam Decode " << probs.size(0) << " - "
451
+ << "Total: " << totalElapsed.count() << "ms"
452
+ << endl;
453
+
454
+ return { ret, conf };
455
+ }
456
+
457
+ std::unique_ptr<LanguageModel> create_sbo_lm(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoffWeight)
458
+ {
459
+ return make_unique<SBO_LanguageModel>(dataFilePath, move(tokenMapping), backoffWeight);
460
+ }
nemo-retriever-ocr/cpp/beam_decode/beam_decode.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <torch/torch.h>
8
+
9
+ #include "language_model.h"
10
+
11
+ std::tuple<torch::Tensor, torch::Tensor>
12
+ beam_decode(torch::Tensor probs, int64_t beamSize, int64_t blank,
13
+ float minProb,
14
+ const LanguageModel *langModel,
15
+ float lmWeight,
16
+ bool combineDuplicates);
17
+
18
+ std::unique_ptr<LanguageModel> create_sbo_lm(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoffWeight);
nemo-retriever-ocr/cpp/beam_decode/kn_lm.cpp ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "kn_lm.h"
6
+
7
+ using namespace std;
8
+
9
+
10
+ KN_LanguageModel::KN_LanguageModel(const string &dataFilePath, token_mapping_t tokenMapping, float_t knDelta)
11
+ : NGramLMBase(dataFilePath, move(tokenMapping)), m_knDelta(knDelta)
12
+ {
13
+ }
14
+
15
+ float KN_LanguageModel::ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const
16
+ {
17
+ if (prefix.empty()) {
18
+ return ScoreUnigram(suffix);
19
+ } else {
20
+ return ScoreTransition(prefix, suffix);
21
+ }
22
+ }
23
+
24
+ float_t KN_LanguageModel::ScoreUnigram(const std::wstring &uni) const
25
+ {
26
+ auto lIter = m_lookup[1].find(L""s);
27
+ if (lIter == m_lookup[1].end()) {
28
+ throw std::runtime_error("Unigrams not supported by this model!");
29
+ }
30
+
31
+ auto uniIter = lIter->second.find(uni);
32
+ float_t ctUni = 1e-8;
33
+ if (uniIter != lIter->second.end()) {
34
+ ctUni = uniIter->second;
35
+ }
36
+
37
+ float_t ctSuffixes = GetPrefixSum(L""s);
38
+
39
+ return ctUni / ctSuffixes;
40
+ }
41
+
42
+ float_t KN_LanguageModel::ScoreTransition(const std::wstring &prefix, const std::wstring &suffix) const
43
+ {
44
+ if (prefix.empty()) {
45
+ // The number of distinct bigrams that end with this token
46
+ auto rlIter = m_reverseLookup.find(suffix);
47
+
48
+ float_t ctEndingBigrams = 0;
49
+ if (rlIter != m_reverseLookup.end()) {
50
+ ctEndingBigrams = rlIter->second[2].size();
51
+ }
52
+
53
+ float_t ctAllBigrams = m_lookup[2].size();
54
+
55
+ return ctEndingBigrams / ctAllBigrams;
56
+ }
57
+
58
+ auto lIter = m_lookup[prefix.size() + 1].find(prefix);
59
+ float_t ctUqSuffixes = 0;
60
+ float_t ctSuffixes = 0;
61
+ float_t ctSuffix = 0;
62
+ if (lIter != m_lookup[prefix.size() + 1].end()) {
63
+ ctUqSuffixes = lIter->second.size();
64
+
65
+ ctSuffixes = GetPrefixSum(prefix);
66
+
67
+ auto sIter = lIter->second.find(suffix);
68
+ if (sIter != lIter->second.end()) {
69
+ ctSuffix = sIter->second;
70
+ }
71
+ }
72
+
73
+ float_t factor = 0;
74
+ float_t main = 0;
75
+ if (ctSuffixes != 0) {
76
+ factor = m_knDelta * ctUqSuffixes / ctSuffixes;
77
+ // TODO: Figure out how to make this call without copying the string!
78
+ factor *= ScoreTransition({begin(prefix) + 1, end(prefix)}, suffix);
79
+
80
+ main = max<float_t>(ctSuffix - m_knDelta, 0) / ctSuffixes;
81
+ }
82
+
83
+ float_t total = main + factor;
84
+
85
+ return total;
86
+ }
nemo-retriever-ocr/cpp/beam_decode/kn_lm.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <unordered_map>
8
+ #include <vector>
9
+
10
+ #include "ngram_lm_base.h"
11
+
12
+
13
+ class KN_LanguageModel
14
+ : public NGramLMBase
15
+ {
16
+ public:
17
+ KN_LanguageModel(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t knDelta);
18
+
19
+ protected:
20
+ virtual float_t ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const override;
21
+
22
+ private:
23
+ float_t ScoreUnigram(const std::wstring &uni) const;
24
+ float_t ScoreTransition(const std::wstring &prefix, const std::wstring &suffix) const;
25
+
26
+ float_t m_knDelta;
27
+ };
nemo-retriever-ocr/cpp/beam_decode/language_model.cpp ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "language_model.h"
6
+
7
+ #include <locale>
8
+ #include <codecvt>
9
+
10
+ using namespace std;
11
+
12
+ const NullLanguageModel_t NullLanguageModel;
13
+
14
+ NullLanguageModel_t::NullLanguageModel_t()
15
+ : LanguageModel({})
16
+ {
17
+ }
18
+
19
+ TokenMappingWrapper::TokenMappingWrapper(token_mapping_t mapping)
20
+ : token_mapping(move(mapping))
21
+ {
22
+ for (const auto &mp : token_mapping) {
23
+ if (mp.second.size() == 1) {
24
+ wchar_t c = mp.second.front();
25
+ reverse_token_mapping.emplace(c, mp.first);
26
+ }
27
+ }
28
+ }
29
+
30
+ TokenMappingWrapper::Ptr create_token_mapping(token_mapping_t tokenMapping)
31
+ {
32
+ return make_shared<TokenMappingWrapper>(move(tokenMapping));
33
+ }
34
+
35
+
36
+ template<typename token_t>
37
+ vector<tuple<wstring, float>>
38
+ decode_sequences_impl(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping,
39
+ c10::optional<torch::Tensor> probs)
40
+ {
41
+ const token_mapping_t &mapping = tokenMapping->token_mapping;
42
+
43
+ auto tokensAccess = tokens.accessor<token_t, 2>();
44
+
45
+ torch::Tensor pTens = probs.value_or(torch::ones({ tokens.size(0) }, torch::kFloat32));
46
+ if (pTens.dim() == 1) {
47
+ pTens = pTens.unsqueeze(1);
48
+ }
49
+
50
+ auto probsAccess = pTens.accessor<float, 2>();
51
+
52
+ const int64_t B = tokens.size(0);
53
+ const int64_t T = tokens.size(1);
54
+
55
+ vector<tuple<wstring, float>> ret;
56
+
57
+ for (int64_t b = 0; b < B; ++b) {
58
+ wstring buff;
59
+
60
+ float logProb = 0.0f; // log 1
61
+ bool done = false;
62
+ for (int64_t t = 0; t < T && ! done; ++t) {
63
+ typename token_mapping_t::key_type tokIdx = tokensAccess[b][t];
64
+
65
+ if (t < probsAccess.size(1)) {
66
+ logProb += log(probsAccess[b][t]);
67
+ }
68
+
69
+ switch (tokIdx) {
70
+ case 0:
71
+ // Blank char
72
+ continue;
73
+ case 1:
74
+ // End of sequence char
75
+ done = true;
76
+ break;
77
+ case 2:
78
+ buff.push_back('^');
79
+ break;
80
+ default:
81
+ auto iter = mapping.find(tokIdx);
82
+ if (iter == mapping.end()) {
83
+ throw std::runtime_error("The token mapping doesn't contain an entry for index " + to_string(tokIdx));
84
+ }
85
+ buff += iter->second;
86
+ break;
87
+ }
88
+ }
89
+
90
+ ret.emplace_back(move(buff), exp(logProb));
91
+ }
92
+
93
+ return ret;
94
+ }
95
+
96
+ vector<tuple<wstring, float>>
97
+ decode_sequences(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping,
98
+ c10::optional<torch::Tensor> probs)
99
+ {
100
+ if (tokens.dim() != 2) {
101
+ throw std::runtime_error("`tokens` must be 2-dimensions of type B,T!");
102
+ }
103
+
104
+ if (tokenMapping == nullptr) {
105
+ throw std::runtime_error("Cannot supply a null token mapping!");
106
+ }
107
+
108
+ const token_mapping_t &mapping = tokenMapping->token_mapping;
109
+
110
+ if (mapping.empty()) {
111
+ throw std::runtime_error("The token mapping hasn't been initialized!");
112
+ }
113
+
114
+ if (probs.has_value()) {
115
+ if (probs.value().scalar_type() != torch::kFloat32) {
116
+ throw std::runtime_error("If the probability distribution is specified, then it must be of type `torch.float32`");
117
+ }
118
+ if (probs.value().size(0) != tokens.size(0)) {
119
+ throw std::runtime_error("The probability distribution batch size doesn't match the tokens batch size!");
120
+ }
121
+ if (probs.value().dim() == 2 && probs.value().size(1) != tokens.size(1)) {
122
+ throw std::runtime_error("Invalid probability distribution shape!");
123
+ }
124
+ }
125
+
126
+ vector<tuple<wstring, float>> ret;
127
+
128
+ AT_DISPATCH_INTEGRAL_TYPES(
129
+ tokens.scalar_type(),
130
+ "decode_sequences_impl",
131
+ ([&] {
132
+ ret = decode_sequences_impl<scalar_t>(tokens, tokenMapping, probs);
133
+ })
134
+ );
135
+
136
+ return ret;
137
+ }
138
+
139
+
140
+ std::string ws2s(const std::wstring& wstr)
141
+ {
142
+ using convert_typeX = std::codecvt_utf8<wchar_t>;
143
+ std::wstring_convert<convert_typeX, wchar_t> converterX;
144
+
145
+ return converterX.to_bytes(wstr);
146
+ }
147
+
nemo-retriever-ocr/cpp/beam_decode/language_model.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <memory>
8
+
9
+ #include <torch/torch.h>
10
+
11
+ #include "prefix.h"
12
+ #include "log_sum_exp.h"
13
+
14
+ typedef std::unordered_map<int64_t, std::wstring> token_mapping_t;
15
+ typedef std::unordered_map<wchar_t, int64_t> reverse_token_mapping_t;
16
+
17
+
18
+ class LanguageModel
19
+ {
20
+ public:
21
+ virtual ~LanguageModel() {}
22
+
23
+ virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const = 0;
24
+
25
+ const token_mapping_t &TokenMapping() const { return m_tokenMapping; }
26
+
27
+ protected:
28
+ LanguageModel(token_mapping_t tokenMapping)
29
+ : m_tokenMapping(std::move(tokenMapping))
30
+ {}
31
+
32
+ token_mapping_t m_tokenMapping;
33
+ };
34
+
35
+
36
+ class NullLanguageModel_t
37
+ : public LanguageModel
38
+ {
39
+ public:
40
+ NullLanguageModel_t();
41
+
42
+ virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const override
43
+ {
44
+ // log P(1)
45
+ // Which means the probability is unchanged
46
+ return 0;
47
+ }
48
+ };
49
+
50
+ extern const NullLanguageModel_t NullLanguageModel;
51
+
52
+ struct TokenMappingWrapper
53
+ {
54
+ typedef std::shared_ptr<TokenMappingWrapper> Ptr;
55
+
56
+ TokenMappingWrapper(token_mapping_t mapping);
57
+
58
+ token_mapping_t token_mapping;
59
+ reverse_token_mapping_t reverse_token_mapping;
60
+ };
61
+
62
+ TokenMappingWrapper::Ptr create_token_mapping(token_mapping_t tokenMapping);
63
+
64
+ std::vector<std::tuple<std::wstring, float>>
65
+ decode_sequences(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping,
66
+ c10::optional<torch::Tensor> probs = torch::nullopt);
nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.cpp ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "log_sum_exp.h"
6
+
7
+ const float_t NEG_INF = -std::numeric_limits<float_t>::infinity();
nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <cmath>
8
+ #include <limits>
9
+ #include <algorithm>
10
+
11
+ typedef float float_t;
12
+ extern const float_t NEG_INF;
13
+
14
+ template<typename T>
15
+ inline T max_val(T v)
16
+ {
17
+ return v;
18
+ }
19
+
20
+ template<typename T, typename ...Args>
21
+ inline T max_val(T v, Args... rest)
22
+ {
23
+ auto restMax = max_val(rest...);
24
+
25
+ return std::max(v, restMax);
26
+ }
27
+
28
+ template<typename T>
29
+ inline T sum_exp(T maxVal, T v)
30
+ {
31
+ return std::exp(v - maxVal);
32
+ }
33
+
34
+ template<typename T, typename ...Args>
35
+ inline T sum_exp(T maxVal, T v, Args... rest)
36
+ {
37
+ auto restSum = sum_exp(maxVal, rest...);
38
+
39
+ return sum_exp(maxVal, v) + restSum;
40
+ }
41
+
42
+ template<typename T, typename ...Args>
43
+ inline T log_sum_exp(T v, Args ...args)
44
+ {
45
+ auto maxVal = max_val(v, args...);
46
+
47
+ if (maxVal == -std::numeric_limits<T>::infinity()) {
48
+ return -std::numeric_limits<T>::infinity();
49
+ }
50
+
51
+ auto sumExp = sum_exp(maxVal, v, args...);
52
+
53
+ return maxVal + std::log(sumExp);
54
+ }
nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.cpp ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "ngram_lm_base.h"
6
+
7
+ #include <iostream>
8
+ #include <fstream>
9
+
10
+ #if defined( USE_BOOST )
11
+
12
+ #include <boost/archive/binary_oarchive.hpp>
13
+ #include <boost/archive/binary_iarchive.hpp>
14
+ #include <boost/serialization/vector.hpp>
15
+ #include <boost/serialization/string.hpp>
16
+ #include <boost/serialization/unordered_map.hpp>
17
+
18
+ #endif // USE_BOOST
19
+
20
+ using namespace std;
21
+
22
+ const std::wstring WORD_END(1, 2);
23
+ const std::wstring NUMERIC(1, 3);
24
+ const std::wstring UNMODELED(1, 4);
25
+
26
+ struct LMStorage
27
+ {
28
+ lookup_t Lookup;
29
+ reverse_lookup_t ReverseLookup;
30
+
31
+ template<class Archive>
32
+ void serialize(Archive &ar, const unsigned int version) {
33
+ ar & Lookup;
34
+ ar & ReverseLookup;
35
+ }
36
+ };
37
+
38
+ void save_suffix_map(std::fstream& fs, const suffix_map_t& suffix_map)
39
+ {
40
+ // write out number of elements for Lookup
41
+ std::size_t suffix_map_count = suffix_map.size();
42
+ fs.write((char*)(&suffix_map_count), sizeof(suffix_map_count));
43
+ for (suffix_map_t::const_iterator reverse_lookup_it = suffix_map.begin(); reverse_lookup_it != suffix_map.end(); ++reverse_lookup_it)
44
+ {
45
+ // write out the key
46
+ size_t key_len = reverse_lookup_it->first.length();
47
+ fs.write((char*)(&key_len), sizeof(key_len));
48
+ fs.write((char*)(reverse_lookup_it->first.data()), key_len * sizeof(wchar_t));
49
+
50
+ // write out value
51
+ fs.write((char*)(&reverse_lookup_it->second), sizeof(reverse_lookup_it->second));
52
+ }
53
+ }
54
+
55
+ void save_lookup(std::fstream& fs, const lookup_t& lookup)
56
+ {
57
+ // write out number of elements for Lookup
58
+ std::size_t lookup_count = lookup.size();
59
+ fs.write((char*)(&lookup_count), sizeof(lookup_count));
60
+ for (lookup_t::const_iterator lookup_it = lookup.begin(); lookup_it != lookup.end(); ++lookup_it)
61
+ {
62
+ // write out element map size
63
+ std::size_t map_elem_count = lookup_it->size();
64
+ fs.write((char*)(&map_elem_count), sizeof(map_elem_count));
65
+
66
+ for (string_suffix_map_t::const_iterator str_sfx_it = lookup_it->begin(); str_sfx_it != lookup_it->end(); ++str_sfx_it)
67
+ {
68
+ // write out key
69
+ size_t key_len = str_sfx_it->first.length();
70
+ fs.write((char*)(&key_len), sizeof(key_len));
71
+ fs.write((char*)(str_sfx_it->first.data()), key_len * sizeof(wchar_t));
72
+ save_suffix_map(fs, str_sfx_it->second);
73
+ }
74
+ }
75
+ }
76
+
77
+ void save_reverse_lookup(std::fstream& fs, const reverse_lookup_t& reverse_lookup)
78
+ {
79
+ // write out number of elements for Lookup
80
+ std::size_t reverse_lookup_count = reverse_lookup.size();
81
+ fs.write((char*)(&reverse_lookup_count), sizeof(reverse_lookup_count));
82
+ for (reverse_lookup_t::const_iterator reverse_lookup_it = reverse_lookup.begin(); reverse_lookup_it != reverse_lookup.end(); ++reverse_lookup_it)
83
+ {
84
+ // write out the key
85
+ size_t key_len = reverse_lookup_it->first.length();
86
+ fs.write((char*)(&key_len), sizeof(key_len));
87
+ fs.write((char*)(reverse_lookup_it->first.data()), key_len * sizeof(wchar_t));
88
+
89
+ // write out value vector length
90
+ size_t val_vec_len = reverse_lookup_it->second.size();
91
+ fs.write((char*)(&val_vec_len), sizeof(val_vec_len));
92
+
93
+ for (suffix_map_vec_t::const_iterator val_vec_it = reverse_lookup_it->second.begin();
94
+ val_vec_it != reverse_lookup_it->second.end();
95
+ ++val_vec_it)
96
+ {
97
+ save_suffix_map(fs, *val_vec_it);
98
+ }
99
+ }
100
+ }
101
+
102
+ void load_suffix_map(std::fstream& fs, suffix_map_t& suffix_map)
103
+ {
104
+ // read in number of elements
105
+ std::size_t suffix_map_count = 0;
106
+ fs.read((char*)(&suffix_map_count), sizeof(suffix_map_count));
107
+ for (size_t suffix_map_index = 0; suffix_map_index < suffix_map_count; ++suffix_map_index )
108
+ {
109
+ // read in key
110
+ std::size_t key_len = 0;
111
+ fs.read((char*)(&key_len), sizeof(key_len));
112
+
113
+ std::wstring wkey(key_len, 0);
114
+ fs.read((char*)(wkey.data()), key_len * sizeof(wchar_t));
115
+ uint32_t value = 0;
116
+ fs.read((char*)(&value), sizeof(value));
117
+
118
+ suffix_map.insert(std::make_pair(wkey, value));
119
+ }
120
+ }
121
+
122
+ void load_lookup(std::fstream& fs, lookup_t& lookup)
123
+ {
124
+ // read in number of elements
125
+ std::size_t lookup_count = 0;
126
+ fs.read((char*)(&lookup_count), sizeof(lookup_count));
127
+ for (size_t lookup_index = 0; lookup_index < lookup_count; ++lookup_index)
128
+ {
129
+ std::size_t map_elem_count = 0;
130
+ fs.read((char*)(&map_elem_count), sizeof(map_elem_count));
131
+
132
+ lookup.push_back(string_suffix_map_t());
133
+ string_suffix_map_t& str_sfx_map = lookup.back();
134
+
135
+ for (size_t str_sfx_map_index = 0; str_sfx_map_index < map_elem_count; ++str_sfx_map_index)
136
+ {
137
+ std::size_t key_len = 0;
138
+ fs.read((char*)(&key_len), sizeof(key_len));
139
+
140
+ std::wstring wkey(key_len, 0);
141
+ fs.read((char*)(wkey.data()), key_len * sizeof(wchar_t));
142
+ str_sfx_map.insert(std::make_pair<wstring, suffix_map_t>(std::wstring(wkey), suffix_map_t()));
143
+ suffix_map_t& suffix_map = str_sfx_map[wkey];
144
+
145
+ load_suffix_map(fs, suffix_map);
146
+ }
147
+ }
148
+ }
149
+
150
+ void load_reverse_lookup(std::fstream& fs, reverse_lookup_t& reverse_lookup)
151
+ {
152
+ // read in number of elements
153
+ std::size_t reverse_lookup_count = 0;
154
+ fs.read((char*)(&reverse_lookup_count), sizeof(reverse_lookup_count));
155
+ for (size_t rev_lookup_index = 0; rev_lookup_index < reverse_lookup_count; ++rev_lookup_index )
156
+ {
157
+ // read in the key
158
+ std::size_t key_len = 0;
159
+ fs.read((char*)(&key_len), sizeof(key_len));
160
+
161
+ std::wstring wkey(key_len, 0);
162
+ fs.read((char*)(wkey.data()), key_len * sizeof(wchar_t));
163
+ reverse_lookup.insert(std::make_pair(wkey, suffix_map_vec_t()));
164
+ suffix_map_vec_t& val_vec = reverse_lookup[wkey];
165
+
166
+ std::size_t val_vec_len = 0;
167
+ fs.read((char*)(&val_vec_len), sizeof(val_vec_len));
168
+
169
+ for (size_t val_vec_index = 0; val_vec_index < val_vec_len; ++val_vec_index)
170
+ {
171
+ val_vec.push_back(suffix_map_t());
172
+ suffix_map_t& suffix_map = val_vec.back();
173
+ load_suffix_map(fs, suffix_map);
174
+ }
175
+ }
176
+ }
177
+
178
+ #if ! defined( USE_BOOST )
179
+
180
+ NGramLMBase::NGramLMBase(const string &dataFilePath, token_mapping_t tokenMapping)
181
+ : LanguageModel(move(tokenMapping))
182
+ {
183
+ std::fstream in(dataFilePath, std::ios::in | std::ios::binary);
184
+ load_lookup(in, m_lookup);
185
+ load_reverse_lookup(in, m_reverseLookup);
186
+
187
+ if (m_lookup.size() >= 10) {
188
+ throw runtime_error("Only N-Grams of 9 or less are supported!");
189
+ }
190
+
191
+ for (auto &ngLevel : m_lookup) {
192
+ for (auto &kvPrefixLevel : ngLevel) {
193
+ uint32_t ct = 0;
194
+ for (auto &kvSfx : kvPrefixLevel.second) {
195
+ ct += kvSfx.second;
196
+ }
197
+ m_prefixSumLookup.emplace(kvPrefixLevel.first, ct);
198
+ }
199
+ }
200
+ }
201
+
202
+ void save_ngram_data_file(const lookup_t& lookup, const reverse_lookup_t& reverseLookup, const std::string &outputPath)
203
+ {
204
+ std::fstream out(outputPath, std::ios::out | std::ios::binary);
205
+
206
+ save_lookup(out, lookup);
207
+ save_reverse_lookup(out, reverseLookup);
208
+ }
209
+
210
+ #else // USE_BOOST
211
+
212
+ NGramLMBase::NGramLMBase(const string &dataFilePath, token_mapping_t tokenMapping)
213
+ : LanguageModel(move(tokenMapping))
214
+ {
215
+ {
216
+ ifstream dfStr(dataFilePath, ios_base::in | ios_base::binary);
217
+ boost::archive::binary_iarchive ia(dfStr);
218
+
219
+ LMStorage s;
220
+ ia >> s;
221
+
222
+
223
+ m_lookup = move(s.Lookup);
224
+
225
+ m_reverseLookup = move(s.ReverseLookup);
226
+ }
227
+
228
+ if (m_lookup.size() >= 10) {
229
+ throw runtime_error("Only N-Grams of 9 or less are supported!");
230
+ }
231
+
232
+ for (auto &ngLevel : m_lookup) {
233
+ for (auto &kvPrefixLevel : ngLevel) {
234
+ uint32_t ct = 0;
235
+ for (auto &kvSfx : kvPrefixLevel.second) {
236
+ ct += kvSfx.second;
237
+ }
238
+ m_prefixSumLookup.emplace(kvPrefixLevel.first, ct);
239
+ }
240
+ }
241
+ }
242
+
243
+ void save_ngram_data_file(lookup_t lookup, reverse_lookup_t reverseLookup, const std::string &outputPath)
244
+ {
245
+ ofstream ofs(outputPath, ios_base::out | ios_base::binary);
246
+
247
+ LMStorage s;
248
+ s.Lookup = move(lookup);
249
+ s.ReverseLookup = move(reverseLookup);
250
+
251
+ boost::archive::binary_oarchive oa(ofs);
252
+ oa << s;
253
+ }
254
+
255
+ #endif // USE_BOOST
256
+
257
+ float_t NGramLMBase::ScoreTransition(const Prefix *p, token_t nextToken) const
258
+ {
259
+ std::wstring prefix;
260
+ if (! ConvertToString(p, prefix)) {
261
+ return NEG_INF;
262
+ }
263
+
264
+ const std::wstring *pSuffix = nullptr;
265
+
266
+ if (nextToken != 1) {
267
+ auto iter = m_tokenMapping.find(nextToken);
268
+ if (iter == m_tokenMapping.end()) {
269
+ pSuffix = &UNMODELED;
270
+ } else {
271
+ pSuffix = &iter->second;
272
+
273
+ if (iswdigit(pSuffix->at(0))) {
274
+ pSuffix = &NUMERIC;
275
+ }
276
+ }
277
+
278
+ } else {
279
+ pSuffix = &WORD_END;
280
+ }
281
+
282
+ float_t ret = ScoreTransitionImpl(prefix, *pSuffix);
283
+
284
+ if (ret > 0) {
285
+ return log(ret);
286
+ } else {
287
+ return NEG_INF;
288
+ }
289
+ }
290
+
291
+ bool NGramLMBase::ConvertToString(const Prefix *p, std::wstring &prefix) const
292
+ {
293
+ const Prefix *stk[10];
294
+ int32_t sz = -1;
295
+ const Prefix *curr = p;
296
+ decltype(sz) mlSz{(int)m_lookup.size() - 2};
297
+ while (curr && sz < mlSz) {
298
+ stk[++sz] = curr;
299
+ curr = curr->Parent;
300
+ }
301
+
302
+ // Either blank or empty prefix
303
+ if (sz < 1) { return true; }
304
+
305
+ --sz;
306
+ for (; sz >= 0; --sz) {
307
+ token_t tok = stk[sz]->Token;
308
+ // End of word token, which maps to the null character
309
+ if (tok == 1) {
310
+ prefix.push_back(WORD_END[0]);
311
+ } else if (tok == 0) {
312
+ // Do nothing
313
+ } else {
314
+ auto iter = m_tokenMapping.find(tok);
315
+ if (iter == m_tokenMapping.end()) {
316
+ prefix += UNMODELED;
317
+ } else {
318
+ const std::wstring &wChar = iter->second;
319
+
320
+ if (iswdigit(wChar[0])) {
321
+ prefix += NUMERIC;
322
+ } else {
323
+ prefix += wChar;
324
+ }
325
+ }
326
+ }
327
+ }
328
+
329
+ return true;
330
+ }
nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.h ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <unordered_map>
8
+ #include <vector>
9
+
10
+ #include "language_model.h"
11
+
12
+ // #define USE_BOOST 1
13
+
14
+
15
+ typedef std::unordered_map<std::wstring, uint32_t> suffix_map_t;
16
+
17
+ /* Tells us the number of suffixes for a given ngram of order K
18
+ Keys:
19
+ 1. NGram Order
20
+ 2. Prefix
21
+ 3. Suffix
22
+ Value:
23
+ Count
24
+ */
25
+ typedef std::unordered_map<std::wstring, suffix_map_t> string_suffix_map_t;
26
+ typedef std::vector<string_suffix_map_t> lookup_t;
27
+ /* Tells us the number of K-gram prefixes found for a given suffix
28
+ Keys:
29
+ 1. Suffix
30
+ 2. NGram Order
31
+ 3. Prefix
32
+ Values:
33
+ Count
34
+ */
35
+ typedef std::vector<suffix_map_t> suffix_map_vec_t;
36
+ typedef std::unordered_map<std::wstring, suffix_map_vec_t> reverse_lookup_t;
37
+
38
+
39
+
40
+ extern const std::wstring WORD_END;
41
+ extern const std::wstring NUMERIC;
42
+ extern const std::wstring UNMODELED;
43
+
44
+ class NGramLMBase
45
+ : public LanguageModel
46
+ {
47
+ public:
48
+ virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const override;
49
+
50
+ protected:
51
+ NGramLMBase(const std::string &dataFilePath, token_mapping_t tokenMapping);
52
+
53
+ virtual float_t ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const = 0;
54
+
55
+ bool ConvertToString(const Prefix *p, std::wstring &prefix) const;
56
+
57
+ float_t GetPrefixSum(const std::wstring &prefix) const;
58
+
59
+ lookup_t m_lookup;
60
+ reverse_lookup_t m_reverseLookup;
61
+
62
+ std::unordered_map<std::wstring, uint32_t> m_prefixSumLookup;
63
+ };
64
+
65
+ #if ! defined( USE_BOOST )
66
+ void save_ngram_data_file(const lookup_t& lookup, const reverse_lookup_t& reverseLookup, const std::string &output_path);
67
+ #else // USE_BOOST
68
+ void save_ngram_data_file(lookup_t lookup, reverse_lookup_t reverseLookup, const std::string &output_path);
69
+ #endif // USE_BOOST
70
+
71
+ inline float_t NGramLMBase::GetPrefixSum(const std::wstring &prefix) const
72
+ {
73
+ auto iter = m_prefixSumLookup.find(prefix);
74
+
75
+ if (iter == m_prefixSumLookup.end()) {
76
+ return 0;
77
+ } else {
78
+ return iter->second;
79
+ }
80
+ }
nemo-retriever-ocr/cpp/beam_decode/prefix.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "prefix.h"
6
+
7
+ using namespace std;
8
+
9
+ vector<token_t> Prefix::ToList() const
10
+ {
11
+ vector<token_t> ret;
12
+
13
+ auto curr = this;
14
+
15
+ while (curr) {
16
+ if (curr->Token != 0) {
17
+ ret.push_back(curr->Token);
18
+ }
19
+ curr = curr->Parent;
20
+ }
21
+
22
+ return { rbegin(ret), rend(ret) };
23
+ }
nemo-retriever-ocr/cpp/beam_decode/prefix.h ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <cstdlib>
8
+ #include <memory>
9
+ #include <vector>
10
+ #include <unordered_map>
11
+ #include <list>
12
+
13
+ typedef int32_t token_t;
14
+
15
+ class Prefix;
16
+
17
+ // typedef std::shared_ptr<Prefix> PrefixPtr;
18
+
19
+ class Prefix
20
+ {
21
+ public:
22
+ token_t Token;
23
+ Prefix *Parent;
24
+
25
+ Prefix(token_t token = 0 /* blank */, Prefix *parent = nullptr)
26
+ : Token(token), Parent(parent)
27
+ {}
28
+
29
+ std::vector<token_t> ToList() const;
30
+
31
+ size_t size() const;
32
+ };
33
+
34
+
35
+ ///// Borrowed from Boost libraries
36
+ template<typename T>
37
+ void hash_combine(size_t & seed, T const& v)
38
+ {
39
+ seed ^= std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
40
+ }
41
+ /////
42
+
43
+ namespace std {
44
+ template<>
45
+ struct hash<Prefix*>
46
+ {
47
+ size_t operator()(const Prefix *p) const noexcept
48
+ {
49
+ size_t seed = 0;
50
+
51
+ while (p) {
52
+ if (p->Token != 0) {
53
+ hash_combine(seed, p->Token);
54
+ }
55
+ p = p->Parent;
56
+ }
57
+ return seed;
58
+ }
59
+ };
60
+
61
+ template<>
62
+ struct hash<tuple<Prefix*, token_t>>
63
+ {
64
+ size_t operator()(const tuple<Prefix*, token_t> &t) const noexcept
65
+ {
66
+ size_t seed = 0;
67
+ hash_combine(seed, get<0>(t));
68
+ hash_combine(seed, get<1>(t));
69
+ return seed;
70
+ }
71
+ };
72
+
73
+ template<>
74
+ struct equal_to<Prefix*>
75
+ {
76
+ bool operator()(const Prefix *a, const Prefix *b) const noexcept
77
+ {
78
+ while (a != nullptr && b != nullptr) {
79
+ if (a->Token != b->Token) {
80
+ return false;
81
+ }
82
+ a = a->Parent;
83
+ b = b->Parent;
84
+ }
85
+ // If one chain is shorter than the other
86
+ return a == b;
87
+ }
88
+ };
89
+ }
90
+
91
+ inline size_t Prefix::size() const
92
+ {
93
+ size_t ret = 0;
94
+ auto p = this;
95
+ while (p != nullptr) {
96
+ ret += 1;
97
+ p = p->Parent;
98
+ }
99
+ return ret;
100
+ }
101
+
102
+
103
+ class PrefixAllocator
104
+ {
105
+ public:
106
+ PrefixAllocator() = default;
107
+ ~PrefixAllocator();
108
+
109
+ template<typename ...Args>
110
+ Prefix *GetPrefix(Args&& ...ctorArgs);
111
+
112
+ private:
113
+ void AllocateNextBuffer();
114
+
115
+ std::list<Prefix*> m_buffers;
116
+ size_t m_allocSize = 0;
117
+ size_t m_currOff = 0;
118
+ };
119
+
120
+ inline PrefixAllocator::~PrefixAllocator()
121
+ {
122
+ for (auto p : m_buffers) {
123
+ // Prefix is a POD, and are allocated without initializing
124
+ // to prevent redundant work upfront
125
+ // delete[] p;
126
+ free(p);
127
+ }
128
+ }
129
+
130
+ inline void PrefixAllocator::AllocateNextBuffer()
131
+ {
132
+ size_t nextSize = m_allocSize == 0 ? 1000 : 2 * m_allocSize;
133
+
134
+ // Using malloc here to prevent the ctor of Prefix being called for each item.
135
+ // Instead, the ctor will be called upon first access using GetPrefix
136
+ auto pBuff = reinterpret_cast<Prefix*>(malloc(sizeof(Prefix) * nextSize));
137
+
138
+ m_buffers.push_back(pBuff);
139
+
140
+ m_allocSize = nextSize;
141
+ m_currOff = 0;
142
+ }
143
+
144
+ template<typename ...Args>
145
+ Prefix *PrefixAllocator::GetPrefix(Args&& ...ctorArgs)
146
+ {
147
+ if (m_currOff == m_allocSize) {
148
+ AllocateNextBuffer();
149
+ }
150
+
151
+ auto buff = m_buffers.back() + m_currOff;
152
+
153
+ auto ret = new (buff) Prefix(std::forward<Args>(ctorArgs)...);
154
+
155
+ ++m_currOff;
156
+
157
+ return ret;
158
+ }
nemo-retriever-ocr/cpp/beam_decode/sbo_lm.cpp ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "sbo_lm.h"
6
+
7
+ #include <assert.h>
8
+
9
+ // Reference paper: https://www.aclweb.org/anthology/D07-1090.pdf
10
+
11
+
12
+ SBO_LanguageModel::SBO_LanguageModel(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoff)
13
+ : NGramLMBase(dataFilePath, move(tokenMapping)), m_backoff(backoff)
14
+ {
15
+ }
16
+
17
+ float SBO_LanguageModel::ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const
18
+ {
19
+ auto lIter = m_lookup[prefix.size() + 1].find(prefix);
20
+
21
+ // This prefix doesn't exist. Shrink it!
22
+ if (lIter == m_lookup[prefix.size() + 1].end()) {
23
+ return m_backoff * ScoreTransitionImpl({ begin(prefix) + 1, end(prefix) }, suffix);
24
+ }
25
+
26
+ const suffix_map_t &suffixMap = lIter->second;
27
+
28
+ auto sfIter = suffixMap.find(suffix);
29
+
30
+ if (sfIter == suffixMap.end()) {
31
+ // This is a novel character entirely!
32
+ if (prefix.empty()) {
33
+ return 1e-8;
34
+ } else {
35
+ return m_backoff * ScoreTransitionImpl({ begin(prefix) + 1, end(prefix) }, suffix);
36
+ }
37
+ }
38
+
39
+ float_t ctSuffix = sfIter->second;
40
+ float_t ctNgram = GetPrefixSum(prefix);
41
+
42
+ float_t score = ctSuffix / ctNgram;
43
+
44
+ assert(score >= 0 && score <= 1);
45
+
46
+ return score;
47
+ }
nemo-retriever-ocr/cpp/beam_decode/sbo_lm.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include "kn_lm.h"
8
+
9
+
10
+ class SBO_LanguageModel
11
+ : public NGramLMBase
12
+ {
13
+ public:
14
+ SBO_LanguageModel(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoff);
15
+
16
+ protected:
17
+ virtual float_t ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const override;
18
+
19
+ private:
20
+ float_t m_backoff;
21
+ };
nemo-retriever-ocr/cpp/better_grid_sample/cpu_indirect_grid_sample.cpp ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "grid_sample.h"
6
+ #include "gpu_grid_sample_utils.cuh"
7
+
8
+ template<typename T>
9
+ void indirect_grid_sample_forward_bilinear(torch::TensorAccessor<T, 4> input,
10
+ torch::TensorAccessor<T, 4> grid,
11
+ torch::TensorAccessor<int64_t, 1> inputIndices,
12
+ torch::TensorAccessor<T, 4> output)
13
+ {
14
+ const int64_t N = inputIndices.size(0);
15
+ const int64_t C = output.size(1);
16
+
17
+ T fInputHeight = input.size(2);
18
+ T fInputWidth = input.size(3);
19
+ int64_t outputHeight = output.size(2);
20
+ int64_t outputWidth = output.size(3);
21
+
22
+ #pragma omp parallel for num_threads(8)
23
+ for (int64_t i = 0; i < N; ++i) {
24
+ int64_t inputIdx = inputIndices[i];
25
+
26
+ for (int64_t c = 0; c < C; ++c) {
27
+ for (int64_t outY = 0; outY < outputHeight; ++outY) {
28
+ for (int64_t outX = 0; outX < outputWidth; ++outX) {
29
+ T u = grid[i][outY][outX][0];
30
+ T v = grid[i][outY][outX][1];
31
+
32
+ if (u < -1 || u > 1 || v < -1 || v > 1) {
33
+ output[i][c][outY][outX] = 0;
34
+ continue;
35
+ }
36
+
37
+ // Denormalize the coordinates
38
+ u = (u + 1) * ((fInputWidth - 1) / 2);
39
+ v = (v + 1) * ((fInputHeight - 1) / 2);
40
+
41
+ // Calculate coordinates
42
+ const T inX = u;
43
+ const T inXint = std::floor(inX);
44
+ const T inXfrac = inX - inXint;
45
+
46
+ const T inY = v;
47
+ const T inYint = std::floor(inY);
48
+ const T inYfrac = inY - inYint;
49
+
50
+ T ps[] = { 1 - inXfrac, inXfrac };
51
+ T rs[] = { 1 - inYfrac, inYfrac };
52
+ T opVal = 0;
53
+
54
+ #pragma unroll
55
+ for (int64_t row = 0; row < 2; ++row) {
56
+ #pragma unroll
57
+ for (int64_t col = 0; col < 2; ++col) {
58
+ T Tpx = utils::get_pixel_clamped(input, inputIdx, c, inXint + col, inYint + row);
59
+ opVal += rs[row] * ps[col] * Tpx;
60
+ }
61
+ }
62
+
63
+ output[i][c][outY][outX] = opVal;
64
+ }
65
+ }
66
+ }
67
+ }
68
+ }
69
+
70
+ torch::Tensor cpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid,
71
+ torch::Tensor inputIndices, const std::string &method)
72
+ {
73
+ auto output = input.new_empty({ inputIndices.size(0), input.size(1), grid.size(1), grid.size(2) });
74
+
75
+ AT_DISPATCH_FLOATING_TYPES(
76
+ input.scalar_type(),
77
+ "cpu_indirect_grid_sample_forward_impl",
78
+ ([&] {
79
+ typedef scalar_t T;
80
+ if (method == "bilinear") {
81
+ indirect_grid_sample_forward_bilinear(
82
+ input.accessor<T, 4>(),
83
+ grid.accessor<T, 4>(),
84
+ inputIndices.accessor<int64_t, 1>(),
85
+ output.accessor<T, 4>()
86
+ );
87
+ } else {
88
+ throw std::runtime_error("Unsupported resample method: " + method);
89
+ }
90
+ })
91
+ );
92
+
93
+ return output;
94
+ }
nemo-retriever-ocr/cpp/better_grid_sample/gpu_grid_sample_utils.cuh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <torch/torch.h>
8
+
9
+ #include "../cuda_intellisense.cuh"
10
+
11
+ #ifndef __NVCC__
12
+ #include <algorithm>
13
+ #define __device__
14
+ #endif
15
+
16
+ namespace utils {
17
+
18
+ #ifdef __NVCC__
19
+
20
+ template<typename T>
21
+ __device__ __lib_inline__
22
+ T clamp(T val, T minVal, T maxVal)
23
+ {
24
+ return max(minVal, min(val, maxVal));
25
+ }
26
+
27
+ #else
28
+ using std::clamp;
29
+ #endif
30
+
31
+ template<typename accessor_t>
32
+ __device__ __lib_inline__
33
+ auto &get_pixel_clamped(accessor_t &inputs,
34
+ int64_t n, int64_t c, int64_t x, int64_t y)
35
+ {
36
+ x = clamp<decltype(x)>(x, 0, inputs.size(3) - 1);
37
+ y = clamp<decltype(y)>(y, 0, inputs.size(2) - 1);
38
+
39
+ return inputs[n][c][y][x];
40
+ }
41
+
42
+ }
nemo-retriever-ocr/cpp/better_grid_sample/gpu_indirect_grid_sample.cu ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "grid_sample.h"
6
+
7
+ #include "../cuda_intellisense.cuh"
8
+ #include "../half_ops.cuh"
9
+ #include "gpu_grid_sample_utils.cuh"
10
+
11
+ using namespace std;
12
+
13
+ template<typename accessor_t, typename index_t>
14
+ __device__ __lib_inline__
15
+ auto &my_get_pixel_clamped(accessor_t &inputs, index_t x, index_t y)
16
+ {
17
+ x = utils::clamp(x, 0, inputs.size(1) - 1);
18
+ y = utils::clamp(y, 0, inputs.size(0) - 1);
19
+
20
+ return inputs[y][x];
21
+ }
22
+
23
+ __global__
24
+ void single_ex_grid_sample_bilinear_kernel(const float *pInputImage,
25
+ uint32_t imgHeight, uint32_t imgWidth, uint32_t numChannels,
26
+ const float2 *pGrid,
27
+ uint32_t numGridCells,
28
+ float *pOutputImage)
29
+ {
30
+ const uint32_t z = blockDim.x * blockIdx.x + threadIdx.x;
31
+ const uint32_t c = blockDim.y * blockIdx.y + threadIdx.y;
32
+
33
+ if (c >= numChannels || z >= numGridCells) {
34
+ return;
35
+ }
36
+
37
+ const uint32_t g = blockIdx.z;
38
+
39
+ const float2 uv = pGrid[g * numGridCells + z];
40
+
41
+ float &outPx = pOutputImage[(g * numChannels + c) * numGridCells + z];
42
+ if (abs(uv.x) > 1.0f || abs(uv.y) > 1.0f) {
43
+ outPx = 0.0f;
44
+ } else {
45
+ const uint32_t maxX = imgWidth - 1;
46
+ const uint32_t maxY = imgHeight - 1;
47
+
48
+ const float u = (uv.x + 1.0f) * maxX * 0.5f;
49
+ const float v = (uv.y + 1.0f) * maxY * 0.5f;
50
+
51
+ // calculate coordinates
52
+ const float inX = u;
53
+ const uint32_t inXint = inX;
54
+ const float inXfrac = inX - inXint;
55
+
56
+ const float inY = v;
57
+ const uint32_t inYint = inY;
58
+ const float inYfrac = inY - inYint;
59
+
60
+ const float *pChanImage = pInputImage + c * imgHeight * imgWidth;
61
+
62
+ // By being in this conditional block, we know that u and v are >= 0, which means
63
+ // that their truncated value is also >= 0. Instead of clamping the value to within the buffer,
64
+ // we set the multiplication factor to be 0 if the interpolated value is outside the buffer
65
+ const float ps[] = { 1.0f - inXfrac, inXfrac * (inXint < maxX) };
66
+ const float rs[] = { 1.0f - inYfrac, inYfrac * (inYint < maxY) };
67
+ float opVal = 0.0f;
68
+ #pragma unroll
69
+ for (uint32_t row = 0; row < 2; ++row) {
70
+ const float *pRowImage = pChanImage + (inYint + row) * imgWidth;
71
+
72
+ #pragma unroll
73
+ for (uint32_t col = 0; col < 2; ++col) {
74
+ const float px = pRowImage[inXint + col];
75
+ opVal += rs[row] * ps[col] * px;
76
+ }
77
+ }
78
+
79
+ outPx = opVal;
80
+ }
81
+ }
82
+
83
+ template<typename T>
84
+ __global__
85
+ void indirect_grid_sample_forward_bilinear_kernel(torch::PackedTensorAccessor32<T, 4> inputs,
86
+ torch::PackedTensorAccessor32<T, 4> grid,
87
+ torch::PackedTensorAccessor32<int64_t, 1> inputIndices,
88
+ torch::PackedTensorAccessor32<T, 4> outputs)
89
+ {
90
+ static_assert(std::is_same<T, float>::value, "Currently only float32 is supported!");
91
+ //typedef typename fp_promote<T>::type accum_t;
92
+ typedef float accum_t;
93
+ constexpr T NEG_ONE = -1;
94
+ constexpr T ONE = 1;
95
+ constexpr T ZERO = 0;
96
+ constexpr T TWO = 2;
97
+ constexpr T ZERO_PT_5 = 0.5;
98
+ typedef decltype(inputs.stride(0)) index_t;
99
+
100
+ const index_t n = blockDim.z * blockIdx.z + threadIdx.z;
101
+
102
+ if (n >= inputIndices.size(0)) return;
103
+
104
+ const index_t c = blockDim.y * blockIdx.y + threadIdx.y;
105
+
106
+ const index_t z = blockDim.x * blockIdx.x + threadIdx.x;
107
+
108
+ const accum_t inputHeight = inputs.size(2);
109
+ const accum_t inputWidth = inputs.size(3);
110
+ const index_t outputHeight = outputs.size(2);
111
+ const index_t outputWidth = outputs.size(3);
112
+
113
+ const index_t outY = z / outputWidth;
114
+ //const index_t outX = z % outputWidth;
115
+ const index_t outX = z - (outY * outputWidth);
116
+
117
+ if (outY >= outputHeight) return;
118
+
119
+ index_t inputIdx = inputIndices[n];
120
+ const float2 f2uv = *reinterpret_cast<const float2*>(grid[n][outY][outX].data());
121
+ float u = f2uv.x;
122
+ float v = f2uv.y;
123
+
124
+ if (u < NEG_ONE || u > ONE || v < NEG_ONE || v > ONE) {
125
+ outputs[n][c][outY][outX] = ZERO;
126
+ return;
127
+ }
128
+
129
+ // Denormalize the coordinates
130
+ u = (u + ONE) * ((inputWidth - ONE) * ZERO_PT_5);
131
+ v = (v + ONE) * ((inputHeight - ONE) * ZERO_PT_5);
132
+
133
+ // calculate coordinates
134
+ const accum_t inX = u;
135
+ const index_t inXint = inX;
136
+ const accum_t inXfrac = inX - inXint;
137
+
138
+ const accum_t inY = v;
139
+ const index_t inYint = inY;
140
+ const accum_t inYfrac = inY - inYint;
141
+
142
+ accum_t ps[] = { ONE - inXfrac, inXfrac };
143
+ accum_t rs[] = { ONE - inYfrac, inYfrac };
144
+ accum_t opVal = ZERO;
145
+
146
+ auto localInputs = inputs[inputIdx][c];
147
+
148
+ #pragma unroll
149
+ for (index_t row = 0; row < 2; ++row) {
150
+ #pragma unroll
151
+ for (index_t col = 0; col < 2; ++col) {
152
+ T Tpx = my_get_pixel_clamped(localInputs, inXint + col, inYint + row);
153
+ opVal += rs[row] * ps[col] * Convert<T, accum_t>::LeftToRight(Tpx);
154
+ }
155
+ }
156
+
157
+ outputs[n][c][outY][outX] = Convert<T, accum_t>::RightToLeft(opVal);
158
+ }
159
+
160
+ template<typename T>
161
+ __global__
162
+ void indirect_grid_sample_backward_bilinear_kernel(torch::PackedTensorAccessor64<T, 4> inputs,
163
+ torch::PackedTensorAccessor64<T, 4> grid,
164
+ torch::PackedTensorAccessor64<int64_t, 1> inputIndices,
165
+ torch::PackedTensorAccessor64<T, 4> gradOutput,
166
+ torch::PackedTensorAccessor64<T, 4> gradInput,
167
+ torch::PackedTensorAccessor64<T, 4> gradGrid)
168
+ {
169
+ typedef typename fp_promote<T>::type accum_t;
170
+ constexpr T NEG_ONE = -1;
171
+ constexpr T ONE = 1;
172
+
173
+ const int64_t n = blockDim.z * blockIdx.z + threadIdx.z;
174
+
175
+ if (n >= inputIndices.size(0)) return;
176
+
177
+ const int64_t c = blockDim.y * blockIdx.y + threadIdx.y;
178
+
179
+ const int64_t z = blockDim.x * blockIdx.x + threadIdx.x;
180
+
181
+ const accum_t inputHeight = inputs.size(2);
182
+ const accum_t inputWidth = inputs.size(3);
183
+ const int64_t outputHeight = gradOutput.size(2);
184
+ const int64_t outputWidth = gradOutput.size(3);
185
+
186
+ const int64_t outY = z / outputWidth;
187
+ const int64_t outX = z % outputWidth;
188
+
189
+ if (outY >= outputHeight) return;
190
+
191
+ int64_t inputIdx = inputIndices[n];
192
+ const float2 f2uv = *reinterpret_cast<const float2*>(grid[n][outY][outX].data());
193
+ float u = f2uv.x;
194
+ float v = f2uv.y;
195
+
196
+ // No output gradient contribution from this position
197
+ if (u < NEG_ONE || u > ONE || v < NEG_ONE || v > ONE) {
198
+ return;
199
+ }
200
+
201
+ // Denormalize the coordinates
202
+ u = (u + 1) * ((inputWidth - 1) / 2);
203
+ v = (v + 1) * ((inputHeight - 1) / 2);
204
+
205
+ // calculate coordinates
206
+ const accum_t inX = u;
207
+ const accum_t inXint = floor(inX);
208
+ const accum_t inXfrac = inX - inXint;
209
+
210
+ const accum_t inY = v;
211
+ const accum_t inYint = floor(inY);
212
+ const accum_t inYfrac = inY - inYint;
213
+
214
+ accum_t ps[] = { 1 - inXfrac, inXfrac };
215
+ accum_t rs[] = { 1 - inYfrac, inYfrac };
216
+
217
+ const accum_t gOut = Convert<T, accum_t>::LeftToRight(gradOutput[n][c][outY][outX]);
218
+
219
+ #pragma unroll
220
+ for (size_t row = 0; row < 2; ++row) {
221
+ #pragma unroll
222
+ for (size_t col = 0; col < 2; ++col) {
223
+ T &gIn = utils::get_pixel_clamped(gradInput, inputIdx, c, inXint + col, inYint + row);
224
+
225
+ T gContrib = Convert<T, accum_t>::RightToLeft(rs[row] * ps[col] * gOut);
226
+
227
+ atomicAdd(&gIn, gContrib);
228
+ }
229
+ }
230
+ }
231
+
232
+ torch::Tensor gpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method)
233
+ {
234
+ auto output = input.new_empty({ inputIndices.size(0), input.size(1), grid.size(1), grid.size(2) });
235
+
236
+
237
+ if (method != "bilinear"s) {
238
+ throw runtime_error("Only 'bilinear' sampling is currently supported!");
239
+ }
240
+
241
+ if (input.size(0) == 1 && input.is_contiguous() && grid.is_contiguous()) {
242
+ uint32_t gridNumCells = grid.size(1) * grid.size(2);
243
+ dim3 blockDim(32, 3, 1);
244
+ dim3 gridDim(div_up(gridNumCells, blockDim.x),
245
+ div_up(input.size(1), blockDim.y),
246
+ div_up(grid.size(0), blockDim.z));
247
+ single_ex_grid_sample_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) (
248
+ input.data_ptr<float>(),
249
+ input.size(2), input.size(3), input.size(1),
250
+ reinterpret_cast<const float2*>(grid.data_ptr()),
251
+ gridNumCells,
252
+ output.data_ptr<float>()
253
+ );
254
+
255
+ } else {
256
+ // z is batch idx
257
+ // y is channel
258
+ // x is w*h
259
+ dim3 blockDim(32, 1, 3);
260
+ dim3 gridDim(div_up(grid.size(1) * grid.size(2), blockDim.x),
261
+ div_up(input.size(1), blockDim.y),
262
+ div_up(inputIndices.size(0), blockDim.z));
263
+ indirect_grid_sample_forward_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) (
264
+ input.packed_accessor32<float, 4>(),
265
+ grid.packed_accessor32<float, 4>(),
266
+ inputIndices.packed_accessor32<int64_t, 1>(),
267
+ output.packed_accessor32<float, 4>()
268
+ );
269
+ }
270
+
271
+ //AT_DISPATCH_FLOATING_TYPES_AND_HALF(
272
+ // input.scalar_type(),
273
+ // "gpu_indirect_grid_sample_forward",
274
+ // ([&] {
275
+ // typedef typename remap_half<scalar_t>::type T;
276
+ // // typedef scalar_t T;
277
+ // if (method == "bilinear") {
278
+ // indirect_grid_sample_forward_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) (
279
+ // input.packed_accessor64<T, 4>(),
280
+ // grid.packed_accessor64<T, 4>(),
281
+ // inputIndices.packed_accessor64<int64_t, 1>(),
282
+ // output.packed_accessor64<T, 4>()
283
+ // );
284
+ // } else {
285
+ // throw runtime_error("Unsupported resample method: " + method);
286
+ // }
287
+ // })
288
+ //);
289
+
290
+ return output;
291
+ }
292
+
293
+ std::vector<torch::Tensor> gpu_indirect_grad_sample_backward(torch::Tensor gradOutput, torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method)
294
+ {
295
+ auto gradInput = torch::zeros_like(input);
296
+ auto gradGrid = torch::zeros_like(grid);
297
+
298
+ // z is batch idx
299
+ // y is channel
300
+ // x is w*h
301
+ dim3 blockDim(32, 1, 1);
302
+ dim3 gridDim(div_up(grid.size(1) * grid.size(2), blockDim.x),
303
+ div_up(input.size(1), blockDim.y),
304
+ div_up(inputIndices.size(0), blockDim.z));
305
+
306
+ AT_DISPATCH_FLOATING_TYPES(
307
+ input.scalar_type(),
308
+ "gpu_indirect_grid_sample_backward",
309
+ ([&] {
310
+ typedef typename remap_half<scalar_t>::type T;
311
+ // typedef scalar_t T;
312
+ if (method == "bilinear") {
313
+ indirect_grid_sample_backward_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) (
314
+ input.packed_accessor64<T, 4>(),
315
+ grid.packed_accessor64<T, 4>(),
316
+ inputIndices.packed_accessor64<int64_t, 1>(),
317
+ gradOutput.packed_accessor64<T, 4>(),
318
+ gradInput.packed_accessor64<T, 4>(),
319
+ gradGrid.packed_accessor64<T, 4>()
320
+ );
321
+ } else {
322
+ throw runtime_error("Unsupported resample method: " + method);
323
+ }
324
+ })
325
+ );
326
+
327
+ return { gradInput, gradGrid };
328
+ }
nemo-retriever-ocr/cpp/better_grid_sample/grid_sample.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <torch/torch.h>
8
+
9
+ inline
10
+ torch::Tensor region_counts_to_indices(torch::Tensor regionCounts, int64_t numOutputs)
11
+ {
12
+ // If there's only one example, we can trivially return idx 0 for all
13
+ if (regionCounts.size(0) == 1) {
14
+ return torch::zeros({ numOutputs }, regionCounts.options().dtype(torch::kInt64));
15
+ }
16
+
17
+ // regionCounts will be some tensor like [ 5, 1, 10, 2 ] which means that the first 5 outputs
18
+ // correspond to the first input, the next output to the second input, 10 to the third, and so on.
19
+
20
+ // We want to convert this to instead have an entry for each output which specifies the index of the corresponding input.
21
+ // To do this, we can count the number of times the output index exceeds the cumulative input counts.
22
+ // e.g. the cumulative region count for the above tensor is [ 5, 6, 16, 18 ].
23
+ // The output indices 0-4 are not greater than or equal to any cumulative count, so they get the input index of 0.
24
+ // The output index 5 is equal to a single count, therefore index 1.
25
+ // The outputs 6-15 are all greater than or equal to two cumulative counts, therefore index 2.
26
+ // And so on.
27
+
28
+ auto indices = torch::arange(regionCounts.size(0), regionCounts.options().dtype(torch::kInt64));
29
+
30
+ auto outputIndices = torch::repeat_interleave(indices, regionCounts, /*dim=*/ 0, /*output_size=*/ numOutputs);
31
+
32
+ return outputIndices;
33
+ }
34
+
35
+ torch::Tensor gpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method);
36
+ torch::Tensor cpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method);
37
+ std::vector<torch::Tensor> gpu_indirect_grad_sample_backward(torch::Tensor gradOutput, torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method);
38
+
39
+ inline
40
+ torch::Tensor indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method)
41
+ {
42
+ if (input.is_cuda() != grid.is_cuda() || input.is_cuda() != inputIndices.is_cuda()) {
43
+ throw std::runtime_error("Input tensors must all be on the same device!");
44
+ }
45
+ if (inputIndices.size(0) != grid.size(0)) {
46
+ throw std::runtime_error("The batch dimensions must match!");
47
+ }
48
+ if (grid.size(-1) != 2) {
49
+ throw std::runtime_error("The final grid dimension must be 2.");
50
+ }
51
+
52
+ if (input.is_cuda()) {
53
+ return gpu_indirect_grid_sample_forward(std::move(input), std::move(grid), std::move(inputIndices), method);
54
+ } else {
55
+ return cpu_indirect_grid_sample_forward(std::move(input), std::move(grid), std::move(inputIndices), method);
56
+ }
57
+ }
58
+
59
+ inline
60
+ std::vector<torch::Tensor> indirect_grad_sample_backward(torch::Tensor gradOutput, torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method)
61
+ {
62
+ if (gradOutput.is_cuda()) {
63
+ return gpu_indirect_grad_sample_backward(std::move(gradOutput), std::move(input), std::move(grid), std::move(inputIndices), method);
64
+ } else {
65
+ throw std::runtime_error("Not implemented!");
66
+ }
67
+ }
nemo-retriever-ocr/cpp/common.cpp ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "common.h"
6
+
7
+ #include <sstream>
8
+
9
+ using namespace std;
10
+
11
+ void print_tensor(const torch::Tensor &t) {
12
+ cout << t << endl;
13
+ }
nemo-retriever-ocr/cpp/common.h ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <ostream>
8
+ #include <vector>
9
+
10
+ #include <torch/torch.h>
11
+
12
+ template<typename T>
13
+ inline
14
+ std::ostream &operator<<(std::ostream &os, const std::vector<T> &v) {
15
+ os << "[";
16
+ if (! v.empty()) {
17
+ os << v[0];
18
+ for (size_t i = 1; i < v.size(); ++i) {
19
+ os << ", " << v[i];
20
+ }
21
+ }
22
+ os << "]";
23
+ return os;
24
+ }
25
+
26
+ template<int Counter, typename ...Args>
27
+ struct _inner_tuple_print
28
+ {
29
+ inline
30
+ static std::ostream &print(std::ostream &os, const std::tuple<Args...> &t) {
31
+ _inner_tuple_print<Counter - 1, Args...>::print(os, t);
32
+
33
+ os << ", " << std::get<Counter>(t);
34
+ return os;
35
+ }
36
+ };
37
+
38
+ template<typename ...Args>
39
+ struct _inner_tuple_print<0, Args...>
40
+ {
41
+ inline
42
+ static std::ostream &print(std::ostream &os, const std::tuple<Args...> &t) {
43
+ os << std::get<0>(t);
44
+ return os;
45
+ }
46
+ };
47
+
48
+
49
+ template<typename... Args>
50
+ inline
51
+ std::ostream &operator<<(std::ostream &os, const std::tuple<Args...> &t) {
52
+ os << "(";
53
+ _inner_tuple_print<sizeof...(Args) - 1, Args...>::print(os, t);
54
+ os << ")";
55
+ return os;
56
+ }
57
+
58
+ void print_tensor(const torch::Tensor &t);
nemo-retriever-ocr/cpp/cuda_intellisense.cuh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #if defined(__INTELLISENSE__) || !defined(__NVCC__)
8
+ #ifndef KERNEL_ARG2
9
+ #define KERNEL_ARG2(grid, block)
10
+ #define KERNEL_ARG3(grid, block, sh_mem)
11
+ #define KERNEL_ARG4(grid, block, sh_mem, stream)
12
+ #define __global__
13
+ #define __device__
14
+ #define __host__
15
+ #endif
16
+ #endif
17
+
18
+ #ifdef __INTELLISENSE__
19
+ #define __CUDACC__
20
+ #include <cuda_runtime.h>
21
+
22
+ void __syncthreads(); // workaround __syncthreads warning
23
+
24
+ dim3 threadIdx;
25
+ dim3 blockIdx;
26
+ dim3 blockDim;
27
+ dim3 gridDim;
28
+
29
+ #else
30
+ #ifndef KERNEL_ARG2
31
+ #define KERNEL_ARG2(grid, block) <<< grid, block >>>
32
+ #define KERNEL_ARG3(grid, block, sh_mem) <<< grid, block, sh_mem >>>
33
+ #define KERNEL_ARG4(grid, block, sh_mem, stream) <<< grid, block, sh_mem, stream >>>
34
+ #endif
35
+ #endif
36
+
37
+ #define __any_device__ __host__ __device__
38
+
39
+ #ifdef __NVCC__
40
+ #define __lib_inline__ __forceinline__
41
+
42
+ #else
43
+ #define __lib_inline__ inline
44
+ #endif
45
+
46
+ template<typename T1, typename T2>
47
+ __any_device__
48
+ inline auto div_up(T1 n, T2 d)
49
+ {
50
+ return (n + d - 1) / d;
51
+ }
nemo-retriever-ocr/cpp/geometry.h ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <algorithm>
8
+ #include <cmath>
9
+ #include <iostream>
10
+ #include <type_traits>
11
+
12
+ #ifndef _GEOMETRY_NO_TORCH
13
+ #include <torch/torch.h>
14
+ #endif
15
+
16
+ #include "cuda_intellisense.cuh"
17
+
18
+ #ifndef __NVCC__
19
+ #define SORT_ALGO std::sort
20
+ #define SWAP std::swap
21
+
22
+ template<typename ...Args>
23
+ using tuple_t = std::tuple<Args...>;
24
+
25
+ #else
26
+
27
+ #include <thrust/sort.h>
28
+ #include <thrust/tuple.h>
29
+
30
+ #define SORT_ALGO thrust::sort
31
+ #define SWAP thrust::swap
32
+
33
+ template<typename ...Args>
34
+ using tuple_t = thrust::tuple<Args...>;
35
+ #endif
36
+
37
+ template<typename T>
38
+ struct Point_ {
39
+ typedef T inner_type;
40
+
41
+ T X, Y;
42
+
43
+ Point_() = default;
44
+
45
+ __any_device__
46
+ Point_(T x, T y) : X(x), Y(y) {}
47
+
48
+ __any_device__
49
+ Point_(T *ptr) : X(ptr[0]), Y(ptr[1]) {}
50
+
51
+ #ifndef _GEOMETRY_NO_TORCH
52
+ template<typename T2>
53
+ __any_device__
54
+ Point_(const torch::TensorAccessor<T2, 1> &accessor) : X(accessor[0]), Y(accessor[1]) {}
55
+
56
+ template<typename T2>
57
+ __any_device__
58
+ Point_(const torch::PackedTensorAccessor64<T2, 1> &accessor) : X(accessor[0]), Y(accessor[1]) {}
59
+ #endif
60
+
61
+ __any_device__
62
+ Point_ &operator+=(const Point_ &other);
63
+
64
+ __any_device__
65
+ Point_ &operator-=(const Point_ &other);
66
+
67
+ __any_device__
68
+ Point_ &operator*=(const Point_ &other);
69
+
70
+ __any_device__
71
+ Point_ &operator/=(const Point_ &other);
72
+
73
+ template<typename W>
74
+ __any_device__
75
+ Point_ &operator/=(W w);
76
+
77
+ template<typename W>
78
+ __any_device__
79
+ Point_ &operator*=(W w);
80
+
81
+ __any_device__
82
+ Point_ operator-() {
83
+ return { -X, -Y };
84
+ }
85
+
86
+ __any_device__
87
+ T Sum() const { return X + Y; }
88
+
89
+ __any_device__
90
+ T Angle() const;
91
+
92
+ __any_device__
93
+ void swap(Point_ &other) noexcept {
94
+ SWAP(X, other.X);
95
+ SWAP(Y, other.Y);
96
+ }
97
+ };
98
+
99
+ template<typename T>
100
+ __lib_inline__ __any_device__
101
+ void swap(Point_<T> &a, Point_<T> &b) {
102
+ a.swap(b);
103
+ }
104
+
105
+
106
+ template<typename T>
107
+ __any_device__
108
+ __lib_inline__ T Point_<T>::Angle() const {
109
+ #ifndef __NVCC__
110
+ using std::atan2;
111
+ #endif
112
+ return atan2(Y, X);
113
+ }
114
+
115
+ template<typename T>
116
+ __any_device__
117
+ __lib_inline__ Point_<T> min(const Point_<T> &a, const Point_<T> &b) {
118
+ #ifndef __NVCC__
119
+ using std::min;
120
+ #endif
121
+ return {
122
+ min(a.X, b.X),
123
+ min(a.Y, b.Y)
124
+ };
125
+ }
126
+
127
+ template<typename T>
128
+ __any_device__
129
+ __lib_inline__ Point_<T> max(const Point_<T> &a, const Point_<T> &b) {
130
+ #ifndef __NVCC__
131
+ using std::max;
132
+ #endif
133
+ return {
134
+ max(a.X, b.X),
135
+ max(a.Y, b.Y)
136
+ };
137
+ }
138
+
139
+ template<typename T>
140
+ struct AABB_ {
141
+ typedef T inner_type;
142
+
143
+ T X;
144
+ T Y;
145
+ T MaxX;
146
+ T MaxY;
147
+
148
+ AABB_() = default;
149
+ __any_device__
150
+ AABB_(T x, T y, T maxX, T maxY)
151
+ : X(x), Y(y), MaxX(maxX), MaxY(maxY) {}
152
+
153
+ __any_device__
154
+ bool Contains(const Point_<T> &p) const {
155
+ return p.X >= X && p.X < MaxX &&
156
+ p.Y >= Y && p.Y < MaxY;
157
+ }
158
+
159
+ __any_device__ __lib_inline__
160
+ AABB_ Union(const AABB_ &other) const {
161
+ #ifndef __NVCC__
162
+ using std::min;
163
+ using std::max;
164
+ #endif
165
+ T minX = min(X, other.X);
166
+ T maxX = max(MaxX, other.MaxX);
167
+ T minY = min(Y, other.Y);
168
+ T maxY = max(MaxY, other.MaxY);
169
+
170
+ return { minX, minY, maxX, maxY };
171
+ }
172
+
173
+ __any_device__
174
+ AABB_ &operator-=(const Point_<T> &offset) {
175
+ X -= offset.X;
176
+ MaxX -= offset.X;
177
+ Y -= offset.Y;
178
+ MaxY -= offset.Y;
179
+ return *this;
180
+ }
181
+
182
+ __any_device__
183
+ __lib_inline__ T Width() const { return MaxX - X; }
184
+ __any_device__
185
+ __lib_inline__ T Height() const { return MaxY - Y; }
186
+ __any_device__
187
+ __lib_inline__ T Area() const { return Width() * Height(); }
188
+
189
+ __lib_inline__ T &operator[] (int64_t idx)
190
+ {
191
+ static_assert(std::is_standard_layout<AABB_<T>>::value, "This function is only valid for standard layout");
192
+ return (&X)[idx];
193
+ }
194
+ __lib_inline__ T operator[] (int64_t idx) const
195
+ {
196
+ static_assert(std::is_standard_layout<AABB_<T>>::value, "This function is only valid for standard layout");
197
+ return (&X)[idx];
198
+ }
199
+
200
+ __any_device__ __lib_inline__
201
+ AABB_ Intersection(const AABB_ &other) const {
202
+ #ifndef __NVCC__
203
+ using std::min;
204
+ using std::max;
205
+ #endif
206
+ T minX = max(X, other.X);
207
+ T minY = max(Y, other.Y);
208
+ T maxX = min(MaxX, other.MaxX);
209
+ T maxY = min(MaxY, other.MaxY);
210
+ // Prevent negative area
211
+ minX = min(minX, maxX);
212
+ minY = min(minY, maxY);
213
+ return { minX, minY, maxX, maxY };
214
+ }
215
+
216
+ __any_device__ __lib_inline__
217
+ T IntersectionArea(const AABB_ &other) const { return Intersection(other).Area(); }
218
+ };
219
+
220
+ template<typename T, typename Derived>
221
+ struct QuadBase_ {
222
+ typedef T inner_type;
223
+
224
+ __any_device__
225
+ AABB_<T> Bounds() const;
226
+
227
+ __any_device__
228
+ bool Contains(const Point_<T> &p) const;
229
+
230
+ __any_device__
231
+ T Area() const;
232
+
233
+ __any_device__
234
+ T Height() const;
235
+
236
+ __any_device__
237
+ T Width() const;
238
+
239
+ template<typename Derived2>
240
+ __any_device__
241
+ T IntersectionArea(const QuadBase_<T, Derived2> &other) const;
242
+
243
+ template<typename Derived2>
244
+ __any_device__
245
+ T IOU(const QuadBase_<T, Derived2> &other) const;
246
+
247
+ template<typename Derived2>
248
+ __any_device__
249
+ T IOU_UpperBound(const QuadBase_<T, Derived2> &other) const;
250
+
251
+ __any_device__
252
+ Point_<T> Center() const;
253
+
254
+ template<typename Derived2>
255
+ __any_device__
256
+ /*
257
+ Returns 3 geometric associations between the two quads:
258
+ 0: The percent shared area between this and other relative to this (e.g. if other contains this, then it returns 1)
259
+ 1: The percent shared area between other and this relative to other (e.g. if this contains other, then it return 1)
260
+ 2: The IOU of the two quads
261
+ */
262
+ tuple_t<T, T, T> RegionSizes(const QuadBase_<T, Derived2> &other) const;
263
+
264
+ template<typename Derived2>
265
+ __any_device__
266
+ tuple_t<T, T, T> RegionSizes_UpperBound(const QuadBase_<T, Derived2> &other) const;
267
+
268
+ __any_device__
269
+ Derived &operator/=(T val) {
270
+ auto rcp = 1 / val;
271
+ return *this *= rcp;
272
+ }
273
+
274
+ __any_device__
275
+ Derived &operator*=(T val) {
276
+ auto dThis = static_cast<Derived*>(this);
277
+ #pragma unroll
278
+ for (size_t i = 0; i < 4; ++i) {
279
+ dThis->Vertices[i] *= val;
280
+ }
281
+ return *dThis;
282
+ }
283
+
284
+ friend auto begin(const QuadBase_ &q) { return static_cast<const Derived&>(q).Vertices; }
285
+ friend auto begin(QuadBase_& q) { return static_cast<const Derived&>(q).Vertices; }
286
+ friend auto end(const QuadBase_ &q) { return static_cast<const Derived&>(q).Vertices + 4; }
287
+ friend auto end(QuadBase_ &q) { return static_cast<const Derived&>(q).Vertices + 4; }
288
+ };
289
+
290
+ template<typename T>
291
+ struct Quad_ : QuadBase_<T, Quad_<T>> {
292
+ Point_<T> *Vertices = nullptr;
293
+
294
+ Quad_() = default;
295
+ __any_device__
296
+ Quad_(T *dataPtr)
297
+ : Vertices(reinterpret_cast<Point_<T>*>(dataPtr)) {}
298
+ __any_device__
299
+ Quad_(Point_<T> *dataPtr)
300
+ : Vertices(dataPtr) {}
301
+
302
+ template<typename index_t>
303
+ __any_device__ __lib_inline__
304
+ const Point_<T> &operator[](index_t offset) const { return Vertices[offset]; }
305
+ template<typename index_t>
306
+ __any_device__ __lib_inline__
307
+ Point_<T> &operator[](index_t offset) { return Vertices[offset]; }
308
+ };
309
+
310
+ template<typename T>
311
+ struct InPlaceQuad_ : public QuadBase_<T, InPlaceQuad_<T>> {
312
+ Point_<T> Vertices[4];
313
+
314
+ InPlaceQuad_() = default;
315
+ __any_device__
316
+ InPlaceQuad_(const T *dataPtr)
317
+ {
318
+ #if defined(__NVCC__)
319
+ T *pVals = reinterpret_cast<T*>(Vertices);
320
+ #pragma unroll
321
+ for (uint32_t i = 0; i < 8; ++i) {
322
+ pVals[i] = dataPtr[i];
323
+ }
324
+ #else
325
+ using std::copy;
326
+ copy(dataPtr, dataPtr + 8, reinterpret_cast<T*>(Vertices));
327
+ #endif
328
+ }
329
+ __any_device__
330
+ InPlaceQuad_(const Point_<T> *dataPtr)
331
+ {
332
+ #if defined(__NVCC__)
333
+ #pragma unroll
334
+ for (uint32_t i = 0; i < 4; ++i) {
335
+ Vertices[i] = dataPtr[i];
336
+ }
337
+ #else
338
+ using std::copy;
339
+ copy(dataPtr, dataPtr + 4, Vertices);
340
+ #endif
341
+ }
342
+
343
+ template<typename index_t>
344
+ __any_device__ __lib_inline__
345
+ const Point_<T> &operator[](index_t v) const { return Vertices[v]; }
346
+
347
+ template<typename index_t>
348
+ __any_device__ __lib_inline__
349
+ Point_<T> &operator[](index_t v) { return Vertices[v]; }
350
+ };
351
+
352
+ template<typename T, typename Derived>
353
+ struct PolygonBase_ {
354
+ typedef T inner_type;
355
+
356
+ __any_device__
357
+ AABB_<T> Bounds() const;
358
+
359
+ __any_device__
360
+ bool Contains(const Point_<T> &p) const;
361
+
362
+ __any_device__
363
+ T EdgeLength() const;
364
+
365
+ __any_device__
366
+ Point_<T> Center() const;
367
+
368
+ __any_device__
369
+ T Area() const;
370
+ };
371
+
372
+ template<typename T>
373
+ struct Polygon_ : PolygonBase_<T, Polygon_<T>> {
374
+ Point_<T> *Vertices = nullptr;
375
+ size_t Count = 0;
376
+
377
+ Polygon_() = default;
378
+ __any_device__
379
+ Polygon_(T *dataPtr, size_t vertexCount)
380
+ : Vertices(reinterpret_cast<Point_<T>*>(dataPtr)), Count(vertexCount) {}
381
+ __any_device__
382
+ Polygon_(Point_<T> *dataPtr, size_t vertexCount)
383
+ : Vertices(dataPtr), Count(vertexCount) {}
384
+
385
+ __any_device__
386
+ const Point_<T> &operator[](size_t offset) const { return Vertices[offset]; }
387
+ __any_device__
388
+ Point_<T> &operator[](size_t offset) { return Vertices[offset]; }
389
+ };
390
+
391
+ template<typename T>
392
+ struct Segment_ {
393
+ Point_<T> A, B;
394
+
395
+ Segment_() = default;
396
+ __any_device__
397
+ Segment_(const Point_<T> &a, const Point_<T> &b) : A(a), B(b) {}
398
+
399
+ __any_device__
400
+ T Length() const;
401
+ __any_device__
402
+ T LengthSq() const;
403
+ __any_device__
404
+ bool Intersection(const Segment_<T> &other, Point_<T> &out_ptAlong) const;
405
+ };
406
+
407
+ template<typename T>
408
+ __any_device__
409
+ __lib_inline__ Point_<T> operator+(const Point_<T> &a, const Point_<T> &b) {
410
+ return { a.X + b.X, a.Y + b.Y };
411
+ }
412
+
413
+ template<typename T>
414
+ __any_device__
415
+ __lib_inline__ Point_<T> operator-(const Point_<T> &a, const Point_<T> &b) {
416
+ return { a.X - b.X, a.Y - b.Y };
417
+ }
418
+
419
+ template<typename T, typename W>
420
+ __any_device__
421
+ __lib_inline__ Point_<T> operator*(W scale, const Point_<T> &p) {
422
+ return { scale * p.X, scale * p.Y };
423
+ }
424
+
425
+ template<typename T, typename W>
426
+ __any_device__
427
+ __lib_inline__ Point_<T> operator*(const Point_<T> &p, W scale) {
428
+ return { scale * p.X, scale * p.Y };
429
+ }
430
+
431
+ template<typename T, typename W>
432
+ __any_device__
433
+ __lib_inline__ Point_<T> operator/(const Point_<T> &p, W divisor) {
434
+ return { p.X / divisor, p.Y / divisor };
435
+ }
436
+
437
+ template<typename T>
438
+ __any_device__
439
+ __lib_inline__ Point_<T> operator*(const Point_<T> &a, const Point_<T> &b) {
440
+ return { a.X * b.X, a.Y * b.Y };
441
+ }
442
+
443
+ template<typename T, typename W>
444
+ __any_device__
445
+ __lib_inline__ Point_<T> operator-(const Point_<T> &p, W v) {
446
+ return { p.X - v, p.Y - v };
447
+ }
448
+
449
+ template<typename T>
450
+ __any_device__
451
+ __lib_inline__ Point_<T> &Point_<T>::operator+=(const Point_<T> &p) {
452
+ X = X + p.X;
453
+ Y = Y + p.Y;
454
+ return *this;
455
+ }
456
+
457
+ template<typename T>
458
+ __any_device__
459
+ __lib_inline__ Point_<T> &Point_<T>::operator-=(const Point_<T> &p) {
460
+ X = X - p.X;
461
+ Y = Y - p.Y;
462
+ return *this;
463
+ }
464
+
465
+ template<typename T>
466
+ __any_device__
467
+ __lib_inline__ Point_<T> &Point_<T>::operator*=(const Point_<T> &p) {
468
+ X = X * p.X;
469
+ Y = Y * p.Y;
470
+ return *this;
471
+ }
472
+
473
+ template<typename T>
474
+ __any_device__
475
+ __lib_inline__ Point_<T> &Point_<T>::operator/=(const Point_<T> &p) {
476
+ X = X / p.X;
477
+ Y = Y / p.Y;
478
+ return *this;
479
+ }
480
+
481
+ template<typename T>
482
+ template<typename W>
483
+ __any_device__
484
+ __lib_inline__ Point_<T> &Point_<T>::operator/=(W val) {
485
+ // TODO: This can be more efficient for float types by computing the reciprocal
486
+ X /= val;
487
+ Y /= val;
488
+ return *this;
489
+ }
490
+
491
+ template<typename T>
492
+ template<typename W>
493
+ __any_device__
494
+ __lib_inline__ Point_<T> &Point_<T>::operator*=(W val) {
495
+ X *= val;
496
+ Y *= val;
497
+ return *this;
498
+ }
499
+
500
+ template<typename T>
501
+ __any_device__
502
+ __lib_inline__ T dot(const Point_<T> &a, const Point_<T> &b) {
503
+ return a.X * b.X + a.Y * b.Y;
504
+ }
505
+
506
+ template<typename T>
507
+ __any_device__
508
+ __lib_inline__ T dot(const Point_<T> &p) {
509
+ return dot(p, p);
510
+ }
511
+
512
+ template<typename T>
513
+ __any_device__
514
+ __lib_inline__ T length(const Point_<T> &p) {
515
+ #ifndef __NVCC__
516
+ using std::sqrt;
517
+ #endif
518
+ return sqrt(dot(p));
519
+ }
520
+
521
+ template<typename T>
522
+ __any_device__
523
+ __lib_inline__ Point_<T> normalize(const Point_<T> &p) {
524
+ static constexpr T epsilon = std::numeric_limits<T>::epsilon();
525
+ auto len = length(p) + epsilon;
526
+ return { p.X / len, p.Y / len };
527
+ }
528
+
529
+ template<typename T>
530
+ __any_device__
531
+ __lib_inline__ Point_<T> ortho_2d(const Point_<T> &p) {
532
+ return { -p.Y, p.X };
533
+ }
534
+
535
+ template<typename T>
536
+ __host__
537
+ __lib_inline__ std::ostream &operator<<(std::ostream &os, const Point_<T> &p) {
538
+ return os << "(" << p.X << ", " << p.Y << ")";
539
+ }
540
+
541
+ template<typename T>
542
+ __host__
543
+ __lib_inline__ std::ostream &operator<<(std::ostream &os, const AABB_<T> &b) {
544
+ return os << "[(" << b.X << ", " << b.Y << "), (" << b.MaxX << ", " << b.MaxY << ")]";
545
+ }
546
+
547
+ template<typename T>
548
+ __host__
549
+ __lib_inline__ std::ostream &operator<<(std::ostream &os, const Segment_<T> &s) {
550
+ return os << "[(" << s.A.X << ", " << s.A.Y << "), (" << s.B.X << ", " << s.B.Y << ")]";
551
+ }
552
+
553
+ template<typename T>
554
+ __host__
555
+ __lib_inline__ std::ostream &operator<<(std::ostream &os, const Quad_<T> &q) {
556
+ os << "[" << q.Vertices[0];
557
+ for (size_t i = 1; i < 4; ++i) {
558
+ os << ", " << q.Vertices[i];
559
+ }
560
+ return os << "]";
561
+ }
562
+
563
+ template<typename T>
564
+ __any_device__
565
+ __lib_inline__ int _signum(T val) {
566
+ return (T(0) < val) - (val < T(0));
567
+ }
568
+
569
+ template<typename T>
570
+ __any_device__
571
+ __lib_inline__ T sign(const Point_<T> &p1, const Point_<T> &p2, const Point_<T> &p3) {
572
+ T ret = (p1.X - p3.X) * (p2.Y - p3.Y) - (p2.X - p3.X) * (p1.Y - p3.Y);
573
+ auto sgn = _signum(ret);
574
+ return sgn;
575
+ }
576
+
577
+ template<typename T>
578
+ __any_device__
579
+ __lib_inline__ T Segment_<T>::Length() const
580
+ {
581
+ #ifndef __NVCC__
582
+ using std::sqrt;
583
+ #endif
584
+ return sqrt(LengthSq());
585
+ }
586
+
587
+ template<typename T>
588
+ __any_device__
589
+ __lib_inline__ T Segment_<T>::LengthSq() const
590
+ {
591
+ return dot(B - A);
592
+ }
593
+
594
+ template<typename T>
595
+ __any_device__
596
+ inline bool Segment_<T>::Intersection(const Segment_<T> &other, Point_<T> &out_ptAlong) const
597
+ {
598
+ auto p1 = A, p2 = B, p3 = other.A, p4 = other.B;
599
+
600
+ auto denom = (p4.Y - p3.Y) * (p2.X - p1.X) - (p4.X - p3.X) * (p2.Y - p1.Y);
601
+
602
+ if (abs(denom) < 1e-8) {
603
+ return false;
604
+ }
605
+
606
+ auto numer = (p4.X - p3.X) * (p1.Y - p3.Y) - (p4.Y - p3.Y) * (p1.X - p3.X);
607
+
608
+ auto t = numer / denom;
609
+
610
+ auto Bnumer = (p2.X - p1.X) * (p1.Y - p3.Y) - (p2.Y - p1.Y) * (p1.X - p3.X);
611
+
612
+ auto Bt = Bnumer / denom;
613
+
614
+ if (t < 0 || t > 1 || Bt < 0 || Bt > 1) {
615
+ return false;
616
+ }
617
+
618
+ out_ptAlong = A + t * (B - A);
619
+
620
+ return true;
621
+ }
622
+
623
+ template<typename quad_t>
624
+ __any_device__
625
+ auto quad_center(const quad_t &quad) -> Point_<typename quad_t::inner_type>
626
+ {
627
+ typedef typename quad_t::inner_type T;
628
+
629
+ Point_<T> center = quad[0];
630
+ for (size_t i = 1; i < 4; ++i) {
631
+ center += quad[i];
632
+ }
633
+
634
+ return center / T{ 4 };
635
+ }
636
+
637
+ template<typename T, typename Derived>
638
+ __any_device__
639
+ Point_<T> QuadBase_<T, Derived>::Center() const {
640
+ return quad_center(static_cast<const Derived&>(*this));
641
+ }
642
+
643
+ template<typename quad_t>
644
+ __any_device__
645
+ auto quad_bounds(const quad_t &quad) -> AABB_<typename quad_t::inner_type>
646
+ {
647
+ #ifndef __NVCC__
648
+ using std::min;
649
+ using std::max;
650
+ #endif
651
+ auto minP = quad[0];
652
+ auto maxP = minP;
653
+ for (size_t i = 1; i < 4; ++i) {
654
+ auto qp = quad[i];
655
+ minP = min(minP, qp);
656
+ maxP = max(maxP, qp);
657
+ }
658
+ return { minP.X, minP.Y, maxP.X, maxP.Y };
659
+ }
660
+
661
+ template<typename T, typename Derived>
662
+ __any_device__
663
+ AABB_<T> QuadBase_<T, Derived>::Bounds() const {
664
+ return quad_bounds(static_cast<const Derived&>(*this));
665
+ }
666
+
667
+ template<typename Quad_t, typename point_t>
668
+ __any_device__
669
+ inline bool quad_contains(const Quad_t &quad, const point_t &pt)
670
+ {
671
+ #ifndef __NVCC__
672
+ using std::abs;
673
+ #endif
674
+
675
+ // Checks that the point lies on the interior side of each half plane
676
+ auto d1 = sign(pt, quad[0], quad[1]);
677
+ auto d2 = sign(pt, quad[1], quad[2]);
678
+ auto d3 = sign(pt, quad[2], quad[3]);
679
+ auto d4 = sign(pt, quad[3], quad[0]);
680
+
681
+ // bool has_neg = (d1 < 0) || (d2 < 0) || (d3 < 0) || (d4 < 0);
682
+ // bool has_pos = (d1 > 0) || (d2 > 0) || (d3 > 0) || (d4 > 0);
683
+ int tot = d1 + d2 + d3 + d4;
684
+
685
+ // return !(has_neg && has_pos);
686
+ return abs(tot) == 4;
687
+ }
688
+
689
+ template<typename T, typename Derived>
690
+ __any_device__
691
+ __lib_inline__ bool QuadBase_<T, Derived>::Contains(const Point_<T> &pt) const
692
+ {
693
+ return quad_contains(static_cast<const Derived&>(*this), pt);
694
+ }
695
+
696
+ template<typename PtList>
697
+ __any_device__
698
+ inline auto shoelace_area(const PtList &points, size_t numPts, bool isSigned=false) -> decltype(points[0].X)
699
+ {
700
+ #ifndef __NVCC__
701
+ using std::abs;
702
+ #endif
703
+
704
+ decltype(points[0].X) area = 0;
705
+
706
+ size_t j = numPts - 1;
707
+ for (size_t i = 0; i < numPts; ++i) {
708
+ auto Pi = points[i];
709
+ auto Pj = points[j];
710
+
711
+ area += (Pj.X + Pi.X) * (Pj.Y - Pi.Y);
712
+ j = i;
713
+ }
714
+
715
+ area = area / 2;
716
+
717
+ if (! isSigned) {
718
+ area = abs(area);
719
+ }
720
+
721
+ return area;
722
+ }
723
+
724
+ template<typename T, typename Derived>
725
+ __any_device__
726
+ __lib_inline__ T QuadBase_<T, Derived>::Height() const
727
+ {
728
+ auto &d = static_cast<const Derived&>(*this);
729
+ auto h1 = Segment_<T>(d[1], d[2]).Length();
730
+ auto h2 = Segment_<T>(d[3], d[0]).Length();
731
+ return (h1 + h2) / 2;
732
+ }
733
+
734
+ template<typename T, typename Derived>
735
+ __any_device__
736
+ __lib_inline__ T QuadBase_<T, Derived>::Width() const
737
+ {
738
+ auto &d = static_cast<const Derived&>(*this);
739
+ auto w1 = Segment_<T>(d[0], d[1]).Length();
740
+ auto w2 = Segment_<T>(d[3], d[2]).Length();
741
+ return (w1 + w2) / 2;
742
+ }
743
+
744
+ // A quad can be defined as the sum of the area of two triangles
745
+ template<typename T, typename Derived>
746
+ __any_device__
747
+ inline T QuadBase_<T, Derived>::Area() const
748
+ {
749
+ // auto vertices = static_cast<const Derived *>(this)->Vertices;
750
+ return shoelace_area(static_cast<const Derived&>(*this), 4);
751
+ }
752
+
753
+ template<typename Quad_t1, typename Quad_t2>
754
+ __any_device__
755
+ inline auto intersection_area(const Quad_t1 &quadsA, const Quad_t2 &quadsB) -> typename Quad_t1::inner_type
756
+ {
757
+ #ifndef __NVCC__
758
+ using std::atan2;
759
+ #endif
760
+
761
+ typedef typename Quad_t1::inner_type T;
762
+
763
+ static const size_t MAX_PTS = 32;
764
+
765
+ Point_<T> points[MAX_PTS], sortedPoints[MAX_PTS];
766
+ T angles[MAX_PTS];
767
+ size_t indices[MAX_PTS];
768
+ size_t numPts = 0;
769
+
770
+ auto addPt = [&] (const Point_<T> &p) {
771
+ points[numPts] = p;
772
+ ++numPts;
773
+ };
774
+
775
+ for (size_t i = 0; i < 4; ++i) {
776
+ Point_<T> aPt = quadsA[i];
777
+ Point_<T> bPt = quadsB[i];
778
+
779
+ if (quadsA.Contains(bPt)) {
780
+ addPt(bPt);
781
+ }
782
+ if (quadsB.Contains(aPt)) {
783
+ addPt(aPt);
784
+ }
785
+ }
786
+
787
+ for (size_t i = 0; i < 4; ++i) {
788
+ Segment_<T> segA{ quadsA[i], quadsA[(i + 1) % 4] };
789
+
790
+ for (size_t j = 0; j < 4; ++j) {
791
+ Segment_<T> segB{ quadsB[j], quadsB[(j + 1) % 4] };
792
+
793
+ Point_<T> ptAlong;
794
+ if (segA.Intersection(segB, ptAlong)) {
795
+ addPt(ptAlong);
796
+ }
797
+ }
798
+ }
799
+
800
+ if (numPts == 0) {
801
+ return 0;
802
+ }
803
+
804
+ Point_<T> center{ 0, 0 };
805
+ for (size_t i = 0; i < numPts; ++i) {
806
+ center += points[i];
807
+ }
808
+ center /= numPts;
809
+
810
+ for (size_t i = 0; i < numPts; ++i) {
811
+ points[i] -= center;
812
+
813
+ angles[i] = atan2(points[i].Y, points[i].X);
814
+
815
+ indices[i] = i;
816
+ }
817
+
818
+ // Perform an argsort over the angles
819
+ SORT_ALGO(indices, indices + numPts,
820
+ [&] (size_t a, size_t b) {
821
+ return angles[a] < angles[b];
822
+ }
823
+ );
824
+
825
+ for (size_t i = 0; i < numPts; ++i) {
826
+ sortedPoints[i] = points[indices[i]];
827
+ }
828
+
829
+ // Finally, we can compute the area of this polygon using the shoelace formula
830
+ T area = shoelace_area(sortedPoints, numPts);
831
+
832
+ return area;
833
+ }
834
+
835
+ template<typename T, typename Derived>
836
+ template<typename Derived2>
837
+ __any_device__
838
+ __lib_inline__ T QuadBase_<T, Derived>::IntersectionArea(const QuadBase_<T, Derived2> &other) const
839
+ {
840
+ return intersection_area(
841
+ static_cast<const Derived&>(*this),
842
+ static_cast<const Derived2&>(other)
843
+ );
844
+ }
845
+
846
+ template<typename T1, typename T2>
847
+ __any_device__
848
+ __lib_inline__ auto geometry_iou(const T1 &a, const T2 &b) -> decltype(a.Area())
849
+ {
850
+ auto aArea = a.Area();
851
+ auto bArea = b.Area();
852
+ auto ixArea = a.IntersectionArea(b);
853
+
854
+ auto unionArea = aArea + bArea - ixArea;
855
+
856
+ return ixArea / unionArea;
857
+ }
858
+
859
+ template<typename T, typename Derived>
860
+ template<typename Derived2>
861
+ __any_device__
862
+ __lib_inline__ T QuadBase_<T, Derived>::IOU(const QuadBase_<T, Derived2> &other) const
863
+ {
864
+ return geometry_iou(
865
+ static_cast<const Derived&>(*this),
866
+ static_cast<const Derived2&>(other)
867
+ );
868
+ }
869
+
870
+ template<typename T, typename Derived>
871
+ template<typename Derived2>
872
+ __any_device__
873
+ __lib_inline__ T QuadBase_<T, Derived>::IOU_UpperBound(const QuadBase_<T, Derived2> &other) const
874
+ {
875
+ return geometry_iou(
876
+ Bounds(),
877
+ other.Bounds()
878
+ );
879
+ }
880
+
881
+ template<typename T1, typename T2>
882
+ __any_device__ __lib_inline__
883
+ auto geometry_region_sizes(const T1 &a, const T2 &b) -> tuple_t<decltype(a.Area()), decltype(a.Area()), decltype(a.IntersectionArea(b))>
884
+ {
885
+ auto aArea = a.Area();
886
+ auto bArea = b.Area();
887
+ auto ixArea = a.IntersectionArea(b);
888
+
889
+ auto unionArea = aArea + bArea - ixArea;
890
+ auto iou = ixArea / unionArea;
891
+
892
+ return { ixArea / aArea, ixArea / bArea, iou };
893
+ }
894
+
895
+
896
+ template<typename T, typename Derived>
897
+ template<typename Derived2>
898
+ __any_device__ __lib_inline__
899
+ tuple_t<T, T, T> QuadBase_<T, Derived>::RegionSizes(const QuadBase_<T, Derived2> &other) const
900
+ {
901
+ return geometry_region_sizes(
902
+ static_cast<const Derived&>(*this),
903
+ static_cast<const Derived2&>(other)
904
+ );
905
+ }
906
+
907
+ template<typename T, typename Derived>
908
+ template<typename Derived2>
909
+ __any_device__ __lib_inline__
910
+ tuple_t<T, T, T> QuadBase_<T, Derived>::RegionSizes_UpperBound(const QuadBase_<T, Derived2> &other) const
911
+ {
912
+ return geometry_region_sizes(
913
+ Bounds(),
914
+ other.Bounds()
915
+ );
916
+ }
917
+
918
+ template<typename polygon_t>
919
+ __any_device__
920
+ auto polygon_bounds(const polygon_t &poly) -> AABB_<typename polygon_t::inner_type>
921
+ {
922
+ #ifndef __NVCC__
923
+ using std::min;
924
+ using std::max;
925
+ #endif
926
+ auto minP = poly[0];
927
+ auto maxP = minP;
928
+ for (size_t i = 1; i < poly.Count; ++i) {
929
+ auto qp = poly[i];
930
+ minP = min(minP, qp);
931
+ maxP = max(maxP, qp);
932
+ }
933
+ return { minP.X, minP.Y, maxP.X, maxP.Y };
934
+ }
935
+
936
+ template<typename T, typename Derived>
937
+ __any_device__
938
+ AABB_<T> PolygonBase_<T, Derived>::Bounds() const {
939
+ return polygon_bounds(static_cast<const Derived&>(*this));
940
+ }
941
+
942
+ template<typename polygon_t, typename point_t>
943
+ __any_device__
944
+ bool polygon_contains(const polygon_t &poly, const point_t &pt)
945
+ {
946
+ typedef typename polygon_t::inner_type T;
947
+
948
+ // Some arbitrary segment. Technically this should be a ray, but functionally this will work
949
+ Segment_<T> testSeg{ pt, { -1e6, -2e6 }};
950
+ Point_<T> trash;
951
+
952
+ int32_t ixCount = 0;
953
+ for (size_t i = 0; i < poly.Count; ++i) {
954
+ Segment_<T> polySeg{ poly[i], poly[(i + 1) % poly.Count] };
955
+
956
+ if (testSeg.Intersection(polySeg, trash)) {
957
+ ++ixCount;
958
+ }
959
+ }
960
+
961
+ // If there are an odd number of intersections, then the point is inside
962
+ return (ixCount % 2) == 1;
963
+ }
964
+
965
+ template<typename T, typename Derived>
966
+ __any_device__
967
+ bool PolygonBase_<T, Derived>::Contains(const Point_<T> &pt) const {
968
+ return polygon_contains(static_cast<const Derived&>(*this), pt);
969
+ }
970
+
971
+ template<typename polygon_t>
972
+ __any_device__
973
+ auto polygon_edge_length(const polygon_t &poly) -> typename polygon_t::inner_type
974
+ {
975
+ typedef typename polygon_t::inner_type T;
976
+
977
+ T ret = 0;
978
+
979
+ for (size_t i = 0; i < poly.Count; ++i) {
980
+ Segment_<T> seg{ poly[i], poly[(i + 1) % poly.Count] };
981
+
982
+ ret += seg.Length();
983
+ }
984
+
985
+ return ret;
986
+ }
987
+
988
+ template<typename T, typename Derived>
989
+ __any_device__
990
+ T PolygonBase_<T, Derived>::EdgeLength() const {
991
+ return polygon_edge_length(static_cast<const Derived&>(*this));
992
+ }
993
+
994
+ template<typename polygon_t>
995
+ __any_device__
996
+ auto polygon_center(const polygon_t &poly) -> Point_<typename polygon_t::inner_type>
997
+ {
998
+ typedef typename polygon_t::inner_type T;
999
+
1000
+ T cx = 0, cy = 0, a = 0;
1001
+ size_t j = poly.Count - 1;
1002
+ for (size_t i = 0; i < poly.Count; ++i) {
1003
+ Point_<T> p0 = poly[i];
1004
+ Point_<T> p1 = poly[j];
1005
+
1006
+ T common = (p0.X * p1.Y - p1.X * p0.Y);
1007
+ cx += (p0.X + p1.X) * common;
1008
+ cy += (p0.Y + p1.Y) * common;
1009
+ a += common;
1010
+
1011
+ j = i;
1012
+ }
1013
+
1014
+ a /= 2;
1015
+
1016
+ Point_<T> center{ cx / (6 * a), cy / (6 * a) };
1017
+
1018
+ return center;
1019
+ }
1020
+
1021
+ template<typename T, typename Derived>
1022
+ __any_device__
1023
+ Point_<T> PolygonBase_<T, Derived>::Center() const {
1024
+ return polygon_center(static_cast<const Derived&>(*this));
1025
+ }
1026
+
1027
+ template<typename T, typename Derived>
1028
+ __any_device__
1029
+ T PolygonBase_<T, Derived>::Area() const {
1030
+ const Derived &dThis = static_cast<const Derived&>(*this);
1031
+ return shoelace_area(dThis, dThis.Count);
1032
+ }
1033
+
1034
+
1035
+ template<typename T>
1036
+ __any_device__
1037
+ Point_<T> nearest_point_on_segment(const Point_<T> &pt, const Segment_<T> &seg)
1038
+ {
1039
+ #ifndef __NVCC__
1040
+ using std::max;
1041
+ using std::min;
1042
+ #endif
1043
+
1044
+ const T l2 = seg.LengthSq();
1045
+
1046
+ if (l2 == 0.0) {
1047
+ return seg.A;
1048
+ }
1049
+
1050
+ const auto v = seg.A;
1051
+ const auto w = seg.B;
1052
+ // Consider the line extending the segment, parameterized as v + t*(w-v)
1053
+ // Find projection of point p onto the line
1054
+ auto t = dot(pt - v, w - v) / l2;
1055
+
1056
+ // Clamp between t=0 and t=1
1057
+ t = max(static_cast<T>(0), min(static_cast<T>(1), t));
1058
+
1059
+ const auto projection = v + t * (w - v);
1060
+
1061
+ return projection;
1062
+ }
1063
+
1064
+
1065
+ template<typename T>
1066
+ __any_device__
1067
+ Segment_<T> shortest_line_between_segments(const Segment_<T> &a, const Segment_<T> &b)
1068
+ {
1069
+ Segment_<T> segs[] = {
1070
+ { a.A, nearest_point_on_segment(a.A, b) },
1071
+ { a.B, nearest_point_on_segment(a.B, b) },
1072
+ { nearest_point_on_segment(b.A, a), b.A },
1073
+ { nearest_point_on_segment(b.B, a), b.B }
1074
+ };
1075
+
1076
+ T minDist = std::numeric_limits<T>::max();
1077
+ size_t idx;
1078
+
1079
+ #pragma unroll
1080
+ for (size_t i = 0; i < 4; ++i) {
1081
+ T dist = segs[i].LengthSq();
1082
+ if (dist < minDist) {
1083
+ minDist = dist;
1084
+ idx = i;
1085
+ }
1086
+ }
1087
+
1088
+ return segs[idx];
1089
+ }
1090
+
1091
+ // Find the distance between a point and the nearest point along the specified segment
1092
+ template<typename T>
1093
+ __any_device__
1094
+ T distance_to_segment(const Point_<T> &pt, const Segment_<T> &seg)
1095
+ {
1096
+ auto projection = nearest_point_on_segment(pt, seg);
1097
+
1098
+ auto dist = length(pt - projection);
1099
+
1100
+ return dist;
1101
+ }
nemo-retriever-ocr/cpp/geometry_api/calc_poly_min_rrect.cpp ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "geometry_api.h"
6
+
7
+ #include "../graph_detection/encode_util.h"
8
+
9
+ #include "../geometry.h"
10
+ #include "matrix2x2.h"
11
+
12
+ using namespace std;
13
+
14
+ template<typename T>
15
+ void _calc_poly_min_rrect(const torch::TensorAccessor<T, 2> vertices, torch::TensorAccessor<T, 2> outRRect);
16
+ template<typename T>
17
+ void _calc_quad_min_rrect(const torch::TensorAccessor<T, 2> vertices, torch::TensorAccessor<T, 2> outRRect);
18
+
19
+ torch::Tensor calc_poly_min_rrect(torch::Tensor vertices)
20
+ {
21
+ if (vertices.size(0) < 3) {
22
+ throw runtime_error("Invalid polygon! Expected >= 3 vertices, got " + to_string(vertices.size(0)));
23
+ }
24
+
25
+ auto ret = torch::empty({ 4, 2 }, vertices.options());
26
+
27
+ auto retAcc = ret.accessor<float, 2>();
28
+
29
+ if (vertices.size(0) != 4) {
30
+ // OpenCV requires this to be a contiguous buffer
31
+ vertices = vertices.contiguous();
32
+ _calc_poly_min_rrect(vertices.accessor<float, 2>(), retAcc);
33
+ } else {
34
+ _calc_quad_min_rrect(vertices.accessor<float, 2>(), retAcc);
35
+ }
36
+
37
+ return ret;
38
+ }
39
+
40
+
41
+ template<typename T>
42
+ void _calc_bounds(const torch::TensorAccessor<T, 2> &vertices, torch::TensorAccessor<T, 2> &outRRect,
43
+ const Point_<T> &leftCenter, const Point_<T> &rightCenter)
44
+ {
45
+ typedef Point_<T> Pointf;
46
+
47
+ Pointf vecAlong = rightCenter - leftCenter;
48
+ auto alongMag = length(vecAlong);
49
+
50
+ if (alongMag == 0.0f) {
51
+ throw runtime_error("Invalid polygon!");
52
+ }
53
+
54
+ vecAlong /= alongMag;
55
+
56
+ Pointf dOrtho{ -vecAlong.Y, vecAlong.X };
57
+
58
+ Pointf center = (leftCenter + rightCenter) / 2.0f;
59
+
60
+ Matrix2x2<T> rotMat{ vecAlong, dOrtho };
61
+
62
+ auto get_fn = [&vertices, &center] (int64_t i) {
63
+ return Pointf{ vertices[i] } - center;
64
+ };
65
+
66
+ // All we care about it getting the bounds in the normalized space, so this saves
67
+ // us from having to do any memory allocation
68
+ Pointf minPt{ 0, 0 }, maxPt{ 0, 0 };
69
+ auto tx_fn = [&minPt, &maxPt] (int64_t i, const Pointf &pt) {
70
+ minPt = min(minPt, pt);
71
+ maxPt = max(maxPt, pt);
72
+ };
73
+
74
+ matmul_fn(vertices.size(0), get_fn, rotMat, tx_fn, transpose_tag{});
75
+
76
+ Pointf rotBox[4] = {
77
+ minPt,
78
+ { maxPt.X, minPt.Y },
79
+ maxPt,
80
+ { minPt.X, maxPt.Y }
81
+ };
82
+
83
+ auto get_fn2 = [&rotBox] (int64_t i) {
84
+ return rotBox[i];
85
+ };
86
+
87
+ auto assign_fn = [&center, &outRRect] (int64_t i, const Pointf &pt) {
88
+ outRRect[i][0] = pt.X + center.X;
89
+ outRRect[i][1] = pt.Y + center.Y;
90
+ };
91
+
92
+ matmul_fn(4, get_fn2, rotMat, assign_fn, contiguous_tag{});
93
+ }
94
+
95
+
96
+ template<typename T>
97
+ void _calc_poly_min_rrect(const torch::TensorAccessor<T, 2> vertices, torch::TensorAccessor<T, 2> outRRect)
98
+ {
99
+ typedef Point_<T> Pointf;
100
+ typedef Polygon_<T> Polygonf;
101
+
102
+ Polygonf poly{ vertices.data(), vertices.size(0) };
103
+
104
+ vector<graph_detection::Edge> bottoms = graph_detection::find_bottom(poly, false);
105
+
106
+ if (bottoms.size() != 2) {
107
+ throw runtime_error("Invalid polygon!");
108
+ }
109
+
110
+ vector<graph_detection::Edge> longEdges[2];
111
+ graph_detection::find_long_edges(poly, bottoms.data(), longEdges[0], longEdges[1]);
112
+
113
+ ////
114
+ // Determine which edge is above the other
115
+ Pointf cpts[2];
116
+ for (size_t i = 0; i < 2; ++i) {
117
+ auto &pedge = longEdges[i];
118
+
119
+ cpts[i] = Pointf{0.0f, 0.0f};
120
+ float ct = 0;
121
+ for (size_t z = 0; z < pedge.size(); ++z) {
122
+ auto edge = pedge[z];
123
+ Pointf p1 = poly[edge.A];
124
+ Pointf p2 = poly[edge.B];
125
+ cpts[i] += (p1 + p2) / 2.0f;
126
+ ct += 1.0f;
127
+ }
128
+
129
+ if (ct < 1.0f) {
130
+ throw runtime_error("Edge was empty!");
131
+ }
132
+ cpts[i] /= ct;
133
+ }
134
+
135
+ float vpp = graph_detection::vector_sin(cpts[0] - cpts[1]);
136
+ if (vpp >= 0) {
137
+ swap(bottoms[0], bottoms[1]);
138
+ }
139
+ ////
140
+
141
+ Pointf edge1[2] = { poly[bottoms[0].A], poly[bottoms[0].B] };
142
+ Pointf edge2[2] = { poly[bottoms[1].A], poly[bottoms[1].B] };
143
+
144
+ Pointf c0 = (edge1[0] + edge1[1]) / 2.0f;
145
+ Pointf c1 = (edge2[0] + edge2[1]) / 2.0f;
146
+
147
+ _calc_bounds(vertices, outRRect, c0, c1);
148
+ }
149
+
150
+ template<typename T>
151
+ void _calc_quad_min_rrect(const torch::TensorAccessor<T, 2> vertices, torch::TensorAccessor<T, 2> outRRect)
152
+ {
153
+ typedef Point_<T> Pointf;
154
+
155
+ // Instead of finding an arbitrary rotated box, find a reasonable
156
+ // fit for the quadrangle
157
+ Pointf pts[4] = {
158
+ vertices[0], vertices[1], vertices[2], vertices[3]
159
+ };
160
+
161
+ Pointf c0 = (pts[0] + pts[3]) / 2.0f;
162
+ Pointf c1 = (pts[1] + pts[2]) / 2.0f;
163
+
164
+ _calc_bounds(vertices, outRRect, c0, c1);
165
+ }
nemo-retriever-ocr/cpp/geometry_api/geometry_api.cpp ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "geometry_api.h"
6
+
7
+ #include "geometry_api_common.h"
8
+
9
+ using namespace std;
10
+
11
+ torch::Tensor rrect_to_quads_gpu(torch::Tensor rrects, float cellSize);
12
+
13
+ template<typename T>
14
+ torch::Tensor rrect_to_quads_impl(torch::Tensor rrects, T cellSize)
15
+ {
16
+ // BHW(5)
17
+ auto rrectAccess = rrects.accessor<T, 4>();
18
+
19
+ T cellOff = cellSize / 2;
20
+
21
+ auto quads = torch::empty({ rrects.size(0), rrects.size(1), rrects.size(2), 4, 2 }, rrects.options());
22
+
23
+ auto quadsAccess = quads.accessor<T, 5>();
24
+
25
+ for (long b = 0; b < rrects.size(0); ++b) {
26
+ for (long y = 0; y < rrects.size(1); ++y) {
27
+ for (long x = 0; x < rrects.size(2); ++x) {
28
+ auto rrect = rrectAccess[b][y][x];
29
+
30
+ auto quad = quadsAccess[b][y][x];
31
+
32
+ assign_rrect_to_quad(rrect, quad, cellSize, cellOff,
33
+ static_cast<T>(x),
34
+ static_cast<T>(y));
35
+ }
36
+ }
37
+ }
38
+
39
+ return quads;
40
+ }
41
+
42
+ torch::Tensor rrect_to_quads(torch::Tensor rrects, float cellSize)
43
+ {
44
+ if (rrects.is_cuda()) {
45
+ return rrect_to_quads_gpu(rrects, cellSize);
46
+ }
47
+
48
+ torch::Tensor quads;
49
+ AT_DISPATCH_FLOATING_TYPES(
50
+ rrects.scalar_type(),
51
+ "rrect_to_quads_impl",
52
+ ([&] {
53
+ quads = rrect_to_quads_impl<scalar_t>(rrects, scalar_t(cellSize));
54
+ })
55
+ );
56
+
57
+ return quads;
58
+ }
59
+
60
+
61
+ template<typename T>
62
+ torch::Tensor rrect_to_quads_backward_impl(torch::Tensor rrects, torch::Tensor gradOutput)
63
+ {
64
+ // BHW(5)
65
+ auto gradInput = torch::empty_like(rrects);
66
+
67
+ auto rrectAccess = rrects.accessor<T, 4>();
68
+ // BHW42
69
+ auto gradOutputAccess = gradOutput.accessor<T, 5>();
70
+ auto gradInputAccess = gradInput.accessor<T, 4>();
71
+
72
+ for (long b = 0; b < rrects.size(0); ++b) {
73
+ for (long y = 0; y < rrects.size(1); ++y) {
74
+ for (long x = 0; x < rrects.size(2); ++x) {
75
+ assign_grad_rrect_to_quad<T>(rrectAccess[b][y][x], gradOutputAccess[b][y][x], gradInputAccess[b][y][x]);
76
+ }
77
+ }
78
+ }
79
+
80
+ return gradInput;
81
+ }
82
+
83
+ torch::Tensor rrect_to_quads_backward_gpu(torch::Tensor rrects, torch::Tensor gradOutput);
84
+
85
+ torch::Tensor rrect_to_quads_backward(torch::Tensor rrects, torch::Tensor gradOutput)
86
+ {
87
+ if (rrects.is_cuda()) {
88
+ return rrect_to_quads_backward_gpu(rrects, gradOutput);
89
+ }
90
+
91
+ torch::Tensor gradInput;
92
+ AT_DISPATCH_FLOATING_TYPES(
93
+ rrects.scalar_type(),
94
+ "rrect_to_quads_backward_impl",
95
+ ([&] {
96
+ gradInput = rrect_to_quads_backward_impl<scalar_t>(rrects, gradOutput);
97
+ })
98
+ );
99
+
100
+ return gradInput;
101
+ }
nemo-retriever-ocr/cpp/geometry_api/geometry_api.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <torch/torch.h>
8
+
9
+ torch::Tensor rrect_to_quads(torch::Tensor rrects, float cellSize);
10
+ torch::Tensor rrect_to_quads_backward(torch::Tensor rrects, torch::Tensor gradOutput);
11
+
12
+ torch::Tensor calc_poly_min_rrect(torch::Tensor vertices);
13
+
14
+ float get_rel_continuation_cos(torch::Tensor rrectA, torch::Tensor rrectB);
15
+
16
+ torch::Tensor get_poly_bounds_quad(torch::Tensor poly);
nemo-retriever-ocr/cpp/geometry_api/geometry_api_common.h ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <torch/torch.h>
8
+
9
+ #include "../cuda_intellisense.cuh"
10
+ #include "../geometry.h"
11
+
12
+ #if defined(__NVCC__)
13
+ #include <math_constants.h>
14
+ #define GEO_PI CUDART_PI_F
15
+ #else
16
+ #include <math.h>
17
+ #define GEO_PI M_PI
18
+ #endif
19
+
20
+
21
+ template<typename access_t, typename point_t>
22
+ __device__
23
+ inline
24
+ void pt_assign(access_t acc, const point_t &p) {
25
+ acc[0] = p.X;
26
+ acc[1] = p.Y;
27
+ }
28
+
29
+ template<typename T, typename rrect_access_t>
30
+ __device__ __lib_inline__
31
+ InPlaceQuad_<T> cvt_rrect_to_quad(const rrect_access_t &rrect, T cellSize, T cellOff, T x, T y)
32
+ {
33
+ typedef Point_<T> Pointf;
34
+
35
+ Pointf prior{
36
+ x * cellSize + cellOff,
37
+ y * cellSize + cellOff
38
+ };
39
+
40
+ T dTop = rrect[0];
41
+ T dRight = rrect[1];
42
+ T dBottom = rrect[2];
43
+ T dLeft = rrect[3];
44
+ T theta = rrect[4];
45
+
46
+ T piOver2{GEO_PI / 2.0f};
47
+ Pointf vX{ cos(theta), sin(theta) };
48
+ Pointf vY{ cos(theta - piOver2), sin(theta - piOver2) };
49
+
50
+ InPlaceQuad_<T> ret;
51
+
52
+ ret[0] = prior - vX * dLeft + vY * dTop;
53
+ ret[1] = prior + vX * dRight + vY * dTop;
54
+ ret[2] = prior + vX * dRight - vY * dBottom;
55
+ ret[3] = prior - vX * dLeft - vY * dBottom;
56
+
57
+ return ret;
58
+ }
59
+
60
+ template<typename rrect_access_t, typename quad_access_t, typename T>
61
+ __device__ __lib_inline__
62
+ void assign_rrect_to_quad(const rrect_access_t &rrect, quad_access_t &quad,
63
+ T cellSize, T cellOff, T x, T y)
64
+ {
65
+ const InPlaceQuad_<T> cvQuad = cvt_rrect_to_quad<T>(rrect, cellSize, cellOff, x, y);
66
+
67
+ const T *pInQuad = reinterpret_cast<const T*>(&cvQuad);
68
+ T *pOutQuad = reinterpret_cast<T*>(quad.data());
69
+
70
+ #pragma unroll
71
+ for (uint32_t i = 0; i < 8; ++i) {
72
+ pOutQuad[i] = pInQuad[i];
73
+ }
74
+ }
75
+
76
+ template<typename T, typename rrect_access_t, typename quad_access_t>
77
+ __device__
78
+ inline
79
+ void assign_grad_rrect_to_quad(const rrect_access_t &rrect,
80
+ const quad_access_t &gradOutput,
81
+ rrect_access_t gradInput)
82
+ {
83
+ typedef Point_<T> Pointf;
84
+
85
+ T Top = rrect[0];
86
+ T Right = rrect[1];
87
+ T Bottom = rrect[2];
88
+ T Left = rrect[3];
89
+ T theta = rrect[4];
90
+
91
+ T piOver2{GEO_PI / 2.0f};
92
+ Pointf vX{ cos(theta), sin(theta) };
93
+ Pointf vY{ cos(theta - piOver2), sin(theta - piOver2) };
94
+
95
+ Pointf dVX{ -vX.Y, vX.X };
96
+ Pointf dVY{ -vY.Y, vY.X };
97
+
98
+ Pointf gP0 = gradOutput[0],
99
+ gP1 = gradOutput[1],
100
+ gP2 = gradOutput[2],
101
+ gP3 = gradOutput[3];
102
+
103
+ // Top
104
+ gradInput[0] = (gP0 * vY + gP1 * vY).Sum();
105
+ // Right
106
+ gradInput[1] = (gP1 * vX + gP2 * vX).Sum();
107
+ // Bottom
108
+ gradInput[2] = -(gP2 * vY + gP3 * vY).Sum();
109
+ // Left
110
+ gradInput[3] = -(gP0 * vX + gP3 * vX).Sum();
111
+
112
+ // Theta
113
+ gradInput[4] = (
114
+ gP0 * (-Left * dVX + Top * dVY) +
115
+ gP1 * (Right * dVX + Top * dVY) +
116
+ gP2 * (Right * dVX - Bottom * dVY) +
117
+ gP3 * (-Left * dVX - Bottom * dVY)
118
+ ).Sum();
119
+ }
120
+
121
+ #undef GEO_PI
nemo-retriever-ocr/cpp/geometry_api/geometry_api_gpu.cu ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "geometry_api.h"
6
+
7
+ #include "../geometry.h"
8
+ #include "../cuda_intellisense.cuh"
9
+ #include "geometry_api_common.h"
10
+
11
+ #include <trove/ptr.h>
12
+
13
+ using namespace std;
14
+
15
+
16
+ template<typename T>
17
+ struct RRect_ {
18
+ T Data[5];
19
+
20
+ template<typename index_t>
21
+ __device__
22
+ const T &operator[](index_t i) const { return Data[i]; }
23
+ template<typename index_t>
24
+ __device__
25
+ T &operator[](index_t i) { return Data[i]; }
26
+ };
27
+
28
+ template<typename T>
29
+ __global__
30
+ void device_rrect_to_quads_gpu(torch::PackedTensorAccessor64<T, 2> rrectAccess,
31
+ torch::PackedTensorAccessor64<T, 3> quadsAccess,
32
+ int64_t numRows, int64_t numCols,
33
+ T cellSize)
34
+ {
35
+ typedef Point_<T> Pointf;
36
+ typedef RRect_<T> RRectf;
37
+ typedef InPlaceQuad_<T> Quadf;
38
+ constexpr T TWO = 2;
39
+
40
+ const int64_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
41
+
42
+ if (jobIdx >= rrectAccess.size(0)) {
43
+ return;
44
+ }
45
+
46
+ int64_t row = jobIdx / numCols;
47
+ const int64_t col = jobIdx - (row * numCols);
48
+ row = row % numRows;
49
+
50
+ auto rawRRect = reinterpret_cast<RRectf*>(rrectAccess.data());
51
+ auto rawQuad = reinterpret_cast<Quadf*>(quadsAccess.data());
52
+ #if defined(NDEBUG)
53
+ trove::coalesced_ptr<RRectf> pRRect(rawRRect);
54
+ trove::coalesced_ptr<Quadf> pQuad(rawQuad);
55
+ #else
56
+ auto pRRect = rawRRect;
57
+ auto pQuad = rawQuad;
58
+ #endif
59
+
60
+ RRectf rrect = pRRect[jobIdx];
61
+
62
+ T cellOff = cellSize / TWO;
63
+ Quadf cvQuad = cvt_rrect_to_quad<T>(rrect, cellSize, cellOff, col, row);
64
+
65
+ pQuad[jobIdx] = cvQuad;
66
+ }
67
+
68
+ torch::Tensor rrect_to_quads_gpu(torch::Tensor rrects, float cellSize)
69
+ {
70
+ if (!rrects.is_contiguous()) {
71
+ throw std::runtime_error("Expected the rrects to be contiguous!");
72
+ }
73
+
74
+ torch::Tensor quads = torch::empty({ rrects.size(0), rrects.size(1), rrects.size(2), 4, 2 }, rrects.options());
75
+
76
+ auto rrFlat = rrects.flatten(0, 2);
77
+ auto qFlat = quads.flatten(0, 2);
78
+
79
+ dim3 blockSize(96);
80
+ dim3 gridSize(div_up(qFlat.size(0), blockSize.x));
81
+
82
+ if (quads.numel() > 0) {
83
+ AT_DISPATCH_FLOATING_TYPES(
84
+ quads.scalar_type(),
85
+ "cuda_rrect_to_quads",
86
+ ([&] {
87
+
88
+ device_rrect_to_quads_gpu<scalar_t> KERNEL_ARG2(gridSize, blockSize) (
89
+ rrFlat.packed_accessor64<scalar_t, 2>(),
90
+ qFlat.packed_accessor64<scalar_t, 3>(),
91
+ rrects.size(1), rrects.size(2),
92
+ cellSize
93
+ );
94
+
95
+ })
96
+ );
97
+ }
98
+
99
+ return quads;
100
+ }
101
+
102
+ template<typename scalar_t>
103
+ __global__
104
+ void device_rrect_to_quads_backward_gpu(torch::PackedTensorAccessor64<scalar_t, 2> rrect,
105
+ torch::PackedTensorAccessor64<scalar_t, 3> gradOutput,
106
+ torch::PackedTensorAccessor64<scalar_t, 2> gradInput)
107
+ {
108
+ const int64_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
109
+
110
+ if (jobIdx >= rrect.size(0)) return;
111
+
112
+ assign_grad_rrect_to_quad<scalar_t>(rrect[jobIdx], gradOutput[jobIdx], gradInput[jobIdx]);
113
+ }
114
+
115
+
116
+ torch::Tensor rrect_to_quads_backward_gpu(torch::Tensor rrects, torch::Tensor gradOutput)
117
+ {
118
+ auto gradInput = torch::empty_like(rrects);
119
+
120
+ auto flatRRects = rrects.reshape({ -1, 5 });
121
+ auto flatGradOutput = gradOutput.reshape({ -1, 4, 2 });
122
+ auto flatGradInput = gradInput.reshape({ -1, 5 });
123
+
124
+ dim3 blockSize(32);
125
+ dim3 gridSize(div_up(rrects.size(0) * rrects.size(1) * rrects.size(2), blockSize.x));
126
+
127
+ if (rrects.numel() > 0) {
128
+ AT_DISPATCH_FLOATING_TYPES(
129
+ rrects.scalar_type(),
130
+ "cuda_rrect_to_quads_backward",
131
+ ([&] {
132
+ device_rrect_to_quads_backward_gpu KERNEL_ARG2(gridSize, blockSize) (
133
+ flatRRects.packed_accessor64<scalar_t, 2>(),
134
+ flatGradOutput.packed_accessor64<scalar_t, 3>(),
135
+ flatGradInput.packed_accessor64<scalar_t, 2>()
136
+ );
137
+ })
138
+ );
139
+ }
140
+
141
+ return gradInput;
142
+ }
nemo-retriever-ocr/cpp/geometry_api/get_rel_continuation_cos.cpp ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "geometry_api.h"
6
+
7
+ #include "../geometry.h"
8
+
9
+ using namespace std;
10
+
11
+
12
+ float get_rel_continuation_cos(torch::Tensor rrectATensor, torch::Tensor rrectBTensor)
13
+ {
14
+ typedef Point_<float> Pointf;
15
+
16
+ if (rrectATensor.size(0) != 4 || rrectBTensor.size(0) != 4) {
17
+ throw runtime_error("Invalid rrect arguments. Both must have 4 vertices! A=" +
18
+ to_string(rrectATensor.size(0)) + ", B=" + to_string(rrectBTensor.size(0)));
19
+ }
20
+
21
+ auto rrectA = rrectATensor.accessor<float, 2>();
22
+ auto rrectB = rrectBTensor.accessor<float, 2>();
23
+
24
+ Pointf aPts[4] = {
25
+ rrectA[0], rrectA[1], rrectA[2], rrectA[3]
26
+ };
27
+
28
+ auto c1 = (aPts[0] + aPts[3]) / 2.0f;
29
+ auto c2 = (aPts[1] + aPts[2]) / 2.0f;
30
+
31
+ auto aDir = c2 - c1;
32
+ auto aLen = length(aDir);
33
+
34
+ if (aLen > 0) {
35
+ aDir /= aLen;
36
+ } else {
37
+ aDir = Pointf{ 1, 0 };
38
+ }
39
+
40
+ auto centerA = (c1 + c2) / 2.0f;
41
+
42
+ Pointf bPts[4] = {
43
+ rrectB[0], rrectB[1], rrectB[2], rrectB[3]
44
+ };
45
+
46
+ auto centerB = (bPts[0] + bPts[1] + bPts[2] + bPts[3]) / 4.0f;
47
+
48
+ auto connDir = centerB - centerA;
49
+ auto connLen = length(connDir);
50
+
51
+ if (connLen == 0.0f) {
52
+ return 1.0f;
53
+ }
54
+
55
+ connDir /= connLen;
56
+
57
+ auto cosT = dot(aDir, connDir);
58
+
59
+ return cosT;
60
+ }
nemo-retriever-ocr/cpp/geometry_api/matrix2x2.h ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include "../geometry.h"
8
+
9
+
10
+ struct contiguous_tag{};
11
+
12
+ struct transpose_tag{};
13
+
14
+ template<typename layout_t, uint32_t R, uint32_t C>
15
+ struct Matrix2x2_Offset;
16
+
17
+ template<uint32_t R, uint32_t C>
18
+ struct Matrix2x2_Offset<contiguous_tag, R, C>
19
+ {
20
+ static const uint32_t OFFSET = R * 2 + C;
21
+ };
22
+
23
+ template<uint32_t R, uint32_t C>
24
+ struct Matrix2x2_Offset<transpose_tag, R, C>
25
+ {
26
+ static const uint32_t OFFSET = C * 2 + R;
27
+ };
28
+
29
+
30
+ template<typename T, typename layout_t, uint32_t R, uint32_t C>
31
+ struct Matrix2x2_Indexor
32
+ {
33
+ static const uint32_t OFFSET = Matrix2x2_Offset<layout_t, R, C>::OFFSET;
34
+
35
+ static T &get(T *data) { return data[OFFSET]; }
36
+ static const T get(const T *data) { return data[OFFSET]; }
37
+ };
38
+
39
+
40
+ template<typename T>
41
+ struct Matrix2x2
42
+ {
43
+ Matrix2x2() = default;
44
+ Matrix2x2(T r0c0, T r0c1, T r1c0, T r1c1)
45
+ : m_data{ r0c0, r0c1, r1c0, r1c1 }
46
+ {
47
+ }
48
+ Matrix2x2(const Point_<T> &r0, const Point_<T> &r1)
49
+ : m_data{ r0.X, r0.Y, r1.X, r1.Y }
50
+ {
51
+ }
52
+ Matrix2x2(const Point_<T> &r0, const Point_<T> &r1, transpose_tag)
53
+ : m_data{ r0.X, r1.X, r0.Y, r1.Y }
54
+ {
55
+ }
56
+
57
+ inline T &operator[](uint32_t i) { return m_data[i]; }
58
+ inline const T operator[](uint32_t i) const { return m_data[i]; }
59
+
60
+ T m_data[4];
61
+ };
62
+
63
+ template<typename T, typename layout_t>
64
+ struct Matrix2x2_View
65
+ {
66
+ Matrix2x2_View(const Matrix2x2<T> &m) : m_data(m.m_data) {}
67
+
68
+ const T *m_data;
69
+ };
70
+
71
+ template<uint32_t R, uint32_t C, typename T, typename layout_t>
72
+ const T get(const Matrix2x2_View<T, layout_t> &m)
73
+ {
74
+ return Matrix2x2_Indexor<T, layout_t, R, C>::get(m.m_data);
75
+ }
76
+
77
+ template<typename T, typename get_pt_t, typename callback_t, typename layout_t = contiguous_tag>
78
+ inline
79
+ void matmul_fn(int64_t N, const get_pt_t &get_fn, const Matrix2x2<T> &mat, const callback_t &callback,
80
+ layout_t lt = layout_t{})
81
+ {
82
+ Matrix2x2_View<T, layout_t> m{ mat };
83
+
84
+ #pragma omp simd
85
+ for (int64_t i = 0; i < N; ++i) {
86
+ Point_<T> pt = get_fn(i);
87
+
88
+ T x = pt.X * get<0, 0>(m) + pt.Y * get<1, 0>(m);
89
+ T y = pt.X * get<0, 1>(m) + pt.Y * get<1, 1>(m);
90
+
91
+ callback(i, Point_<T>{ x, y });
92
+ }
93
+ }
nemo-retriever-ocr/cpp/geometry_api/poly_bounds_quad.cpp ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "geometry_api.h"
6
+
7
+ using namespace std;
8
+
9
+
10
+ template<typename T>
11
+ void pt_assign(torch::TensorAccessor<T, 1> acc, T x, T y)
12
+ {
13
+ acc[0] = x;
14
+ acc[1] = y;
15
+ }
16
+
17
+
18
+ template<typename T>
19
+ void poly_bounds_quad_impl(torch::TensorAccessor<T, 2> poly, torch::TensorAccessor<T, 2> outBounds)
20
+ {
21
+ T minX = poly[0][0],
22
+ minY = poly[0][1],
23
+ maxX = poly[0][0],
24
+ maxY = poly[0][1];
25
+
26
+ const int64_t numVertices = poly.size(0);
27
+
28
+ for (int64_t i = 0; i < numVertices; ++i) {
29
+ auto vert = poly[i];
30
+
31
+ minX = min(minX, vert[0]);
32
+ maxX = max(maxX, vert[0]);
33
+
34
+ minY = min(minY, vert[1]);
35
+ maxY = max(maxY, vert[1]);
36
+ }
37
+
38
+ pt_assign(outBounds[0], minX, minY);
39
+ pt_assign(outBounds[1], maxX, minY);
40
+ pt_assign(outBounds[2], maxX, maxY);
41
+ pt_assign(outBounds[3], minX, maxY);
42
+ }
43
+
44
+
45
+ torch::Tensor get_poly_bounds_quad(torch::Tensor poly)
46
+ {
47
+ auto ret = torch::empty({ 4, 2 }, poly.options());
48
+
49
+ AT_DISPATCH_FLOATING_TYPES(
50
+ poly.scalar_type(),
51
+ "poly_bounds_quad_impl",
52
+ ([&] {
53
+ poly_bounds_quad_impl(
54
+ poly.accessor<scalar_t, 2>(),
55
+ ret.accessor<scalar_t, 2>()
56
+ );
57
+ })
58
+ );
59
+
60
+ return ret;
61
+ }
nemo-retriever-ocr/cpp/graph_detection/encode_util.cpp ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "encode_util.h"
6
+
7
+ #include <algorithm>
8
+ #include <numeric>
9
+ #include <sstream>
10
+
11
+ #include "../third_party/clipper/clipper.hpp"
12
+
13
+ using namespace std;
14
+
15
+ namespace graph_detection {
16
+
17
+ template<typename T>
18
+ struct Candidate : Edge {
19
+ T C;
20
+
21
+ Candidate() = default;
22
+ Candidate(int32_t a, int32_t b, T c) : Edge(a, b), C(c) {}
23
+ };
24
+
25
+ struct DistStruct {
26
+ Candidate<Pointf> A;
27
+ Candidate<Pointf> B;
28
+ float Dist;
29
+
30
+ DistStruct() = default;
31
+ DistStruct(Candidate<Pointf> a, Candidate<Pointf> b, float dist) : A(a), B(b), Dist(dist) {}
32
+ };
33
+
34
+ template<typename T>
35
+ float vec_cos(const Point_<T> &a, const Point_<T> &b)
36
+ {
37
+ return dot(a, b) / (length(a) * length(b) + 1e-8);
38
+ }
39
+
40
+ template<typename T, typename Fn = std::less<T>>
41
+ vector<size_t> arg_sort(const vector<T> &vec, Fn comp = Fn())
42
+ {
43
+ vector<size_t> ret;
44
+ ret.reserve(vec.size());
45
+ for (size_t i = 0; i < vec.size(); ++i) {
46
+ ret.push_back(i);
47
+ }
48
+
49
+ sort(begin(ret), end(ret),
50
+ [&vec, &comp] (size_t idxA, size_t idxB) {
51
+ return comp(vec[idxA], vec[idxB]);
52
+ }
53
+ );
54
+
55
+ return ret;
56
+ }
57
+
58
+
59
+ float edge_length(const Polygon_<float> &poly, const vector<Edge> &edges);
60
+
61
+ vector<Edge> find_bottom(const Polygon_<float> &poly, bool useVertexOrder)
62
+ {
63
+ if (poly.Count < 4) {
64
+ throw runtime_error("Invalid polygon. Fewer than 4 vertices!");
65
+ }
66
+
67
+ // If we trust the source of the geometries, then this saves us both computation,
68
+ // but can also be more reliable since we won't reorder the vertices
69
+ if (useVertexOrder) {
70
+ if ((poly.Count % 2) == 1) {
71
+ throw runtime_error("Can't use trusted vertex order when the vertex count is odd!");
72
+ }
73
+ int32_t halfCt = poly.Count / 2;
74
+ return { { halfCt - 1, halfCt },
75
+ { static_cast<int32_t>(poly.Count) - 1, 0 } };
76
+ }
77
+
78
+ if (poly.Count == 4) {
79
+ float d1 = length(poly[1] - poly[0]) + length(poly[2] - poly[3]);
80
+ float d2 = length(poly[2] - poly[1]) + length(poly[0] - poly[3]);
81
+
82
+ if (4 * d1 < d2) {
83
+ return { { 0, 1 }, { 2, 3 } };
84
+ } else {
85
+ return { { 1, 2 }, { 3, 0 } };
86
+ }
87
+ }
88
+
89
+ auto idx_wrap = [&poly] (size_t idx) {
90
+ return poly[idx % poly.Count];
91
+ };
92
+
93
+ vector<Candidate<float>> candidates;
94
+ for (size_t i = 1; i < (poly.Count + 1); ++i) {
95
+ auto vPrev = idx_wrap(i) - idx_wrap(i - 1);
96
+ auto vNext = idx_wrap(i + 2) - idx_wrap(i + 1);
97
+
98
+ // We're looking for the segment where the preceding and following segment
99
+ // essentially travel in opposite directions
100
+ if (vec_cos(vPrev, vNext) < -0.875f) {
101
+ auto currSeg = idx_wrap(i) - idx_wrap(i + 1);
102
+ candidates.emplace_back(i % poly.Count, (i + 1) % poly.Count, length(currSeg));
103
+ }
104
+ }
105
+
106
+ if (candidates.size() != 2 || candidates[0].A == candidates[1].B || candidates[0].B == candidates[1].A) {
107
+ // If candidate number < 2, or two bottom are joined, select 2 farthest edge
108
+ vector<Candidate<Pointf>> midList;
109
+ for (size_t i = 0; i < poly.Count; ++i) {
110
+ Pointf midPoint = (idx_wrap(i) + idx_wrap(i + 1)) / 2.0f;
111
+ midList.emplace_back(i, (i + 1) % poly.Count, midPoint);
112
+ }
113
+
114
+ vector<DistStruct> distList;
115
+
116
+ // Only found one good candidate, so search for the edge that's the furthest from this candidate
117
+ if (candidates.size() == 1) {
118
+ auto idx1a = candidates.back().A;
119
+ auto idx1b = candidates.back().B;
120
+ Candidate<Pointf> cand1{ idx1a, idx1b, (idx_wrap(idx1a) + idx_wrap(idx1b)) / 2.0f };
121
+ for (size_t j = 0; j < poly.Count; ++j) {
122
+ auto &cand2 = midList[j];
123
+
124
+ if (cand1.Touches(cand2)) continue;
125
+
126
+ float dist = length(cand1.C - cand2.C);
127
+ distList.emplace_back(cand1, cand2, dist);
128
+ }
129
+ } else {
130
+ for (size_t i = 0; i < poly.Count; ++i) {
131
+ for (size_t j = i + 1; j < poly.Count; ++j) {
132
+ auto &cand1 = midList[i];
133
+ auto &cand2 = midList[j];
134
+
135
+ if (cand1.Touches(cand2)) continue;
136
+
137
+ float dist = length(cand1.C - cand2.C);
138
+ distList.emplace_back(cand1, cand2, dist);
139
+ }
140
+ }
141
+ }
142
+ sort(begin(distList), end(distList), [] (auto a, auto b) { return a.Dist < b.Dist; });
143
+
144
+ if (distList.empty()) {
145
+ throw runtime_error("No valid bottom candidates found for this polygon!");
146
+ }
147
+
148
+ auto &bEdge = distList.back();
149
+ return vector<Edge>{ bEdge.A, bEdge.B };
150
+
151
+ } else {
152
+ return vector<Edge>{ candidates[0], candidates[1] };
153
+ }
154
+ }
155
+
156
+ void find_long_edges(const Polygon_<float> &poly, Edge *bottoms, vector<Edge> &outLongEdge1, vector<Edge> &outLongEdge2)
157
+ {
158
+ int32_t b1End = bottoms[0].B;
159
+ int32_t b2End = bottoms[1].B;
160
+
161
+ int32_t nPoints = poly.Count;
162
+
163
+ auto accum_into = [nPoints] (int32_t end1, int32_t end2, vector<Edge> &outEdge) {
164
+ int32_t i = (end1 + 1) % nPoints;
165
+ while ((i % nPoints) != end2) {
166
+ int32_t start = i > 0 ? i - 1 : nPoints - 1;
167
+ int32_t end = i % nPoints;
168
+ outEdge.emplace_back(start, end);
169
+ i = (i + 1) % nPoints;
170
+ }
171
+ };
172
+
173
+ accum_into(b1End, b2End, outLongEdge1);
174
+ accum_into(b2End, b1End, outLongEdge2);
175
+ }
176
+
177
+ float edge_length(const Polygon_<float> &poly, const vector<Edge> &edges)
178
+ {
179
+ float ret = 0.0f;
180
+ for (const Edge &e : edges) {
181
+ ret += length(poly[e.B] - poly[e.A]);
182
+ }
183
+ return ret;
184
+ }
185
+
186
+ vector<float> edge_lengths(const Polygon_<float> &poly, const vector<Edge> &edges)
187
+ {
188
+ if (edges.empty()) {
189
+ throw runtime_error("Found an empty edge!");
190
+ }
191
+
192
+ vector<float> ret;
193
+ ret.reserve(edges.size());
194
+
195
+ for (const Edge &e : edges) {
196
+ ret.push_back(length(poly[e.B] - poly[e.A]));
197
+ }
198
+
199
+ return ret;
200
+ }
201
+
202
+ void split_edge_sequence(const Polygon_<float> &poly, const vector<Edge> &edges,
203
+ const vector<float> &edgeLengths, float nParts,
204
+ vector<Pointf> &outPts);
205
+
206
+ void split_edge_sequence_by_step(const Polygon_<float> &poly, const vector<Edge> &longEdge1, const vector<Edge> &longEdge2,
207
+ float step, vector<Pointf> &outInnerPoints1, vector<Pointf> &outInnerPoints2)
208
+ {
209
+ auto edgeLengths1 = edge_lengths(poly, longEdge1);
210
+ auto edgeLengths2 = edge_lengths(poly, longEdge2);
211
+
212
+ float totalLength = (accumulate(begin(edgeLengths1), end(edgeLengths1), 0.0f) + accumulate(begin(edgeLengths2), end(edgeLengths2), 0.0f)) / 2;
213
+
214
+ float nParts = max<float>(ceil(totalLength / step), 2);
215
+
216
+ split_edge_sequence(poly, longEdge1, edgeLengths1, nParts, outInnerPoints1);
217
+ split_edge_sequence(poly, longEdge2, edgeLengths2, nParts, outInnerPoints2);
218
+ }
219
+
220
+ void split_edge_sequence(const Polygon_<float> &poly, const vector<Edge> &edges,
221
+ const vector<float> &edgeLengths, float nParts,
222
+ vector<Pointf> &outPts)
223
+ {
224
+ vector<float> elCumSum = vec_cumsum(edgeLengths);
225
+
226
+ float totalLength = elCumSum.back();
227
+ float lengthPerPart = totalLength / (nParts - 1);
228
+
229
+ size_t iNumParts = nParts;
230
+ size_t currNode = 0;
231
+ size_t ctr = 0;
232
+ for (float i = 0.0f; ctr < iNumParts; i += 1.0f, ++ctr) {
233
+ float t = min(i * lengthPerPart, totalLength);
234
+
235
+ while (t > elCumSum[currNode + 1]) {
236
+ ++currNode;
237
+ }
238
+
239
+ Edge currEdge = edges[currNode];
240
+ Pointf e1 = poly[currEdge.A];
241
+ Pointf e2 = poly[currEdge.B];
242
+
243
+ float currLen = edgeLengths[currNode];
244
+
245
+ Pointf sampledPt;
246
+
247
+ if (currLen > 0) {
248
+ float deltaT = t - elCumSum[currNode];
249
+ float ratio = deltaT / currLen;
250
+ sampledPt = e1 + ratio * (e2 - e1);
251
+ } else {
252
+ sampledPt = e1;
253
+ }
254
+
255
+ outPts.push_back(sampledPt);
256
+ }
257
+ }
258
+
259
+ string print_poly(const Polyf &poly) {
260
+ ostringstream oss;
261
+ oss << "[";
262
+ for (size_t i = 0; i < poly.Count; ++i) {
263
+ if (i > 0) {
264
+ oss << ", ";
265
+ }
266
+ oss << "(" << poly[i].X << ", " << poly[i].Y << ")";
267
+ }
268
+ oss << "]";
269
+ return oss.str();
270
+ }
271
+
272
+ } // namespace graph_detection
nemo-retriever-ocr/cpp/graph_detection/encode_util.h ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <vector>
8
+ #include <random>
9
+ #include <algorithm>
10
+
11
+ #include "../geometry.h"
12
+
13
+ namespace graph_detection {
14
+
15
+
16
+
17
+ struct Edge {
18
+ int32_t A;
19
+ int32_t B;
20
+
21
+ Edge() = default;
22
+ Edge(int32_t a, int32_t b) : A(a), B(b) {}
23
+
24
+ bool Touches(int32_t idx) const { return A == idx || B == idx; }
25
+ bool Touches(const Edge &other) const;
26
+ };
27
+
28
+ inline
29
+ bool edge_touches(const Edge &edge, int32_t vertex) {
30
+ return edge.A == vertex || edge.B == vertex;
31
+ }
32
+
33
+ inline
34
+ bool Edge::Touches(const Edge &other) const {
35
+ return edge_touches(other, A) || edge_touches(other, B);
36
+ }
37
+
38
+ typedef Point_<float> Pointf;
39
+ typedef AABB_<float> AABBf;
40
+ typedef Polygon_<float> Polyf;
41
+ typedef std::vector<Pointf> Polyline;
42
+
43
+ std::vector<Edge> find_bottom(const Polygon_<float> &poly, bool useVertexOrder);
44
+
45
+ void find_long_edges(const Polygon_<float> &poly, Edge *bottoms, std::vector<Edge> &outLongEdge1, std::vector<Edge> &outLongEdge2);
46
+
47
+ void split_edge_sequence_by_step(const Polygon_<float> &poly, const std::vector<Edge> &longEdge1, const std::vector<Edge> &longEdge2,
48
+ float step, std::vector<Pointf> &outInnerPoints1, std::vector<Pointf> &outInnerPoints2);
49
+
50
+ std::string print_poly(const Polyf &poly);
51
+
52
+ template<typename T>
53
+ inline
54
+ std::vector<T> vec_cumsum(const std::vector<T> &v)
55
+ {
56
+ std::vector<T> ret;
57
+ ret.reserve(v.size() + 1);
58
+ ret.push_back(0);
59
+ for (T val : v) {
60
+ ret.push_back(ret.back() + val);
61
+ }
62
+ return ret;
63
+ }
64
+
65
+ template<typename RandEng, typename Fn>
66
+ inline
67
+ void n_choose_k(size_t n, size_t k, RandEng &randEng, Fn fn)
68
+ {
69
+ if (k == 0) return;
70
+
71
+ // TODO(mranzinger): This algorithm can be replaced with sampling from a geometric
72
+ // distribution, which drastically reduces the runtime complexity
73
+ for (size_t i = 0; i < n; ++i) {
74
+ size_t leftover = n - i;
75
+ if (leftover <= k) {
76
+ fn(i);
77
+ --k;
78
+ } else {
79
+ float p = std::uniform_real_distribution<float>(0.0f, 1.0f)(randEng);
80
+ float probSample = float{k} / float{leftover};
81
+ if (p < probSample) {
82
+ fn(i);
83
+ --k;
84
+ }
85
+ }
86
+ }
87
+ }
88
+
89
+ template<typename T>
90
+ inline T clamp(T val, T minVal, T maxVal) {
91
+ return std::max(std::min(val, maxVal), minVal);
92
+ }
93
+
94
+ inline
95
+ Pointf avg_point(const std::vector<Pointf> &points)
96
+ {
97
+ return std::accumulate(std::begin(points), std::end(points), Pointf(0,0)) / float(points.size());
98
+ }
99
+
100
+ inline
101
+ float vector_sin(const Pointf &pt)
102
+ {
103
+ // sin = y / len(pt)
104
+ return pt.Y / (length(pt) + 1e-8);
105
+ }
106
+
107
+ inline
108
+ float vector_cos(const Pointf &pt)
109
+ {
110
+ // cos = x / len(pt)
111
+ return pt.X / (length(pt) + 1e-8);
112
+ }
113
+
114
+ inline
115
+ void vector_cos_sin(const Pointf & pt, float &outCos, float &outSin)
116
+ {
117
+ float len = length(pt) + 1e-8;
118
+ outCos = pt.X / len;
119
+ outSin = pt.Y / len;
120
+ }
121
+
122
+ inline
123
+ float point_dist_to_line(const Pointf &l1, const Pointf &l2, const Pointf &pt)
124
+ {
125
+ auto d = l2 - l1;
126
+
127
+ auto lineLen = length(d);
128
+
129
+ if (lineLen > 0) {
130
+ float distance = abs(
131
+ d.Y * pt.X
132
+ - d.X * pt.Y
133
+ + l2.X * l1.Y
134
+ - l2.Y * l1.X
135
+ ) / lineLen;
136
+ return distance;
137
+ } else {
138
+ return length(pt - l1);
139
+ }
140
+ }
141
+
142
+ template<typename T>
143
+ T find_mode(std::vector<T> &inputs) {
144
+ using std::sort;
145
+ using std::begin;
146
+ using std::end;
147
+
148
+ if (inputs.empty()) {
149
+ throw std::runtime_error("Cannot find mode of empty distribution!");
150
+ }
151
+
152
+ sort(begin(inputs), end(inputs));
153
+
154
+ T currVal = inputs[0];
155
+ size_t currCount = 1;
156
+
157
+ T modeVal = inputs[0];
158
+ size_t modeCount = 1;
159
+
160
+ auto commitCurr = [&] () {
161
+ if (currCount > modeCount) {
162
+ modeCount = currCount;
163
+ modeVal = currVal;
164
+ }
165
+ };
166
+
167
+ for (size_t i = 1; i < inputs.size(); ++i) {
168
+ if (inputs[i] == currVal) {
169
+ ++currCount;
170
+ } else {
171
+ // Start of a new value
172
+ commitCurr();
173
+
174
+ currCount = 1;
175
+ currVal = inputs[i];
176
+ }
177
+ }
178
+
179
+ commitCurr();
180
+
181
+ return modeVal;
182
+ }
183
+
184
+ } // namespace graph_detection
nemo-retriever-ocr/cpp/half_ops.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "half_ops.cuh"
nemo-retriever-ocr/cpp/half_ops.cuh ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <torch/torch.h>
8
+
9
+ #include "cuda_intellisense.cuh"
10
+
11
+ #ifndef __CUDACC__
12
+ #pragma message("__CUDACC__ not defined!")
13
+ #else
14
+ #pragma message("__CUDACC__ defined!")
15
+ #endif
16
+
17
+ #ifdef __NVCC__
18
+ #define __qr_device__ __device__
19
+ #define __qr_host__ __host__
20
+ #define __qr_inline__ __forceinline__
21
+ #else
22
+ #define __qr_device__
23
+ #define __qr_host__
24
+ #define __qr_inline__ inline
25
+ #endif
26
+
27
+ #ifdef __CUDACC__
28
+ #include <cuda.h>
29
+ #include <cuda_runtime.h>
30
+ #include <cuda_fp16.h>
31
+
32
+
33
+ __qr_inline__ __device__ __half operator-(__half v) {
34
+ return __hneg(v);
35
+ }
36
+
37
+ __qr_inline__ __device__ __half operator+(__half a, __half b) {
38
+ return __hadd(a, b);
39
+ }
40
+
41
+ __qr_inline__ __device__ __half operator-(__half a, __half b) {
42
+ return __hsub(a, b);
43
+ }
44
+
45
+ __qr_inline__ __device__ __half operator*(__half a, __half b) {
46
+ return __hmul(a, b);
47
+ }
48
+
49
+ __qr_inline__ __device__ __half operator/(__half a, __half b) {
50
+ return __hdiv(a, b);
51
+ }
52
+
53
+ __qr_inline__ __device__ bool operator==(__half a, __half b) {
54
+ return __heq(a, b);
55
+ }
56
+
57
+ __qr_inline__ __device__ bool operator<(__half a, __half b) {
58
+ return __hlt(a, b);
59
+ }
60
+
61
+ __qr_inline__ __device__ bool operator>(__half a, __half b) {
62
+ return __hgt(a, b);
63
+ }
64
+
65
+ __qr_inline__ __device__ __half sqrt(__half v) {
66
+ return hsqrt(v);
67
+ }
68
+
69
+ __qr_inline__ __device__ __half floor(__half v) {
70
+ return hfloor(v);
71
+ }
72
+
73
+ __qr_inline__ __device__ __half ceil(__half v) {
74
+ return hceil(v);
75
+ }
76
+
77
+ __qr_inline__ __device__ __half max(__half a, __half b) {
78
+ return a > b ? a : b;
79
+ }
80
+ #endif //__CUDACC__
81
+
82
+ template<typename Src, typename Dest>
83
+ struct Convert {
84
+ __qr_inline__ static __qr_host__ __qr_device__ constexpr Dest From(Src value) { return static_cast<Dest>(value); }
85
+ __qr_inline__ static __qr_host__ __qr_device__ constexpr Src To(Dest value) { return static_cast<Src>(value); }
86
+ __qr_inline__ static __qr_host__ __qr_device__ constexpr Dest LeftToRight(Src value) { return static_cast<Dest>(value); }
87
+ __qr_inline__ static __qr_host__ __qr_device__ constexpr Src RightToLeft(Dest value) { return static_cast<Src>(value); }
88
+ };
89
+
90
+ #ifdef __CUDACC__
91
+ template<>
92
+ struct Convert<__half, float> {
93
+ __qr_inline__ static __host__ __device__ float From(__half value) { return __half2float(value); }
94
+ __qr_inline__ static __host__ __device__ __half To(float value) { return __float2half(value); }
95
+ __qr_inline__ static __host__ __device__ float LeftToRight(__half value) { return __half2float(value); }
96
+ __qr_inline__ static __host__ __device__ __half RightToLeft(float value) { return __float2half(value); }
97
+ };
98
+
99
+ template<typename Dest>
100
+ struct Convert<__half, Dest> : Convert<__half, float> {
101
+
102
+ };
103
+
104
+ namespace at {
105
+
106
+ template<>
107
+ inline __half* TensorBase::mutable_data_ptr() const {
108
+ TORCH_CHECK(scalar_type() == ScalarType::Half,
109
+ "expected scalar type Half but found ",
110
+ c10::toString(scalar_type()));
111
+ return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data());
112
+ }
113
+
114
+ template<>
115
+ inline __half* TensorBase::data_ptr() const {
116
+ TORCH_CHECK(scalar_type() == ScalarType::Half,
117
+ "expected scalar type Half but found ",
118
+ c10::toString(scalar_type()));
119
+ return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data());
120
+ }
121
+
122
+ }
123
+
124
+ template<typename T>
125
+ struct remap_half {
126
+ typedef T type;
127
+ };
128
+
129
+ template<>
130
+ struct remap_half<at::Half> {
131
+ typedef __half type;
132
+ };
133
+
134
+ template<typename T>
135
+ __half to_half(T val) {
136
+ return Convert<__half, T>::RightToLeft(val);
137
+ }
138
+
139
+ template<typename T>
140
+ struct fp_promote {
141
+ typedef T type;
142
+ };
143
+
144
+ template<>
145
+ struct fp_promote<__half> {
146
+ typedef float type;
147
+ };
148
+
149
+ #endif //__CUDACC__
nemo-retriever-ocr/cpp/local_ips/local_ips.h ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <torch/torch.h>
8
+
9
+ torch::Tensor ragged_quad_all_2_all_distance_v2(torch::Tensor embedQuads, torch::Tensor quadsPerExample,
10
+ float xFactor, float yFactor,
11
+ bool allowSelfDistance);
nemo-retriever-ocr/cpp/local_ips/quad_all_2_all_dist_v2.cu ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ #include <iostream>
7
+
8
+ #include <cooperative_groups.h>
9
+ #include <cooperative_groups/reduce.h>
10
+
11
+ #include <thrust/binary_search.h>
12
+ #include <thrust/execution_policy.h>
13
+
14
+ #include "local_ips.h"
15
+ #include "../cuda_intellisense.cuh"
16
+ #include "../common.h"
17
+ #include "../geometry.h"
18
+
19
+ using namespace std;
20
+ namespace cg = cooperative_groups;
21
+
22
+ typedef Point_<float> Pointf;
23
+
24
+ __device__ inline
25
+ float square(float val) { return val * val; }
26
+
27
+ __global__
28
+ void device_quad_all_2_all_distance_v2(torch::PackedTensorAccessor64<float, 4> allEmbedQuads,
29
+ torch::PackedTensorAccessor64<int64_t, 1> allRegionCounts,
30
+ torch::PackedTensorAccessor64<int64_t, 1> csWorkPerExample,
31
+ torch::PackedTensorAccessor64<float, 3> outDistances,
32
+ float xFactor, float yFactor,
33
+ bool allowSelfDistance)
34
+ {
35
+ // Note that the blockIdx.x is on purpose here
36
+ int64_t workIdx = blockIdx.x * blockDim.y + threadIdx.y;
37
+
38
+ if (workIdx >= csWorkPerExample[csWorkPerExample.size(0) - 1]) return;
39
+
40
+ auto exIter = thrust::upper_bound(thrust::seq,
41
+ csWorkPerExample.data(), csWorkPerExample.data() + csWorkPerExample.size(0),
42
+ workIdx);
43
+
44
+ const int64_t exIdx = exIter - csWorkPerExample.data();
45
+
46
+ const int64_t workStart = exIdx == 0 ? 0 : csWorkPerExample[exIdx - 1];
47
+ const int64_t workOff = workIdx - workStart;
48
+
49
+ const int64_t row = workOff / allRegionCounts[exIdx];
50
+ const int64_t col = workOff % allRegionCounts[exIdx];
51
+
52
+ auto taRowQuad = allEmbedQuads[exIdx][row];
53
+ auto taColQuad = allEmbedQuads[exIdx][col];
54
+
55
+ Quad_<float> rowQuad(taRowQuad.data()),
56
+ colQuad(taColQuad.data());
57
+
58
+ auto p1 = (rowQuad[0] + rowQuad[3]) / 2.0f;
59
+ auto p2 = (rowQuad[1] + rowQuad[2]) / 2.0f;
60
+
61
+ auto vX = p2 - p1;
62
+ auto lenVX = length(vX);
63
+ if (lenVX > 0) {
64
+ vX = vX / max(lenVX, 1e-8f);
65
+ } else {
66
+ vX = { 1, 0 };
67
+ }
68
+
69
+ Pointf vY{ -vX.Y, vX.X };
70
+
71
+ auto reproj = [&vX, &vY, xFactor, yFactor] (const Pointf &pt) {
72
+ auto dX = dot(pt, vX);
73
+ if (dX >= 0) {
74
+ dX *= xFactor;
75
+ }
76
+ auto dY = dot(pt, vY);
77
+ if (dY >= 0) {
78
+ dY *= yFactor;
79
+ }
80
+
81
+ return Pointf{ dX, dY };
82
+ };
83
+
84
+ auto tile16 = cg::tiled_partition<16>(cg::this_thread_block());
85
+
86
+ // Figure out which vertices this thread is processing
87
+ const int64_t rowVertexIdx = tile16.thread_rank() / 4;
88
+ const int64_t colVertexIdx = tile16.thread_rank() % 4;
89
+
90
+ float dist;
91
+ if (row != col) {
92
+ Segment_<float> rowSeg{ rowQuad[rowVertexIdx], rowQuad[(rowVertexIdx + 1) % 4] };
93
+ Segment_<float> colSeg{ colQuad[colVertexIdx], colQuad[(colVertexIdx + 1) % 4] };
94
+
95
+ Segment_<float> minSeg = shortest_line_between_segments(rowSeg, colSeg);
96
+
97
+ Point_<float> vSeg = minSeg.B - minSeg.A;
98
+
99
+ vSeg = reproj(vSeg);
100
+
101
+ dist = length(vSeg);
102
+ } else if (allowSelfDistance) {
103
+ dist = 0;
104
+ } else {
105
+ dist = numeric_limits<float>::infinity();
106
+ }
107
+
108
+ // Now find the minimum distance across the group
109
+ int lane = tile16.thread_rank();
110
+ // Each iteration halves the number of active threads
111
+ // Each thread gets the partial min[i] to min[lane+i]
112
+ #pragma unroll
113
+ for (uint32_t i = 1; i < 16; i <<= 1) {
114
+ auto otherDist = tile16.shfl_down(dist, i);
115
+ dist = min(dist, otherDist);
116
+ }
117
+
118
+ #ifndef NDEBUG
119
+ float lowestDist = tile16.shfl(dist, 0);
120
+ assert(dist >= lowestDist);
121
+ #endif
122
+
123
+ if (lane == 0) {
124
+ outDistances[exIdx][row][col] = dist;
125
+ }
126
+ }
127
+
128
+ torch::Tensor ragged_quad_all_2_all_distance_v2(torch::Tensor embedQuads, torch::Tensor regionCounts,
129
+ float xFactor, float yFactor,
130
+ bool allowSelfDistance)
131
+ {
132
+ if (!embedQuads.is_contiguous()) {
133
+ throw std::runtime_error("Expected `embedQuads` to be contiguous!");
134
+ }
135
+
136
+ auto outDistances = torch::zeros({ embedQuads.size(0), embedQuads.size(1), embedQuads.size(1) },
137
+ embedQuads.options());
138
+
139
+ if (embedQuads.numel() == 0) {
140
+ return outDistances;
141
+ }
142
+
143
+ auto workPerExample = regionCounts * regionCounts;
144
+
145
+ auto csWorkPerExample = torch::cumsum(workPerExample, 0);
146
+
147
+ int64_t totalWork = csWorkPerExample[-1].item<int64_t>();
148
+
149
+ dim3 blockSize(16, 2);
150
+ dim3 gridSize(div_up(totalWork, blockSize.y), 1);
151
+
152
+ device_quad_all_2_all_distance_v2 KERNEL_ARG2(gridSize, blockSize) (
153
+ embedQuads.packed_accessor64<float, 4>(),
154
+ regionCounts.packed_accessor64<int64_t, 1>(),
155
+ csWorkPerExample.packed_accessor64<int64_t, 1>(),
156
+ outDistances.packed_accessor64<float, 3>(),
157
+ xFactor, yFactor,
158
+ allowSelfDistance
159
+ );
160
+
161
+ return outDistances;
162
+ }
nemo-retriever-ocr/cpp/module.cpp ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ #include "quad_rectify/quad_rectify.h"
7
+ #include "non_maximal_suppression/non_maximal_suppression.h"
8
+ #include "geometry_api/geometry_api.h"
9
+ #include "beam_decode/beam_decode.h"
10
+ #include "better_grid_sample/grid_sample.h"
11
+ #include "sparse_select/sparse_select.h"
12
+ #include "text_region_grouping/text_region_grouping.h"
13
+ #include "local_ips/local_ips.h"
14
+
15
+ #include <torch/extension.h>
16
+ #include <pybind11/stl.h>
17
+
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("quad_rectify_calc_quad_width", &quad_rectify_calc_quad_width,
21
+ "Quad Rectify Calc Quad Width C++",
22
+ py::arg("quads"),
23
+ py::arg("output_height"),
24
+ py::arg("round_factor") = 16,
25
+ py::arg("max_width") = 0
26
+ );
27
+ m.def("quad_rectify_forward", &quad_rectify_forward, "Quad Rectify Forward C++",
28
+ py::arg("quads"),
29
+ py::arg("image_height"), py::arg("image_width"),
30
+ py::arg("output_height"), py::arg("output_width"),
31
+ py::arg("isotropic") = true
32
+ );
33
+ m.def("quad_rectify_backward", &quad_rectify_backward, "Quad Rectify Backward C++",
34
+ py::arg("quads"), py::arg("grad_output"),
35
+ py::arg("image_height"), py::arg("image_width"),
36
+ py::arg("isotropic") = true
37
+ );
38
+ m.def("quad_non_maximal_suppression", &quad_non_maximal_suppression, "Quad Non-Maximal Suppression C++",
39
+ py::arg("quads"), py::arg("probs"),
40
+ py::arg("prob_threshold"), py::arg("iou_threshold"),
41
+ py::arg("kernel_height"), py::arg("kernel_width"),
42
+ py::arg("max_regions"),
43
+ py::arg("verbose") = false
44
+ );
45
+
46
+ py::class_<LanguageModel>(m, "LanguageModel");
47
+
48
+ m.def("beam_decode", &beam_decode, "beam_decode c++",
49
+ py::arg("probs"),
50
+ py::arg("beam_size") = 100,
51
+ py::arg("blank") = 0,
52
+ py::arg("min_prob") = 0.001,
53
+ py::arg("lang_model") = static_cast<LanguageModel*>(nullptr),
54
+ py::arg("lm_weight") = 1,
55
+ py::arg("combine_duplicates") = true
56
+ );
57
+
58
+ py::class_<TokenMappingWrapper, TokenMappingWrapper::Ptr>(m, "TokenMapping");
59
+
60
+ m.def("create_token_mapping", &create_token_mapping, "create token mapping c++",
61
+ py::arg("token_mapping")
62
+ );
63
+
64
+ m.def("decode_sequences", &decode_sequences, "decode_sequences c++",
65
+ py::arg("tokens"), py::arg("language_model"),
66
+ py::arg("probs") = nullptr
67
+ );
68
+
69
+ m.def("create_sbo_lm", &create_sbo_lm, "create_sbo_lm c++",
70
+ py::arg("data_file_path"),
71
+ py::arg("token_mapping"),
72
+ py::arg("backoff") = 0.4
73
+ );
74
+
75
+ m.def("indirect_grid_sample_forward", &indirect_grid_sample_forward, "indirect_grid_sample::forward c++",
76
+ py::arg("input"), py::arg("grid"), py::arg("input_indices"), py::arg("method")
77
+ );
78
+ m.def("indirect_grad_sample_backward", &indirect_grad_sample_backward, "indirect_grid_sample::backward c++",
79
+ py::arg("grad_output"), py::arg("input"), py::arg("grid"), py::arg("input_indices"), py::arg("method")
80
+ );
81
+ m.def("region_counts_to_indices", &region_counts_to_indices, "region counts to indices",
82
+ py::arg("region_counts"), py::arg("num_outputs")
83
+ );
84
+
85
+ m.def("rrect_to_quads", &rrect_to_quads, "convert rotated rectangle to quadrangles",
86
+ py::arg("rrects"), py::arg("cell_size")
87
+ );
88
+ m.def("rrect_to_quads_backward", &rrect_to_quads_backward, "gradient of rrect_to_quads",
89
+ py::arg("rrects"), py::arg("grad_output")
90
+ );
91
+
92
+ m.def("sparse_select", &sparse_select, "Select sparse tensor(s) given a set of indices",
93
+ py::arg("sparse_counts"), py::arg("sparse_tensors"), py::arg("select_indices")
94
+ );
95
+
96
+ m.def("text_region_grouping", &text_region_grouping, "Clusters all of the text into lines and phrases",
97
+ py::arg("quads"), py::arg("counts"),
98
+ py::arg("horizontal_tolerance") = 2.0f,
99
+ py::arg("vertical_tolerance") = 0.5f,
100
+ py::arg("verbose") = false
101
+ );
102
+
103
+ m.def("dense_relations_to_graph", &dense_relations_to_graph, "Converts a dense relational tensor to a graph",
104
+ py::arg("relations")
105
+ );
106
+
107
+ m.def("ragged_quad_all_2_all_distance_v2", &ragged_quad_all_2_all_distance_v2, "get the all-to-all distances in ragged-batch quad mode",
108
+ py::arg("embed_quads"), py::arg("region_counts"),
109
+ py::arg("x_factor") = 1.0f,
110
+ py::arg("y_factor") = 1.0f,
111
+ py::arg("allow_self_distance") = true
112
+ );
113
+
114
+ m.def("calc_poly_min_rrect", &calc_poly_min_rrect, "calculate a reasonable bounding rectangle for a given text polygon",
115
+ py::arg("vertices")
116
+ );
117
+
118
+ m.def("get_rel_continuation_cos", &get_rel_continuation_cos, "c++ get relation cosine between 2 regions",
119
+ py::arg("rrect_a"), py::arg("rrect_b")
120
+ );
121
+
122
+ m.def("get_poly_bounds_quad", &get_poly_bounds_quad, "c++ get polygon bounds",
123
+ py::arg("poly")
124
+ );
125
+ }
nemo-retriever-ocr/cpp/non_maximal_suppression/cpu_non_maximal_suppression.cpp ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "non_maximal_suppression.h"
6
+
7
+ #include <algorithm>
8
+ #include "../geometry.h"
9
+
10
+ using namespace std;
11
+
12
+
13
+ template<typename scalar_t>
14
+ void visit_node(
15
+ const torch::TensorAccessor<scalar_t, 4> &quads,
16
+ const torch::TensorAccessor<scalar_t, 2> &probs,
17
+ const torch::TensorAccessor<int32_t, 3> &adjacency,
18
+ MergeQuad_<scalar_t> &mQuad,
19
+ unordered_set<int32_t> &visited,
20
+ int64_t r, int64_t c, int32_t vIdx)
21
+ {
22
+ if (visited.count(vIdx)) {
23
+ return;
24
+ }
25
+ visited.insert(vIdx);
26
+
27
+ int32_t *pAdj = adjacency[r][c].data();
28
+
29
+ int32_t adjCt = pAdj[0];
30
+ assert(adjCt > 0);
31
+
32
+ mQuad.Append(Quad_<scalar_t>(quads[r][c].data()), probs[r][c]);
33
+
34
+ int32_t *pOff = pAdj + 2;
35
+ int32_t *pEnd = pAdj + adjCt + 1;
36
+
37
+ const int32_t W = quads.size(1);
38
+
39
+ for (; pOff != pEnd; ++pOff) {
40
+ int32_t vIdx2 = *pOff;
41
+ int32_t r2 = vIdx2 / W;
42
+ int32_t c2 = vIdx2 % W;
43
+
44
+ visit_node(quads, probs, adjacency, mQuad, visited, r2, c2, vIdx2);
45
+ }
46
+ }
47
+
48
+ template<typename scalar_t>
49
+ std::vector<torch::Tensor> quad_nms_from_adjacency_impl(
50
+ const torch::TensorAccessor<scalar_t, 5> &quads,
51
+ const torch::TensorAccessor<scalar_t, 3> &probs,
52
+ const torch::TensorAccessor<int32_t, 4> &adjacency,
53
+ scalar_t probThreshold, scalar_t iouThreshold,
54
+ int64_t maxRegions)
55
+ {
56
+ const uint64_t B = quads.size((int)0);
57
+ const int64_t H = quads.size((int)1);
58
+ const int64_t W = quads.size((int)2);
59
+
60
+ typedef MergeQuad_<scalar_t> MQuad;
61
+ typedef EmbedQuad_<scalar_t> EFQuad;
62
+
63
+ vector<vector<EFQuad>> batchQuads{ static_cast< const unsigned int >( B ) };
64
+ vector<vector<EFQuad>> allQuads{ static_cast< const unsigned int >( B ) };
65
+ vector<vector<vector<size_t>>> batchAdjIdxs{ static_cast< const unsigned int >( B ) };
66
+
67
+ #pragma omp parallel num_threads (8)
68
+ {
69
+ #pragma omp for
70
+ for (int64_t b = 0; b < B; ++b) {
71
+ unordered_set<int32_t> visited;
72
+
73
+ for (int64_t r = 0; r < H; ++r) {
74
+ for (int64_t c = 0; c < W; ++c) {
75
+ auto currProb = probs[b][r][c];
76
+
77
+ if (currProb < probThreshold) {
78
+ continue;
79
+ }
80
+
81
+ int32_t vIdx = r * W + c;
82
+
83
+ // Ensure that this quad hasn't already been merged
84
+ if (visited.count(vIdx)) {
85
+ continue;
86
+ }
87
+
88
+ MQuad mQuad{ZeroInitTag{}};
89
+ visit_node(quads[b], probs[b], adjacency[b], mQuad, visited, r, c, vIdx);
90
+
91
+ batchQuads[b].push_back(mQuad.Commit());
92
+ }
93
+ }
94
+ }
95
+
96
+ #pragma omp single
97
+ {
98
+ for (size_t b = 0; b < B; ++b) {
99
+ size_t numQuads = batchQuads[b].size();
100
+ batchAdjIdxs[b].resize(numQuads);
101
+ for (int64_t n = 0; n < numQuads; ++n) {
102
+ #pragma omp task default(none) shared(batchAdjIdxs, batchQuads, iouThreshold) firstprivate(b, numQuads, n)
103
+ {
104
+ for (int64_t m = n + 1; m < numQuads; ++m) {
105
+ vector<size_t> &adjIdxs = batchAdjIdxs[b][n];
106
+ vector<EFQuad> &quads = batchQuads[b];
107
+ auto iou = quads[n].IOU(quads[m]);
108
+
109
+ if (iou > iouThreshold) {
110
+ adjIdxs.push_back(m);
111
+ }
112
+ }
113
+ }
114
+ }
115
+ }
116
+
117
+ #pragma omp taskwait
118
+ }
119
+
120
+ #pragma omp for
121
+ for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) {
122
+ vector<vector<size_t>> &adjIdxs = batchAdjIdxs[batchIdx];
123
+ vector<EFQuad> &quads = batchQuads[batchIdx];
124
+ vector<EFQuad> &finalQuads = allQuads[batchIdx];
125
+
126
+ // Step 3: Using depth first search, merge the regions
127
+ unordered_set<size_t> visited;
128
+ for (int64_t n = 0; n < quads.size(); ++n) {
129
+ EFQuad currQuad;
130
+ visit_node(quads, n, adjIdxs, currQuad, visited);
131
+
132
+ if (currQuad.NumQuads > 0) {
133
+ currQuad.Prepare();
134
+
135
+ finalQuads.push_back(currQuad);
136
+ }
137
+ }
138
+
139
+ // Only sort the part that we want to keep
140
+ partial_sort(begin(finalQuads),
141
+ begin(finalQuads) + std::min<int64_t>(finalQuads.size(), maxRegions),
142
+ end(finalQuads),
143
+ [] (auto a, auto b) {
144
+ return a.Confidence > b.Confidence;
145
+ }
146
+ );
147
+
148
+ // Truncate the low confidence regions
149
+ if (finalQuads.size() > maxRegions) {
150
+ finalQuads.resize(maxRegions);
151
+ }
152
+
153
+ //cout << "Ex " << batchIdx << " quads:" << endl << finalQuads << endl << endl;
154
+ }
155
+
156
+ } // End parallel
157
+
158
+ int64_t numOutQuads = 0;
159
+ for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) {
160
+ numOutQuads += allQuads[batchIdx].size();
161
+ }
162
+
163
+ // Step 4: Convert the quads into tensor representation
164
+ auto outQuadTensor = torch::empty({ numOutQuads, 4, 2 }, torch::kFloat32);
165
+ auto outConfTensor = torch::empty({ numOutQuads }, torch::kFloat32);
166
+ torch::Tensor outCountTensor = torch::empty({ static_cast<int64_t>( allQuads.size() ) }, torch::kInt64);
167
+
168
+ auto outQuadAccess = outQuadTensor.accessor<float, 3>();
169
+ auto outConfAccess = outConfTensor.accessor<float, 1>();
170
+ auto outCountAccess = outCountTensor.accessor<int64_t, 1>();
171
+
172
+ int64_t offset = 0;
173
+ for (int64_t batchIdx = 0; batchIdx < allQuads.size(); ++batchIdx) {
174
+ vector<EFQuad> &exQuads = allQuads[batchIdx];
175
+
176
+ outCountAccess[batchIdx] = exQuads.size();
177
+
178
+ for (int64_t qIdx = 0; qIdx < exQuads.size(); ++qIdx, ++offset) {
179
+ copy_quad(exQuads[qIdx], outQuadAccess[offset].data());
180
+ outConfAccess[offset] = exQuads[qIdx].Confidence;
181
+ }
182
+ }
183
+
184
+ return { outQuadTensor, outConfTensor, outCountTensor };
185
+ }
186
+
187
+ std::vector<torch::Tensor> quad_nms_from_adjacency(
188
+ torch::Tensor quads, torch::Tensor probs, torch::Tensor adjacency,
189
+ float probThreshold, float iouThreshold,
190
+ int64_t maxRegions)
191
+ {
192
+ std::vector<torch::Tensor> ret;
193
+
194
+ AT_DISPATCH_FLOATING_TYPES(
195
+ quads.scalar_type(),
196
+ "quad_nms_from_adjacency",
197
+ ([&] {
198
+ ret = quad_nms_from_adjacency_impl<scalar_t>(
199
+ quads.accessor<scalar_t, 5>(),
200
+ probs.accessor<scalar_t, 3>(),
201
+ adjacency.accessor<int32_t, 4>(),
202
+ probThreshold, iouThreshold,
203
+ maxRegions
204
+ );
205
+ })
206
+ );
207
+
208
+ return ret;
209
+ }
nemo-retriever-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu ADDED
@@ -0,0 +1,1720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #include "non_maximal_suppression.h"
6
+
7
+ #include <cooperative_groups.h>
8
+ #include <cooperative_groups/reduce.h>
9
+
10
+ #include <thrust/binary_search.h>
11
+ #include <thrust/device_vector.h>
12
+ #include <thrust/execution_policy.h>
13
+
14
+ #include <c10/cuda/CUDACachingAllocator.h>
15
+
16
+ #include <trove/ptr.h>
17
+
18
+ #include "../cuda_intellisense.cuh"
19
+ #include "../geometry.h"
20
+ #include "../common.h"
21
+ #include "../scope_timer.h"
22
+ #include "strided_quad.h"
23
+
24
+ // If this flag is turned on, then a bunch of checks will be inserted to ensure that the same results are produced by
25
+ // successive calls to NMS. This means that it makes the library unusable outside of a debug context, so beware!
26
+ //#define NMS_VERIFY_CORRECTNESS
27
+
28
+ namespace cg = cooperative_groups;
29
+ namespace ix = torch::indexing;
30
+
31
+ inline
32
+ void print_tensor_stats2(const std::string &msg, const torch::Tensor& tensor) {
33
+
34
+ auto fTensor = tensor.to(torch::kDouble).cpu();
35
+
36
+ std::stringstream ss;
37
+ if (tensor.numel() > 1) {
38
+ ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << " Max: " << fTensor.max().item<double>() << " Min: " << fTensor.min().item<double>() << " Mean: " << fTensor.mean().item<double>() << " Std: " << fTensor.std().item<double>();
39
+ }
40
+ else if (tensor.numel() == 1) {
41
+ ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << " Value: " << fTensor.item<double>() << std::endl;
42
+ }
43
+ else {
44
+ ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << std::endl;
45
+ }
46
+ std::cout << ss.str() << std::endl;
47
+ }
48
+
49
+ inline
50
+ void print_tensor_vec_stats2(std::string msg, const std::vector<torch::Tensor>& tensorVec) {
51
+ std::cout << msg << " Size: " << tensorVec.size() << std::endl;
52
+ std::stringstream ss;
53
+ msg = " - ";
54
+ for (int i = 0; i < tensorVec.size(); ++i) {
55
+ ss << msg << "[" << i << "]:";
56
+ auto tensor = tensorVec[i];
57
+ print_tensor_stats2(ss.str(), tensor);
58
+ ss.str("");
59
+ }
60
+ }
61
+
62
+ std::ostream &operator<<(std::ostream &os, dim3 d)
63
+ {
64
+ return os << "(" << d.x << ", " << d.y << ", " << d.z << ")";
65
+ }
66
+
67
+ #define ADD_OP2(vector2_t) __device__ \
68
+ vector2_t operator+(const vector2_t &a, const vector2_t &b) { \
69
+ return { a.x + b.x, a.y + b.y }; \
70
+ }
71
+ ADD_OP2(float2);
72
+ ADD_OP2(double2);
73
+ #undef ADD_OP2
74
+
75
+ #define ADD_OP4(vector4_t) __device__ \
76
+ vector4_t operator+(const vector4_t &a, const vector4_t &b) { \
77
+ return { a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w }; \
78
+ }
79
+ ADD_OP4(float4);
80
+ ADD_OP4(double4);
81
+ #undef ADD_OP4
82
+
83
+ template<typename T, size_t Size>
84
+ __device__
85
+ std::array<T, Size> operator+(const std::array<T, Size> &a, const std::array<T, Size> &b) {
86
+ std::array<T, Size> ret;
87
+ #pragma unroll
88
+ for (size_t i = 0; i < Size; ++i) {
89
+ ret._Elems[i] = a._Elems[i] + b._Elems[i];
90
+ }
91
+ return ret;
92
+ }
93
+
94
+ #if __CUDA_ARCH__ >= 800
95
+ #define __reduce_add_full_warp(val) __reduce_add_sync(0xFFFFFFFF, val)
96
+ #define __reduce_max_full_warp(val) __reduce_max_sync(0xFFFFFFFF, val)
97
+ #define __reduce_min_full_warp(val) __reduce_min_sync(0xFFFFFFFF, val)
98
+ #else
99
+ #define __reduce_add_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::plus<decltype(val)>())
100
+ #define __reduce_max_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::greater<decltype(val)>())
101
+ #define __reduce_min_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::less<decltype(val)>())
102
+ #endif
103
+
104
+ template<typename T>
105
+ struct TToVec;
106
+ template<>
107
+ struct TToVec<float> { typedef float2 type2; typedef float4 type4; };
108
+ template<>
109
+ struct TToVec<double> { typedef double2 type2; typedef double4 type4; };
110
+
111
+ template<typename T, typename accessor_t>
112
+ __device__
113
+ void write_embed_quad(accessor_t &acc, const MergeQuad_<T> &quad, int64_t storeOff)
114
+ {
115
+ constexpr auto EMBED_QUAD_SIZE = sizeof(EmbedQuad_<T>) / sizeof(T);
116
+ static_assert(EMBED_QUAD_SIZE == 10, "Unsupported embed quad size!");
117
+
118
+ const T *mergeBuff = reinterpret_cast<const T*>(&quad);
119
+
120
+ const T confidence = quad.Confidence;
121
+ const auto i = threadIdx.x;
122
+
123
+ if (i >= 10) {
124
+ return;
125
+ }
126
+
127
+ T outVal;
128
+ // Coordinates
129
+ if (i < 8) {
130
+ outVal = mergeBuff[i] / confidence;
131
+ // Confidence
132
+ } else if (i == 8) {
133
+ outVal = confidence / mergeBuff[9];
134
+ // NumQuads
135
+ } else {
136
+ outVal = mergeBuff[9];
137
+ }
138
+
139
+ acc[i][storeOff] = outVal;
140
+ }
141
+
142
+
143
+ template<typename group_t, typename ...Args>
144
+ __device__
145
+ void ordered_print(group_t &group, const char *const fmt, const Args& ...args)
146
+ {
147
+ for (uint32_t i = 0; i < group.size(); ++i) {
148
+ if (group.thread_rank() == i) {
149
+ printf(fmt, args...);
150
+ }
151
+ group.sync();
152
+ }
153
+ }
154
+
155
+ template<typename T>
156
+ __global__
157
+ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
158
+ torch::PackedTensorAccessor64<T, 3> allConfs,
159
+ T confThreshold, T iouThreshold,
160
+ torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
161
+ torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads
162
+ #ifdef NMS_VERIFY_CORRECTNESS
163
+ , torch::PackedTensorAccessor64<int32_t, 2> allOutIds
164
+ #endif
165
+ )
166
+ {
167
+ typedef InPlaceQuad_<T> Quadf;
168
+ static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
169
+
170
+ constexpr uint32_t ALL_MASK = 0xFFFFFFFF;
171
+ constexpr uint32_t WARP_SIZE = 32;
172
+ constexpr T MIN_VALID_AREA = 8;
173
+
174
+ const uint32_t B = allQuads.size(0);
175
+ const uint32_t H = allQuads.size(1);
176
+
177
+ const uint32_t b = blockIdx.z;
178
+ const uint32_t r = blockIdx.y * blockDim.y + threadIdx.y;
179
+
180
+ if (r >= H) {
181
+ return;
182
+ }
183
+
184
+ #define threadRank threadIdx.x
185
+
186
+ auto rawQuads = reinterpret_cast<Quadf*>(allQuads[b][r].data());
187
+ #if defined(NDEBUG)
188
+ trove::coalesced_ptr<Quadf> quads(rawQuads);
189
+ #else
190
+ auto quads = rawQuads;
191
+ #endif
192
+
193
+ auto confs = allConfs[b][r];
194
+
195
+ T conf = confs[threadRank];
196
+
197
+ bool quadValid = conf >= confThreshold;
198
+ uint32_t ballot = __ballot_sync(ALL_MASK, quadValid);
199
+
200
+ // No valid quads in this window, so we're done!
201
+ if (ballot == 0) {
202
+ return;
203
+ }
204
+
205
+ const Quadf currQuad = quads[threadRank];
206
+
207
+ const T qArea = currQuad.Area();
208
+
209
+ quadValid = quadValid && qArea > MIN_VALID_AREA;
210
+ ballot = __ballot_sync(ALL_MASK, quadValid);
211
+ if (ballot == 0) {
212
+ return;
213
+ }
214
+ if (! quadValid) {
215
+ conf = 0;
216
+ }
217
+
218
+ MergeQuad_<T> qAccum{ZeroInitTag{}};
219
+
220
+ Quadf prevQuad;
221
+ auto pCurrQuad = reinterpret_cast<const T*>(&currQuad);
222
+ auto pPrevQuad = reinterpret_cast<T*>(&prevQuad);
223
+ #pragma unroll
224
+ for (uint32_t i = 0; i < 8; ++i) {
225
+ pPrevQuad[i] = __shfl_up_sync(ALL_MASK, pCurrQuad[i], 1);
226
+ }
227
+ T prevConf = __shfl_up_sync(ALL_MASK, conf, 1);
228
+
229
+ if (threadRank == 0) {
230
+ prevConf = 0;
231
+ }
232
+
233
+ bool iouValid = false;
234
+ T iou = 0;
235
+ if (quadValid) {
236
+ qAccum.Append(currQuad, conf);
237
+
238
+ if (prevConf >= confThreshold) {
239
+ iou = prevQuad.IOU_UpperBound(currQuad);
240
+ if (iou >= iouThreshold) {
241
+ iouValid = true;
242
+ }
243
+ }
244
+ }
245
+
246
+ // This is the start of a span if the current confidence is above threshold, but the quad to the left is either below threshold,
247
+ // or the IOU between the quads is below threshold
248
+ const bool isStartOfSpan = quadValid && !iouValid;
249
+
250
+ uint32_t label = isStartOfSpan;
251
+ // All labels start out as 0 or 1, and we'll then do a cumsum over the warp, which gives each thread an assigned label
252
+ // We also know that the final thread also contains the number of labels.
253
+ #pragma unroll
254
+ for (uint32_t offset = 1; offset < 32; offset <<= 1) {
255
+ auto inc = __shfl_up_sync(ALL_MASK, label, offset);
256
+ if (threadRank >= offset) {
257
+ label += inc;
258
+ }
259
+ }
260
+
261
+ // Before we zero out invalid labels, get the total number of labels
262
+ const uint32_t numLabels = __shfl_sync(ALL_MASK, label, WARP_SIZE - 1);
263
+
264
+ // Zero out the label if the current quad isn't valid
265
+ label = quadValid ? label : 0;
266
+
267
+ T* accumPtr = reinterpret_cast<T*>(&qAccum);
268
+ // Reduce all of the quads s.t. the left-most position in the span contains the full quad.
269
+ // We use `label` to decide whether to do the accumulation
270
+ #pragma unroll
271
+ for (uint32_t offset = 1; offset < 32; offset <<= 1) {
272
+ const auto otherLabel = __shfl_down_sync(ALL_MASK, label, offset);
273
+
274
+ // Regardless of whether the labels match, all threads in the warp must make the shfl_down
275
+ // call. So we use factor to modulate whether the given merge is valid
276
+ const T factor = otherLabel == label && offset + threadRank < WARP_SIZE ? 1.0f : 0.0f;
277
+
278
+ #pragma unroll
279
+ for (uint32_t i = 0; i < 10; ++i) {
280
+ accumPtr[i] += factor * __shfl_down_sync(ALL_MASK, accumPtr[i], offset);
281
+ }
282
+ }
283
+
284
+ // Elect thread-0 to figure out where to store the results
285
+ uint32_t storeOff = 0;
286
+ if (threadRank == 0) {
287
+ storeOff = atomicAdd(&allOutCounts[b], numLabels);
288
+ }
289
+ // Broadcast that offset to the whole warp
290
+ storeOff = __shfl_sync(ALL_MASK, storeOff, 0);
291
+
292
+ auto outEmbedQuads = allOutEmbedQuads[b];
293
+ // Now write out each quad, but collectively
294
+ for (uint32_t procLabel = 1; procLabel <= numLabels; ++procLabel) {
295
+ // Discover the index of the start of each label span
296
+ ballot = __ballot_sync(ALL_MASK, procLabel == label);
297
+ // ffs will find the (1-based) index of the least significant bit in ballot.
298
+ // This just so happens to be the start of the span for the current label
299
+ uint32_t startIdx = __ffs(ballot) - 1;
300
+
301
+ const T* inT = reinterpret_cast<T*>(&qAccum);
302
+ MergeQuad_<T> outQuad;
303
+ T* outT = reinterpret_cast<T*>(&outQuad);
304
+ #pragma unroll
305
+ for (uint32_t i = 0; i < 10; ++i) {
306
+ outT[i] = __shfl_sync(ALL_MASK, inT[i], startIdx);
307
+ }
308
+
309
+ write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
310
+ #ifdef NMS_VERIFY_CORRECTNESS
311
+ if (threadRank == 0) {
312
+ allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
313
+ }
314
+ #endif
315
+ }
316
+
317
+ if (threadRank == 0) {
318
+ // Increment the total number of quads by the number encountered on this row
319
+ atomicAdd(&allOutCounts[B], numLabels);
320
+ }
321
+
322
+ #undef threadRank
323
+ }
324
+
325
+ template<bool IsSingleExample, typename T>
326
+ __global__
327
+ void device_a2a_adjacency_sparse(const uint64_t punCounts,
328
+ T iouThreshold,
329
+ torch::PackedTensorAccessor64<T, 3> embedQuads,
330
+ torch::PackedTensorAccessor64<bool, 2> outIsStart,
331
+ torch::PackedTensorAccessor64<int32_t, 2> outAdjCounts,
332
+ torch::PackedTensorAccessor64<int32_t, 3> outSparseAdj)
333
+ {
334
+ const uint32_t b = blockIdx.y;
335
+
336
+ const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
337
+
338
+ const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
339
+ const int32_t row = jobIdx / quadCt;
340
+ const int32_t col = jobIdx % quadCt;
341
+
342
+ // Only compute the upper triangular portion of the matrix
343
+ if (row >= quadCt || col < row) {
344
+ return;
345
+ }
346
+
347
+ T* exData = IsSingleExample ? embedQuads.data() : embedQuads[b].data();
348
+
349
+ const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
350
+ qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
351
+
352
+ T pctRow, pctCol, iou;
353
+ thrust::tie(pctRow, pctCol, iou) = geometry_region_sizes(qRow, qCol);
354
+
355
+ auto warpGroup = cg::tiled_partition<32>(cg::this_thread_block());
356
+
357
+ auto rowGroup = cg::labeled_partition(warpGroup, row);
358
+
359
+ const bool isValid = iou >= iouThreshold;
360
+
361
+ const uint32_t ballot = rowGroup.ballot(isValid);
362
+ const uint32_t numValid = __popc(ballot);
363
+
364
+ auto exAdjCounts = outAdjCounts[b].data();
365
+
366
+ int32_t storeOff = 0;
367
+ if (numValid > 0 && rowGroup.thread_rank() == 0) {
368
+ storeOff = atomicAdd(exAdjCounts + row, numValid);
369
+ }
370
+ storeOff = rowGroup.shfl(storeOff, 0);
371
+
372
+ if (isValid) {
373
+ // This will set all of the bits to the left of this one to 1, otherwise 0.
374
+ // We can use this to count the number of bits that are set, and are less significant than this one,
375
+ // to get the local storage offset
376
+ uint32_t lowerMask = (1 << rowGroup.thread_rank()) - 1;
377
+
378
+ storeOff += __popc(ballot & lowerMask);
379
+
380
+ outSparseAdj[b][row][storeOff] = col;
381
+ if (row != col) {
382
+ // Because `col` gets merged into `row`, we mark it as inactive for reduction purposes.
383
+ // All of the quads that `col` is adjacent to will be absorbed by `row`.
384
+ outIsStart[b][col] = false;
385
+
386
+ // Also store the transposed relation
387
+ storeOff = atomicAdd(exAdjCounts + col, 1);
388
+ outSparseAdj[b][col][storeOff] = row;
389
+ }
390
+ } else if (pctRow > 0.8f || pctCol > 0.8f) {
391
+ T anchorHeight = qRow.Height();
392
+ T otherHeight = qCol.Height();
393
+
394
+ T ratio = anchorHeight > otherHeight ?
395
+ otherHeight / anchorHeight :
396
+ anchorHeight / otherHeight;
397
+ if (ratio > 0.9f) {
398
+ if (pctRow > 0.8f) {
399
+ // Other envelops anchor
400
+ outIsStart[b][row] = false;
401
+ }
402
+ else {
403
+ outIsStart[b][col] = false;
404
+ }
405
+ }
406
+ }
407
+ }
408
+
409
+ template<uint32_t NumWarps, bool IsSingleExample, typename T, int32_t I_CELL_SIZE>
410
+ __global__
411
+ void device_a2a_adjacency_build_grid(const uint64_t punCounts,
412
+ torch::PackedTensorAccessor64<T, 3> embedQuads,
413
+ torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
414
+ torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
415
+ {
416
+ constexpr T MIN_T = std::numeric_limits<T>::min();
417
+ constexpr T MAX_T = std::numeric_limits<T>::max();
418
+ constexpr uint32_t WARP_SIZE = 32;
419
+ constexpr uint32_t BLOCK_SIZE = NumWarps * WARP_SIZE;
420
+ constexpr uint32_t FULL_WARP = 0xFFFFFFFF;
421
+ constexpr uint32_t FIRST_16_THREADS = 0x0FFFF;
422
+ constexpr T CELL_SIZE = I_CELL_SIZE;
423
+ constexpr T INV_CELL_SIZE = 1 / CELL_SIZE;
424
+
425
+ const uint32_t b = blockIdx.z;
426
+
427
+ const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
428
+ const uint32_t quadIdx = blockIdx.y;
429
+
430
+ if (!IsSingleExample && quadIdx >= quadCt) {
431
+ return;
432
+ }
433
+
434
+ const uint32_t threadRank = threadIdx.x;
435
+ const uint32_t localThreadRank = threadRank & 0x1F;
436
+
437
+ auto exQuads = embedQuads[b];
438
+
439
+ const uint32_t numCells[2] = { outGridCells.size(2), outGridCells.size(1) };
440
+
441
+ const uint32_t numRows = outGridCells.size(1);
442
+ const uint32_t numCols = outGridCells.size(2);
443
+
444
+ // We use flip so that we can compute min and max simultaneously.
445
+ // First 4 threads compute the min, next 4 compute the max
446
+ T sign = localThreadRank < 8 ? 1.0f : -1.0f;
447
+ T myVal = sign * (localThreadRank < 16 ? exQuads[localThreadRank & 0x7][quadIdx] : MIN_T);
448
+ #pragma unroll
449
+ for (uint32_t offset = 2; offset < 8; offset <<= 1) {
450
+ T nextVal = __shfl_down_sync(FIRST_16_THREADS, myVal, offset);
451
+ myVal = min(myVal, nextVal);
452
+ }
453
+ const uint32_t cellVal = max(0.0f, sign * INV_CELL_SIZE * myVal);
454
+
455
+ uint32_t minCell[2] = { __shfl_sync(FULL_WARP, cellVal, 0), __shfl_sync(FULL_WARP, cellVal, 1) },
456
+ maxCell[2] = { __shfl_sync(FULL_WARP, cellVal, 8), __shfl_sync(FULL_WARP, cellVal, 9) };
457
+
458
+ #pragma unroll
459
+ for (uint32_t i = 0; i < 2; ++i) {
460
+ maxCell[i] = min(numCells[i] - 1, maxCell[i]);
461
+ }
462
+
463
+ const uint32_t sizes[2] = { maxCell[0] - minCell[0] + 1, maxCell[1] - minCell[1] + 1 };
464
+
465
+ const uint32_t totalCells = sizes[0] * sizes[1];
466
+
467
+ auto exGridCells = outGridCells[b];
468
+
469
+ for (uint32_t i = threadRank; i < totalCells; i += BLOCK_SIZE) {
470
+ uint32_t row = minCell[1] + i / sizes[0];
471
+ uint32_t col = minCell[0] + i % sizes[0];
472
+
473
+ int32_t *pCell = exGridCells[row][col].data();
474
+
475
+ // The first value in the array is the count, and the rest are the quad indices
476
+ int32_t storeOff = atomicAdd(pCell, 1) + 1;
477
+ pCell[storeOff] = quadIdx;
478
+ }
479
+
480
+ if (threadRank < 2) {
481
+ outQuadCells[b][quadIdx][threadRank] = minCell[threadRank];
482
+ } else if (threadRank < 4) {
483
+ outQuadCells[b][quadIdx][threadRank] = maxCell[threadRank - 2];
484
+ }
485
+ }
486
+
487
+ typedef uint8_t visit_mask_t;
488
+
489
+ template<uint32_t NumWarps, bool IsSingleExample, typename T>
490
+ __global__
491
+ void device_a2a_adjacency_with_grid(const uint64_t punCounts,
492
+ T iouThreshold,
493
+ torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
494
+ torch::PackedTensorAccessor64<int32_t, 4> allCells,
495
+ torch::PackedTensorAccessor64<int32_t, 3> allQuadExtents,
496
+ torch::PackedTensorAccessor64<bool, 2> outIsStart,
497
+ torch::PackedTensorAccessor64<int32_t, 2> outAdjCounts,
498
+ torch::PackedTensorAccessor64<int32_t, 3> outSparseAdj)
499
+ {
500
+ constexpr T MIN_T = std::numeric_limits<T>::min();
501
+ constexpr T MAX_T = std::numeric_limits<T>::max();
502
+ constexpr uint32_t WARP_SIZE = 32;
503
+ constexpr uint32_t BLOCK_SIZE = NumWarps * WARP_SIZE;
504
+
505
+ const uint32_t b = blockIdx.z;
506
+
507
+ const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
508
+ const uint32_t quadIdx = blockIdx.y;
509
+
510
+ if (!IsSingleExample && quadIdx >= quadCt) {
511
+ return;
512
+ }
513
+
514
+ const uint32_t threadRank = threadIdx.x;
515
+
516
+ auto exQuads = allEmbedQuads[b];
517
+
518
+ __shared__ T s_quadVerts[8];
519
+ __shared__ uint32_t s_quadExtent[4];
520
+ extern __shared__ uint32_t s_alreadyVisited[];
521
+
522
+ if (threadRank < 8) {
523
+ s_quadVerts[threadRank] = exQuads[threadRank][quadIdx];
524
+ } else if (threadRank < 12) {
525
+ s_quadExtent[threadRank - 8] = reinterpret_cast<uint32_t*>(allQuadExtents[b][quadIdx].data())[threadRank - 8];
526
+ }
527
+
528
+ uint32_t zeroTerm = (quadCt + 31u) >> 5u; // Fast version of div_up(quadCt, 32)
529
+ for (uint32_t col = threadRank; col < zeroTerm; col += BLOCK_SIZE) {
530
+ s_alreadyVisited[col] = 0;
531
+ }
532
+
533
+ __syncthreads();
534
+
535
+ auto exCells = allCells[b];
536
+ auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
537
+ auto exAdjValues = outSparseAdj[b][quadIdx].data();
538
+
539
+ T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
540
+
541
+ const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
542
+
543
+ const uint32_t startCol = s_quadExtent[0],
544
+ endCol = s_quadExtent[2];
545
+ for (uint32_t row = s_quadExtent[1], endRow = s_quadExtent[3]; row <= endRow; ++row) {
546
+ auto rowCells = exCells[row];
547
+
548
+ for (uint32_t col = startCol; col <= endCol; ++col) {
549
+ auto colCells = reinterpret_cast<const uint32_t*>(rowCells[col].data());
550
+
551
+ const uint32_t ct = colCells[0];
552
+
553
+ for (uint32_t i = threadRank + 1; i <= ct; i += BLOCK_SIZE) {
554
+ const uint32_t otherIdx = colCells[i];
555
+
556
+ const uint32_t maskIdx = otherIdx >> 5; // Divide by 32, since there are 32 bits per mask slot
557
+ const uint32_t maskBit = 1 << (otherIdx & 0x1F); // Set the relevant bit for this mask ID
558
+
559
+ const bool alreadyVisited = atomicOr(s_alreadyVisited + maskIdx, maskBit) & maskBit;
560
+
561
+ if (!alreadyVisited) {
562
+ const auto bdsOther = StridedEmbedQuad_<T>{ exData + otherIdx * allEmbedQuads.stride(2), allEmbedQuads.stride(1) }.Bounds();
563
+
564
+ T pctAnchor, pctOther, iou;
565
+ thrust::tie(pctAnchor, pctOther, iou) = geometry_region_sizes(bdsAnchor, bdsOther);
566
+
567
+ if (iou >= iouThreshold) {
568
+ auto validGroup = cg::coalesced_threads();
569
+
570
+ uint32_t storeOff = 0;
571
+ if (validGroup.thread_rank() == 0) {
572
+ storeOff = atomicAdd(exAdjCounts + quadIdx, validGroup.size());
573
+ }
574
+ storeOff = validGroup.shfl(storeOff, 0) + validGroup.thread_rank();
575
+
576
+ exAdjValues[storeOff] = otherIdx;
577
+
578
+ if (otherIdx > quadIdx) {
579
+ outIsStart[b][otherIdx] = false;
580
+ }
581
+ } else if (pctAnchor > 0.8f || pctOther > 0.8f) {
582
+ T anchorHeight = bdsAnchor.Height();
583
+ T otherHeight = bdsOther.Height();
584
+
585
+ T ratio = anchorHeight > otherHeight ?
586
+ otherHeight / anchorHeight :
587
+ anchorHeight / otherHeight;
588
+ if (ratio > 0.9f) {
589
+ if (pctAnchor > 0.8f) {
590
+ // Other envelops anchor
591
+ outIsStart[b][quadIdx] = false;
592
+ } else {
593
+ outIsStart[b][otherIdx] = false;
594
+ }
595
+ }
596
+ }
597
+ }
598
+ }
599
+ }
600
+ }
601
+ }
602
+
603
+ template<bool IsSingleExample>
604
+ __global__
605
+ void device_flatten_graph_iterative(const uint64_t punCounts,
606
+ torch::PackedTensorAccessor64<bool, 2> allIsStart,
607
+ volatile uint32_t *allAdjCounts,
608
+ volatile uint32_t *allAdjValues
609
+ #ifdef NMS_VERIFY_CORRECTNESS
610
+ , int32_t *maxDepth
611
+ #endif
612
+ )
613
+ {
614
+ constexpr uint32_t WARP_SIZE = 32;
615
+ constexpr uint32_t VISIT_STACK_SIZE = 9;
616
+ constexpr uint32_t TERM_VALUE = std::numeric_limits<uint32_t>::max();
617
+
618
+ constexpr visit_mask_t VISITED_MASK = 0b001;
619
+ constexpr visit_mask_t ADDED_MASK = 0b010;
620
+ constexpr visit_mask_t QUEUED_MASK = 0b100;
621
+ constexpr visit_mask_t QUEUED_OR_VISITED_MASK = VISITED_MASK | QUEUED_MASK;
622
+
623
+ const uint32_t b = blockIdx.z;
624
+ const uint32_t anchorRow = blockIdx.y;
625
+
626
+ const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
627
+
628
+ // Only need to check this if there are multiple examples, since in the case of a single example,
629
+ // the grid is precisely sized to that quadCt
630
+ if constexpr (!IsSingleExample) {
631
+ if (anchorRow >= quadCt) {
632
+ return;
633
+ }
634
+ }
635
+
636
+ auto isStart = allIsStart[b].data();
637
+
638
+ const uint32_t threadRank = threadIdx.x;
639
+
640
+ extern __shared__ visit_mask_t s_visitedMask[];
641
+
642
+ #ifndef NMS_VERIFY_CORRECTNESS
643
+ // Only need to process the anchor rows, since they're the only ones
644
+ // that will make it through the full NMS operation.
645
+ // NOTE: There's a race condition where some rows may be marked as anchor,
646
+ // but they'll later be marked non-anchor over the course of this kernel.
647
+ // That's fine. It's a bit of extra work, but there's no real way around it.
648
+ const bool anchorIsStart = isStart[anchorRow];
649
+ if (!anchorIsStart) {
650
+ return;
651
+ }
652
+ #endif
653
+
654
+ uint32_t *pIntVisitedMask = reinterpret_cast<uint32_t*>(s_visitedMask);
655
+ uint32_t zeroTerm = (quadCt + 3) >> 2; // Fast version of div_up(quadCt, 4)
656
+ for (uint32_t col = threadRank; col < zeroTerm; col += blockDim.x) {
657
+ pIntVisitedMask[col] = 0;
658
+ }
659
+
660
+ __syncthreads();
661
+
662
+ const uint32_t maxExCount = allIsStart.size(1);
663
+ auto adjCounts = allAdjCounts + (b * maxExCount);
664
+ auto adjValues = allAdjValues + (b * maxExCount * maxExCount);
665
+
666
+ auto adjAnchorValues = adjValues + (anchorRow * maxExCount);
667
+ // For the anchor row, set the visited mask to 0b10, which will signify that we haven't visited it yet,
668
+ // but that the value is already in the adjacency vector.
669
+ // 0bx1 signifies that the value has been visited
670
+ for (uint32_t i = threadRank, ct = adjCounts[anchorRow]; i < ct; i += blockDim.x) {
671
+ const auto adjCol = adjAnchorValues[i];
672
+ s_visitedMask[adjCol] = ADDED_MASK;
673
+ }
674
+
675
+ __syncthreads();
676
+
677
+ if (threadRank == 0) {
678
+ s_visitedMask[anchorRow] |= QUEUED_MASK;
679
+ }
680
+
681
+ __syncthreads();
682
+
683
+ // TODO(mranzinger): Is it worth incorporating these other threads?
684
+ // It seems like the vast majority of adjacency counts is <32
685
+ if (threadRank >= WARP_SIZE) {
686
+ return;
687
+ }
688
+
689
+ uint32_t visitStack[VISIT_STACK_SIZE];
690
+ visitStack[0] = TERM_VALUE;
691
+ visitStack[1] = anchorRow;
692
+ #ifndef NDEBUG
693
+ for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
694
+ visitStack[i] = -2;
695
+ }
696
+ #endif
697
+ int32_t visitPtr = 1;
698
+
699
+ while (true) {
700
+ #ifdef NMS_VERIFY_CORRECTNESS
701
+ assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
702
+ #endif
703
+ const uint32_t threadNextCol = visitStack[visitPtr];
704
+ const uint32_t warpNextCol = __reduce_min_full_warp(threadNextCol);
705
+
706
+ // Check to see if this thread got chosen.
707
+ // If so, decrement the stack counter
708
+ if (threadNextCol == warpNextCol) {
709
+ #ifndef NDEBUG
710
+ // This makes it easier to debug where the pointer is
711
+ visitStack[visitPtr] = -2;
712
+ #endif
713
+ --visitPtr;
714
+ }
715
+
716
+ // If the maximum value encountered is -1, that means that none of the threads
717
+ // had another value to process
718
+ if (warpNextCol == TERM_VALUE) {
719
+ break;
720
+ }
721
+
722
+ const uint32_t procRow = warpNextCol;
723
+
724
+ __syncthreads();
725
+
726
+ bool isAlreadyVisited = s_visitedMask[procRow] & VISITED_MASK;
727
+
728
+ if (isAlreadyVisited) {
729
+ continue;
730
+ }
731
+
732
+ const uint32_t procAdjCount = adjCounts[procRow];
733
+ auto procAdjValues = adjValues + (procRow * maxExCount);
734
+
735
+ // Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
736
+ // The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
737
+ // has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
738
+ // the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
739
+ for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
740
+ const uint32_t adjCol = procAdjValues[i];
741
+
742
+ // This will set the queued flag for this column, if it's not already set.
743
+ // It also returns the old state. In our case, we only want to add this value to the
744
+ // stack iff it hasn't already been visited, and hasn't been queued elsewhere
745
+ // NOTE: CUDA doesn't support atomicOr on uint8_t :(, but it's not necessary that
746
+ // the operation be absolutely atomic, so the poor man's version is probably okay
747
+ const auto oldMask = s_visitedMask[adjCol];
748
+ auto newMask = oldMask;
749
+
750
+ bool alreadyAdded = oldMask & ADDED_MASK;
751
+
752
+ auto group = cg::coalesced_threads();
753
+ const uint32_t gThreadRank = group.thread_rank();
754
+ uint32_t notAddedBallot = group.ballot(!alreadyAdded);
755
+ if (notAddedBallot) {
756
+ // Only one warp will ever be adding values to a given row, which means
757
+ // that we don't need atomics. However, other warps may be reading data
758
+ // from anchorRow, which means that we need to add the values first,
759
+ // followed by incrementing the count. This order makes things
760
+ // concurrency safe.
761
+ const uint32_t globalStoreOff = adjCounts[anchorRow];
762
+ // Gets the count of the bits to the left of this thread
763
+ const uint32_t localStoreOff = __popc(notAddedBallot & ((1 << gThreadRank) - 1));
764
+
765
+ if (!alreadyAdded) {
766
+ adjAnchorValues[globalStoreOff + localStoreOff] = adjCol;
767
+ if (adjCol > anchorRow) {
768
+ // Also, ensure that this quad is no longer marked as a starting quad
769
+ isStart[adjCol] = false;
770
+ }
771
+ newMask |= ADDED_MASK;
772
+ }
773
+
774
+ // Finally, commit the change by incrementing the counter
775
+ if (gThreadRank == 0) {
776
+ adjCounts[anchorRow] += __popc(notAddedBallot);
777
+ }
778
+ }
779
+
780
+ bool alreadyHandled = oldMask & QUEUED_OR_VISITED_MASK;
781
+
782
+ if (!alreadyHandled) {
783
+ #ifdef NMS_VERIFY_CORRECTNESS
784
+ newMask |= QUEUED_MASK;
785
+ ++visitPtr;
786
+ assert(visitPtr < VISIT_STACK_SIZE);
787
+ atomicMax(maxDepth, visitPtr);
788
+ visitStack[visitPtr] = adjCol;
789
+ #else
790
+ // Prefer potentially inconsistent results over buffer overflow
791
+ if (visitPtr < VISIT_STACK_SIZE - 1) {
792
+ newMask |= QUEUED_MASK;
793
+ ++visitPtr;
794
+ visitStack[visitPtr] = adjCol;
795
+ }
796
+ #endif
797
+ }
798
+
799
+ if (newMask != oldMask) {
800
+ s_visitedMask[adjCol] = newMask;
801
+ }
802
+ }
803
+
804
+ // We actually rely on the `pop_next` function largely to handle recursing down into the next row
805
+ __syncthreads();
806
+ }
807
+ }
808
+
809
+ void add_to_set(const torch::TensorAccessor<int32_t, 1>& adjCounts,
810
+ const torch::TensorAccessor<int32_t, 2>& adjValues,
811
+ int32_t row,
812
+ std::unordered_set<int32_t>& possible)
813
+ {
814
+ if (possible.count(row)) {
815
+ return;
816
+ }
817
+
818
+ possible.insert(row);
819
+
820
+ const int32_t adjCount = adjCounts[row];
821
+ auto values = adjValues[row].data();
822
+
823
+ for (int32_t i = 0; i < adjCount; ++i) {
824
+ const int32_t col = values[i];
825
+ add_to_set(adjCounts, adjValues, col, possible);
826
+ }
827
+ }
828
+
829
+ template<bool IsSingleExample>
830
+ void cpu_flatten_graph(const uint64_t punCounts,
831
+ torch::Tensor isStartTensorGPU,
832
+ torch::Tensor adjCountsTensorGPU,
833
+ torch::Tensor adjValuesTensorGPU)
834
+ {
835
+ auto isStartTensor = isStartTensorGPU.cpu();
836
+ auto adjCountsTensor = adjCountsTensorGPU.cpu();
837
+ auto adjValuesTensor = adjValuesTensorGPU.cpu();
838
+
839
+ auto allIsStart = isStartTensor.accessor<bool, 2>();
840
+ auto allAdjCounts = adjCountsTensor.accessor<int32_t, 2>();
841
+ auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
842
+
843
+ for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
844
+ const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
845
+
846
+ for (int32_t row = 0; row < quadCt; ++row) {
847
+ std::unordered_set<int32_t> fullAdjSet;
848
+ add_to_set(allAdjCounts[b], allAdjValues[b], row, fullAdjSet);
849
+
850
+ int32_t &currCt = allAdjCounts[b][row];
851
+ int32_t *currValues = allAdjValues[b][row].data();
852
+ std::unordered_set<int32_t> existingSet{ currValues, currValues + currCt };
853
+
854
+ for (int32_t adjCol : fullAdjSet) {
855
+ if (existingSet.count(adjCol)) {
856
+ continue;
857
+ }
858
+
859
+ currValues[currCt] = adjCol;
860
+ ++currCt;
861
+
862
+ if (adjCol > row) {
863
+ allIsStart[b][adjCol] = false;
864
+ }
865
+ }
866
+ }
867
+ }
868
+
869
+ isStartTensorGPU.copy_(isStartTensor);
870
+ adjCountsTensorGPU.copy_(adjCountsTensor);
871
+ adjValuesTensorGPU.copy_(adjValuesTensor);
872
+ }
873
+
874
+
875
+ __global__
876
+ void device_a2a_adj_cleanup(const int32_t *counts,
877
+ torch::PackedTensorAccessor64<uint8_t, 3> inOutAdjacency)
878
+ {
879
+ const uint32_t b = blockIdx.y;
880
+ const uint32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
881
+ const uint32_t numQuads = counts[b];
882
+ const uint32_t row = jobIdx / numQuads;
883
+ const uint32_t col = jobIdx % numQuads;
884
+
885
+ if (row >= numQuads) {
886
+ return;
887
+ }
888
+
889
+ auto adjacency = inOutAdjacency[b];
890
+
891
+ bool rowPivot = adjacency[row][row] > 0;
892
+ bool colPivot = adjacency[col][col] > 0;
893
+
894
+ if (!rowPivot || !colPivot) {
895
+ adjacency[row][col] = 0;
896
+ }
897
+ }
898
+
899
+ template<uint32_t NumWarps, typename T, bool IsSingleExample>
900
+ __global__
901
+ void device_a2a_collapse(const uint64_t punCounts,
902
+ torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
903
+ torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
904
+ const int64_t *regionCounts,
905
+ torch::PackedTensorAccessor64<int32_t, 2> allAdjCounts,
906
+ torch::PackedTensorAccessor64<int32_t, 3> allAdjValues,
907
+ //torch::PackedTensorAccessor64<int32_t, 2> allOutPositions,
908
+ torch::PackedTensorAccessor64<T, 3> outQuads,
909
+ T *outConf)
910
+ {
911
+ constexpr uint32_t WARP_SIZE = 32;
912
+ constexpr uint32_t FULL_WARP = 0xFFFFFFFF;
913
+ constexpr uint32_t BLOCK_WIDTH = NumWarps * WARP_SIZE;
914
+ constexpr size_t MERGE_QUAD_SIZE = sizeof(MergeQuad_<T>) / sizeof(T);
915
+
916
+ static_assert(NumWarps < WARP_SIZE, "Only a single warp currently supported!");
917
+
918
+ const uint32_t b = blockIdx.z;
919
+ const uint32_t row = blockIdx.y;
920
+
921
+ const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
922
+
923
+ if constexpr (!IsSingleExample) {
924
+ if (row >= quadCt) {
925
+ return;
926
+ }
927
+ }
928
+
929
+ // Only process the lead rows
930
+ const auto isLeadRow = IsSingleExample ? allIsLeadRow.data() : allIsLeadRow[b].data();
931
+ if (!isLeadRow[row]) {
932
+ return;
933
+ }
934
+
935
+ const uint32_t threadRank = threadIdx.x;
936
+ const uint32_t localThreadRank = threadRank & 0x1F;
937
+ const uint32_t warpIdx = threadRank >> 5;
938
+
939
+ __shared__ T s_mergeQuad[MERGE_QUAD_SIZE];
940
+
941
+ if constexpr (NumWarps > 1) {
942
+ if (threadRank < MERGE_QUAD_SIZE) {
943
+ s_mergeQuad[threadRank] = 0.0f;
944
+ }
945
+
946
+ __syncthreads();
947
+ }
948
+
949
+ T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
950
+
951
+ const int32_t adjCount = allAdjCounts[b][row];
952
+ const int32_t *adjIdxs = allAdjValues[b][row].data();
953
+
954
+ MergeQuad_<T> localMerge{ZeroInitTag{}};
955
+
956
+ for (int32_t i = threadRank; i < adjCount; i += BLOCK_WIDTH) {
957
+ const int32_t currQuadIdx = adjIdxs[i];
958
+ const StridedEmbedQuad_<T> qCurr{ exData + currQuadIdx * allEmbedQuads.stride(2), allEmbedQuads.stride(1) };
959
+
960
+ localMerge.Append(qCurr);
961
+ }
962
+
963
+ T *mqV = reinterpret_cast<T*>(&localMerge);
964
+ #pragma unroll
965
+ for (uint32_t offset = 1; offset < WARP_SIZE; offset <<= 1) {
966
+ T mergeFactor = offset + localThreadRank < 32;
967
+ #pragma unroll
968
+ for (uint32_t i = 0; i < MERGE_QUAD_SIZE; ++i) {
969
+ mqV[i] += mergeFactor * __shfl_down_sync(FULL_WARP, mqV[i], offset);
970
+ }
971
+ }
972
+ #pragma unroll
973
+ for (uint32_t i = 0; i < MERGE_QUAD_SIZE; ++i) {
974
+ mqV[i] = __shfl_sync(FULL_WARP, mqV[i], 0);
975
+ }
976
+
977
+ // Only need to do a multi-warp merge if there are enough quads to justify it
978
+ if (NumWarps > 1 && adjCount > WARP_SIZE) {
979
+ if (localThreadRank < MERGE_QUAD_SIZE) {
980
+ atomicAdd(s_mergeQuad + localThreadRank, mqV[localThreadRank]);
981
+ }
982
+
983
+ __syncthreads();
984
+
985
+ mqV = s_mergeQuad;
986
+ }
987
+
988
+ // Figure out the output position
989
+ uint32_t writePosition = 0;
990
+ if constexpr (!IsSingleExample) {
991
+ for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) {
992
+ writePosition += regionCounts[i];
993
+ }
994
+ }
995
+
996
+ const int32_t numLongs = row >> 3; // Divide by 8
997
+ const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
998
+ const uint64_t *lpCurrIsLeadRow = reinterpret_cast<const uint64_t*>(pCurrIsLeadRow);
999
+
1000
+ for (int32_t i = threadRank; i < numLongs; i += BLOCK_WIDTH) {
1001
+ writePosition += __popcll(lpCurrIsLeadRow[i]);
1002
+ }
1003
+ for (int32_t i = (numLongs * 8) + threadRank; i < row; i += BLOCK_WIDTH) {
1004
+ if (pCurrIsLeadRow[i]) {
1005
+ ++writePosition;
1006
+ }
1007
+ }
1008
+ // Sum all of the individual offsets over the warp
1009
+ writePosition = __reduce_add_full_warp(writePosition);
1010
+ // Reduce across warps, if applicable
1011
+ if constexpr (NumWarps > 1) {
1012
+ __shared__ uint32_t s_threadWritePositions[NumWarps];
1013
+ if (localThreadRank == 0) {
1014
+ s_threadWritePositions[warpIdx] = writePosition;
1015
+ }
1016
+ __syncthreads();
1017
+ writePosition = threadRank < NumWarps ? s_threadWritePositions[threadRank] : 0;
1018
+ writePosition = __reduce_add_full_warp(writePosition);
1019
+ }
1020
+
1021
+ if (threadRank >= 9) {
1022
+ return;
1023
+ }
1024
+
1025
+ const T sumConfidence = mqV[8];
1026
+ const T numQuads = mqV[9];
1027
+ const T divisor = threadRank < 8 ? sumConfidence : numQuads;
1028
+
1029
+ const T myVal = mqV[threadRank] / divisor;
1030
+
1031
+ auto writeVerts = outQuads[writePosition].data();
1032
+
1033
+ if (threadRank < 8) {
1034
+ writeVerts[threadRank] = myVal;
1035
+ } else {
1036
+ outConf[writePosition] = myVal;
1037
+ }
1038
+ }
1039
+
1040
+ struct CollapseRowsResult {
1041
+ torch::Tensor ExCounts;
1042
+ torch::Tensor StridedMergeQuads;
1043
+ int32_t TotalNumQuads;
1044
+ // NOTE: This will only be available in Debug builds
1045
+ torch::Tensor QuadIds;
1046
+ int32_t ImageWidth;
1047
+ int32_t ImageHeight;
1048
+ };
1049
+
1050
+ template<typename scalar_t>
1051
+ CollapseRowsResult collapse_rows(
1052
+ torch::Tensor quads, torch::Tensor probs, scalar_t probThreshold, scalar_t iouThreshold
1053
+ )
1054
+ {
1055
+ if (! quads.is_contiguous()) {
1056
+ throw std::runtime_error("Expected `quads` to be contiguous!");
1057
+ }
1058
+
1059
+ if ((quads.size(2) % 32) != 0) {
1060
+ throw std::runtime_error("Expected the width of the `quads` buffer to be a multiple of 32!");
1061
+ }
1062
+
1063
+ int32_t imageWidth = quads.size(2) * 4;
1064
+ int32_t imageHeight = quads.size(1) * 4;
1065
+
1066
+ quads = quads.reshape({ quads.size(0), -1, 32, 4, 2 });
1067
+ probs = probs.reshape({ probs.size(0), -1, 32 });
1068
+
1069
+ if (quads.size(0) != probs.size(0) || quads.size(1) != probs.size(1)) {
1070
+ throw std::runtime_error("Dimension mismatch between `quads` and `probs`");
1071
+ }
1072
+
1073
+ // The final counter is for the total number of quads for the entire batch
1074
+ auto counts = torch::zeros({ quads.size(0) + 1 }, quads.options().dtype(torch::kInt32));
1075
+
1076
+ int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
1077
+ auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
1078
+
1079
+ #ifdef NMS_VERIFY_CORRECTNESS
1080
+ auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
1081
+ std::numeric_limits<int32_t>::max(),
1082
+ counts.options().dtype(torch::kInt32));
1083
+ #else
1084
+ torch::Tensor idsTensor;
1085
+ #endif
1086
+
1087
+ dim3 blockSize(32, 3, 1);
1088
+ dim3 gridSize(1,
1089
+ div_up(quads.size(1), blockSize.y),
1090
+ quads.size(0));
1091
+
1092
+ device_row_collapse KERNEL_ARG2(gridSize, blockSize) (
1093
+ quads.packed_accessor64<scalar_t, 5>(),
1094
+ probs.packed_accessor64<scalar_t, 3>(),
1095
+ probThreshold, iouThreshold,
1096
+ counts.packed_accessor64<int32_t, 1>(),
1097
+ rowMergeTensor.packed_accessor64<scalar_t, 3>()
1098
+ #ifdef NMS_VERIFY_CORRECTNESS
1099
+ , idsTensor.packed_accessor64<int32_t, 2>()
1100
+ #endif
1101
+ );
1102
+
1103
+ #ifdef NMS_VERIFY_CORRECTNESS
1104
+ static std::unordered_set<int32_t> s_quadIds;
1105
+ auto cpuIdsTensor = idsTensor.cpu();
1106
+ const int32_t *idsPtr = cpuIdsTensor.data_ptr<int32_t>();
1107
+ if (s_quadIds.empty()) {
1108
+ s_quadIds.insert(idsPtr, idsPtr + idsTensor.numel());
1109
+ } else {
1110
+ std::unordered_set<int32_t> otherIds{ idsPtr, idsPtr + idsTensor.numel() };
1111
+
1112
+ if (s_quadIds != otherIds) {
1113
+ throw std::runtime_error("Inconsistent Ids!");
1114
+ }
1115
+ }
1116
+ #endif
1117
+
1118
+ // The final value in `counts` is actually to total number of quads for the entire batch
1119
+ int32_t totalQuads = counts[-1].item<int32_t>();
1120
+
1121
+ counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
1122
+
1123
+ #ifdef NMS_VERIFY_CORRECTNESS
1124
+ int64_t maxExCount;
1125
+ if (counts.size(0) > 1) {
1126
+ maxExCount = counts.max().item<int32_t>();
1127
+ } else {
1128
+ maxExCount = totalQuads;
1129
+ }
1130
+
1131
+ static bool s_sortOrder = false;
1132
+
1133
+ rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
1134
+ idsTensor = idsTensor.slice(1, 0, maxExCount);
1135
+ auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder); s_sortOrder = !s_sortOrder;
1136
+
1137
+ auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
1138
+
1139
+ rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
1140
+ idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
1141
+ #endif
1142
+
1143
+ return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
1144
+ }
1145
+
1146
+
1147
+
1148
+ void verify_row(const torch::TensorAccessor<int32_t, 1> &adjCounts,
1149
+ const torch::TensorAccessor<int32_t, 2> &adjValues,
1150
+ int32_t row)
1151
+ {
1152
+ // Traverse the graph, and accumulate all set flags across all rows marked
1153
+ // adjacent by the current row. If the merge_up algorithm works correctly, then
1154
+ // `possible` will contain exactly the same set of values as the current row
1155
+ std::unordered_set<int32_t> possible;
1156
+ add_to_set(adjCounts, adjValues, row, possible);
1157
+
1158
+ std::unordered_set<int32_t> thisRow{ row };
1159
+ const int32_t thisCount = adjCounts[row];
1160
+ auto thisValues = adjValues[row].data();
1161
+ thisRow.insert(thisValues, thisValues + thisCount);
1162
+
1163
+ if (thisRow != possible) {
1164
+ throw std::runtime_error("The merge_up algorithm is not correct!");
1165
+ }
1166
+ }
1167
+
1168
+ struct AdjacencyResult {
1169
+ // Shape: BxQ
1170
+ // Specifies whether the given row is a result row
1171
+ torch::Tensor IsLeadRow;
1172
+ // Shape: BxQ
1173
+ // The number of quads that need to be merged with the given quad
1174
+ torch::Tensor AdjCounts;
1175
+ // Shape: BxQx<Num Adjacent>
1176
+ // The indices of the adjacent quads.
1177
+ torch::Tensor AdjValues;
1178
+ int64_t MaxExCount;
1179
+ };
1180
+
1181
+ template<bool IsSingleExample, typename T>
1182
+ void cpu_a2a_adjacency_sparse(const uint64_t punCounts,
1183
+ const T iouThreshold,
1184
+ torch::Tensor embedQuadsTensor,
1185
+ torch::Tensor outIsStartTensorGPU,
1186
+ torch::Tensor outAdjCountsTensorGPU,
1187
+ torch::Tensor outSparseAdjTensorGPU)
1188
+ {
1189
+ embedQuadsTensor = embedQuadsTensor.cpu();
1190
+ auto outIsStartTensor = outIsStartTensorGPU.cpu();
1191
+ auto outAdjCountsTensor = outAdjCountsTensorGPU.cpu();
1192
+ auto outSparseAdjTensor = outSparseAdjTensorGPU.cpu();
1193
+
1194
+ auto embedQuads = embedQuadsTensor.accessor<T, 3>();
1195
+ auto isStart = outIsStartTensor.accessor<bool, 2>();
1196
+ auto adjCounts = outAdjCountsTensor.accessor<int32_t, 2>();
1197
+ auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
1198
+
1199
+ for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
1200
+ const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
1201
+
1202
+ T *exData = embedQuads[b].data();
1203
+
1204
+ for (int32_t row = 0; row < quadCt; ++row) {
1205
+ const auto qRow = StridedEmbedQuad_<T>{ exData + row, embedQuads.stride(1) }.Bounds();
1206
+
1207
+ for (int32_t col = 0; col < quadCt; ++col) {
1208
+ const auto qCol = StridedEmbedQuad_<T>{ exData + col, embedQuads.stride(1) }.Bounds();
1209
+
1210
+ T pctRow, pctCol, iou;
1211
+ thrust::tie(pctRow, pctCol, iou) = geometry_region_sizes(qRow, qCol);
1212
+
1213
+ if (iou >= iouThreshold) {
1214
+ int32_t &storeIdx = adjCounts[b][row];
1215
+ adjValues[b][row][storeIdx] = col;
1216
+ ++storeIdx;
1217
+ if (row < col) {
1218
+ isStart[b][col] = false;
1219
+ }
1220
+ } else if (pctRow > 0.8f || pctCol > 0.8f) {
1221
+ T anchorHeight = qRow.Height();
1222
+ T otherHeight = qCol.Height();
1223
+
1224
+ T ratio = anchorHeight > otherHeight ?
1225
+ otherHeight / anchorHeight :
1226
+ anchorHeight / otherHeight;
1227
+ if (ratio > 0.9f) {
1228
+ if (pctRow > 0.8f) {
1229
+ // Other envelops anchor
1230
+ isStart[b][row] = false;
1231
+ }
1232
+ else {
1233
+ isStart[b][col] = false;
1234
+ }
1235
+ }
1236
+ }
1237
+ }
1238
+ }
1239
+ }
1240
+
1241
+ outIsStartTensorGPU.copy_(outIsStartTensor);
1242
+ outAdjCountsTensorGPU.copy_(outAdjCountsTensor);
1243
+ outSparseAdjTensorGPU.copy_(outSparseAdjTensor);
1244
+ }
1245
+
1246
+ template<typename T>
1247
+ std::string to_flat_string(torch::Tensor tensor) {
1248
+ tensor = tensor.flatten();
1249
+
1250
+ auto acc = tensor.accessor<T, 1>();
1251
+
1252
+ std::ostringstream oss;
1253
+ oss << "[";
1254
+ if (acc.size(0) > 0) {
1255
+ oss << acc[0];
1256
+ for (int64_t i = 1; i < acc.size(0); ++i) {
1257
+ oss << ", " << acc[i];
1258
+ }
1259
+ }
1260
+ oss << "]";
1261
+ return oss.str();
1262
+ }
1263
+
1264
+ template<typename scalar_t>
1265
+ AdjacencyResult compute_all_to_all_adjacency(
1266
+ const CollapseRowsResult &collapseResult,
1267
+ scalar_t iouThreshold)
1268
+ {
1269
+ torch::Tensor counts = collapseResult.ExCounts;
1270
+
1271
+ int64_t maxExCount;
1272
+ if (counts.size(0) > 1) {
1273
+ maxExCount = counts.max().item<int32_t>();
1274
+ } else {
1275
+ maxExCount = collapseResult.TotalNumQuads;
1276
+ }
1277
+
1278
+ auto isStartTensor = torch::ones({ counts.size(0), maxExCount }, counts.options().dtype(torch::kBool));
1279
+ auto adjCountsTensor = torch::zeros({ counts.size(0), maxExCount }, counts.options().dtype(torch::kInt32));
1280
+ #ifndef NMS_VERIFY_CORRECTNESS
1281
+ auto adjValuesTensor = torch::empty({ counts.size(0), maxExCount, maxExCount }, counts.options().dtype(torch::kInt32));
1282
+ #else
1283
+ auto adjValuesTensor = torch::full({ counts.size(0), maxExCount, maxExCount },
1284
+ 5000,
1285
+ counts.options().dtype(torch::kInt32));
1286
+ #endif
1287
+
1288
+ // If the batch is only a single example, instead of hitting global memory for the count, we can
1289
+ // just encode the count into the pointer instead
1290
+ uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
1291
+ if (counts.size(0) == 1) {
1292
+ ptrCounts = maxExCount;
1293
+ }
1294
+
1295
+ #ifdef NMS_VERIFY_CORRECTNESS
1296
+ auto cpuAdjValuesTensor = adjValuesTensor.cpu();
1297
+ auto cpuAdjCountsTensor = adjCountsTensor.cpu();
1298
+ auto cpuIsStartTensor = isStartTensor.cpu();
1299
+ #endif
1300
+
1301
+ size_t smemSize;
1302
+ dim3 gridSize, blockSize;
1303
+
1304
+ ///////////////////
1305
+ // NOTE(mranzinger): This algorithm uses a fixed sized grid to spatially subdivide the canvas. For virtually all test conditions
1306
+ // I ran this through, it was slightly slower than the brute force approach that parallelizes better.
1307
+ // It's possible that there is some number of words present (e.g. >500) where this algorithm becomes
1308
+ // faster.
1309
+ //
1310
+ //constexpr int32_t CELL_SIZE = 100;
1311
+ //constexpr int64_t NUM_BINS_PER_CELL = 200;
1312
+ //int32_t numXCells = div_up(collapseResult.ImageWidth, CELL_SIZE);
1313
+ //int32_t numYCells = div_up(collapseResult.ImageHeight, CELL_SIZE);
1314
+ //auto gridCellsTensor = torch::zeros({ counts.size(0), numYCells, numXCells, NUM_BINS_PER_CELL }, adjCountsTensor.options());
1315
+ //auto quadCellExtentsTensor = torch::empty({ counts.size(0), maxExCount, 4 }, gridCellsTensor.options());
1316
+ //smemSize = div_up(static_cast<uint32_t>(maxExCount), 32);
1317
+
1318
+ //constexpr uint32_t GRID_NUM_WARPS = 3;
1319
+ //blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
1320
+ //gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1321
+
1322
+ //auto buildGridFn = counts.size(0) == 1 ?
1323
+ // device_a2a_adjacency_build_grid<GRID_NUM_WARPS, true, scalar_t, CELL_SIZE> :
1324
+ // device_a2a_adjacency_build_grid<GRID_NUM_WARPS, false, scalar_t, CELL_SIZE>;
1325
+
1326
+ //buildGridFn KERNEL_ARG2(gridSize, blockSize) (
1327
+ // ptrCounts,
1328
+ // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1329
+ // gridCellsTensor.packed_accessor64<int32_t, 4>(),
1330
+ // quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
1331
+ //);
1332
+
1333
+ //auto adjGridFn = counts.size(0) == 1 ?
1334
+ // device_a2a_adjacency_with_grid<GRID_NUM_WARPS, true, scalar_t> :
1335
+ // device_a2a_adjacency_with_grid<GRID_NUM_WARPS, false, scalar_t>;
1336
+
1337
+ //adjGridFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
1338
+ // ptrCounts,
1339
+ // iouThreshold,
1340
+ // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1341
+ // gridCellsTensor.packed_accessor64<int32_t, 4>(),
1342
+ // quadCellExtentsTensor.packed_accessor64<int32_t, 3>(),
1343
+ // isStartTensor.packed_accessor64<bool, 2>(),
1344
+ // adjCountsTensor.packed_accessor64<int32_t, 2>(),
1345
+ // adjValuesTensor.packed_accessor64<int32_t, 3>()
1346
+ //);
1347
+ ///////////////////
1348
+
1349
+ uint32_t totalWork = maxExCount * maxExCount;
1350
+
1351
+ blockSize = dim3{96, 1};
1352
+ gridSize = dim3{div_up(totalWork, blockSize.x),
1353
+ static_cast<uint32_t>(counts.size(0))};
1354
+
1355
+ auto adjFn = counts.size(0) == 1 ? device_a2a_adjacency_sparse<true, scalar_t> : device_a2a_adjacency_sparse<false, scalar_t>;
1356
+
1357
+ // This algorithm is O(n^2) with n being the current number of quads
1358
+ adjFn KERNEL_ARG2(gridSize, blockSize) (
1359
+ ptrCounts,
1360
+ iouThreshold,
1361
+ collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1362
+ isStartTensor.packed_accessor64<bool, 2>(),
1363
+ adjCountsTensor.packed_accessor64<int32_t, 2>(),
1364
+ adjValuesTensor.packed_accessor64<int32_t, 3>()
1365
+ );
1366
+
1367
+
1368
+ #ifdef NMS_VERIFY_CORRECTNESS
1369
+ cpu_a2a_adjacency_sparse<true>(ptrCounts, iouThreshold,
1370
+ collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1371
+
1372
+ adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
1373
+
1374
+ assert(torch::all(cpuAdjCountsTensor == adjCountsTensor.cpu()).item<bool>());
1375
+ assert(torch::all(cpuIsStartTensor == isStartTensor.cpu()).item<bool>());
1376
+ assert(torch::all(cpuAdjValuesTensor == adjValuesTensor.cpu()).item<bool>());
1377
+
1378
+ std::cout << "\tA2A Is Start Count: " << isStartTensor.sum(torch::kInt32).item<int32_t>()
1379
+ << ", Most Adjacent: " << adjCountsTensor.max().item<int32_t>() << std::endl;
1380
+
1381
+ auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
1382
+ #endif
1383
+
1384
+ auto traverseFn = counts.size(0) == 1 ?
1385
+ device_flatten_graph_iterative<true> :
1386
+ device_flatten_graph_iterative<false>;
1387
+
1388
+ blockSize = dim3{ 128, 1, 1 };
1389
+ gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1390
+ smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
1391
+
1392
+ traverseFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
1393
+ ptrCounts,
1394
+ isStartTensor.packed_accessor64<bool, 2>(),
1395
+ reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
1396
+ reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
1397
+ #ifdef NMS_VERIFY_CORRECTNESS
1398
+ , maxDepthTensor.data_ptr<int32_t>()
1399
+ #endif
1400
+ );
1401
+
1402
+ #ifdef NMS_VERIFY_CORRECTNESS
1403
+ cpu_flatten_graph<true>(ptrCounts, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1404
+
1405
+ cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
1406
+ adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
1407
+
1408
+ torch::Tensor diffStartIdxs = (cpuIsStartTensor != isStartTensor.cpu()).nonzero_numpy()[0];
1409
+
1410
+ assert(diffStartIdxs.numel() == 0);
1411
+
1412
+ torch::Tensor diffCountIdxs = (cpuAdjCountsTensor != adjCountsTensor.cpu()).nonzero_numpy()[0];
1413
+
1414
+ assert(diffCountIdxs.numel() == 0);
1415
+
1416
+ auto diffValuesTensor = torch::any(cpuAdjValuesTensor != adjValuesTensor.cpu(), /*dim=*/ 2, /*keepdim=*/ false).flatten().nonzero().flatten();
1417
+
1418
+ std::cout << "\t\tDiff Indices: " << to_flat_string<int64_t>(diffValuesTensor) << std::endl;
1419
+
1420
+ auto cpuDiffCountsTensor = cpuAdjCountsTensor.flatten().index({ diffValuesTensor });
1421
+ auto cpuDiffRowsTensor = cpuAdjValuesTensor.flatten(0, 1).index({ diffValuesTensor });
1422
+ auto gpuDiffRowsTensor = adjValuesTensor.cpu().flatten(0, 1).index({ diffValuesTensor });
1423
+
1424
+ for (int64_t i = 0, ct = cpuDiffRowsTensor.size(0); i < ct; ++i) {
1425
+ auto z = cpuDiffCountsTensor[i].item<int32_t>();
1426
+ auto diffRow = diffValuesTensor[i].item<int64_t>();
1427
+ std::cout << "\t\tRow " << diffRow << std::endl;
1428
+ std::cout << "\t\t\tExpected: " << to_flat_string<int32_t>(cpuDiffRowsTensor[i].slice(0, 0, z + 1)) << std::endl;
1429
+ std::cout << "\t\t\t GPU: " << to_flat_string<int32_t>(gpuDiffRowsTensor[i].slice(0, 0, z + 1)) << std::endl;
1430
+ }
1431
+
1432
+ assert(diffValuesTensor.size(0) == 0);
1433
+
1434
+ std::cout << "\tA2A - Flatten - Is Start Count: " << isStartTensor.sum(torch::kInt32).item<int32_t>()
1435
+ << ", Most Adjacent: " << adjCountsTensor.max().item<int32_t>()
1436
+ << ", Max Depth: " << maxDepthTensor.item<int32_t>() << std::endl;
1437
+
1438
+ cpuIsStartTensor = isStartTensor.cpu();
1439
+ cpuAdjCountsTensor = adjCountsTensor.cpu();
1440
+ cpuAdjValuesTensor = adjValuesTensor.cpu();
1441
+ auto cpuCounts = counts.cpu();
1442
+ auto cpuCollapseIds = collapseResult.QuadIds.cpu();
1443
+
1444
+ static std::vector<std::unordered_set<int32_t>> s_knownGroups;
1445
+ static std::unordered_map<int32_t, std::unordered_set<int32_t>> s_groupLookup;
1446
+
1447
+ std::vector<std::unordered_set<int32_t>> idGroups;
1448
+ decltype(s_groupLookup) groupLookup;
1449
+ for (int64_t b = 0; b < counts.size(0); ++b) {
1450
+ int64_t quadCt = cpuCounts[b].item<int32_t>();
1451
+ for (int64_t row = 0; row < quadCt; ++row) {
1452
+ bool isLeadRow = cpuIsStartTensor[b][row].item<bool>();
1453
+ auto bCountsTensor = cpuAdjCountsTensor[b];
1454
+ auto bValuesTensor = cpuAdjValuesTensor[b];
1455
+ auto bCounts = bCountsTensor.accessor<int32_t, 1>();
1456
+ auto bValues = bValuesTensor.accessor<int32_t, 2>();
1457
+
1458
+ auto bIdsTensor = cpuCollapseIds[b];
1459
+ auto bIds = bIdsTensor.accessor<int32_t, 1>();
1460
+
1461
+ std::unordered_set<int32_t> sIds;
1462
+ for (int32_t i = 0, ct = bCounts[row]; i < ct; ++i) {
1463
+ int32_t col = bValues[row][i];
1464
+ int32_t id = bIds[col];
1465
+ sIds.insert(id);
1466
+ }
1467
+
1468
+ if (sIds.empty()) {
1469
+ throw std::runtime_error("The ids tensor is empty!");
1470
+ }
1471
+
1472
+ groupLookup[bIds[row]] = sIds;
1473
+
1474
+ if (isLeadRow) {
1475
+ verify_row(bCounts, bValues, row);
1476
+ idGroups.push_back(move(sIds));
1477
+ }
1478
+ }
1479
+ }
1480
+
1481
+ if (s_knownGroups.empty()) {
1482
+ s_knownGroups = move(idGroups);
1483
+ s_groupLookup = move(groupLookup);
1484
+ } else {
1485
+ // Make a copy
1486
+ auto remOrigGroups = s_knownGroups;
1487
+ auto remOrigGroupLookup = s_groupLookup;
1488
+
1489
+ std::vector<int32_t> quadIds;
1490
+ for (auto &kv : remOrigGroupLookup) {
1491
+ quadIds.push_back(kv.first);
1492
+ }
1493
+ for (int32_t qId : quadIds) {
1494
+ assert(groupLookup.count(qId));
1495
+ }
1496
+ assert(groupLookup.size() == remOrigGroupLookup.size());
1497
+
1498
+ for (int32_t qId : quadIds) {
1499
+ auto &oldGroup = remOrigGroupLookup[qId];
1500
+ auto &newGroup = groupLookup[qId];
1501
+
1502
+ if (oldGroup == newGroup) {
1503
+ remOrigGroupLookup.erase(qId);
1504
+ groupLookup.erase(qId);
1505
+ } else {
1506
+ throw std::runtime_error("Group mismatch!");
1507
+ }
1508
+ }
1509
+
1510
+ for (int i = idGroups.size() - 1; i >= 0; --i) {
1511
+ for (int j = remOrigGroups.size() - 1; j >= 0; --j) {
1512
+ auto &idGroup = idGroups[i];
1513
+ auto &knownGroup = remOrigGroups[j];
1514
+
1515
+ if (idGroup == knownGroup) {
1516
+ idGroups.erase(begin(idGroups) + i);
1517
+ remOrigGroups.erase(begin(remOrigGroups) + j);
1518
+ break;
1519
+ }
1520
+ }
1521
+ }
1522
+
1523
+ if (!idGroups.empty() || !remOrigGroups.empty()) {
1524
+ auto group_str = [] (auto &group) {
1525
+ std::vector<int32_t> vGroup{ std::begin(group), std::end(group) };
1526
+ std::sort(std::begin(vGroup), std::end(vGroup));
1527
+
1528
+ auto id_str = [] (int32_t id) {
1529
+ std::ostringstream oss;
1530
+ //oss << "(" << (id / 32) << ", " << (id % 32) << ")";
1531
+ oss << id;
1532
+ return oss.str();
1533
+ };
1534
+
1535
+ std::ostringstream oss;
1536
+ oss << "[" << id_str(vGroup[0]);
1537
+ for (size_t i = 1; i < vGroup.size(); ++i) {
1538
+ oss << ", " << id_str(vGroup[i]);
1539
+ }
1540
+ oss << "]";
1541
+ return oss.str();
1542
+ };
1543
+
1544
+ std::cout << "\tEncountered a difference in groups!" << std::endl
1545
+ << "\t\tOrig groups:" << std::endl;
1546
+ for (auto &group : remOrigGroups) {
1547
+ std::cout << "\t\t\t" << group_str(group) << std::endl;
1548
+ }
1549
+ std::cout << "\t\tNew groups:" << std::endl;
1550
+ for (auto &group : idGroups) {
1551
+ std::cout << "\t\t\t" << group_str(group) << std::endl;
1552
+ }
1553
+ }
1554
+ }
1555
+ #endif
1556
+
1557
+ return { isStartTensor, adjCountsTensor, adjValuesTensor, maxExCount };
1558
+ }
1559
+
1560
+
1561
+
1562
+ template<typename scalar_t>
1563
+ nms_result_t
1564
+ all_to_all_collapse(
1565
+ const CollapseRowsResult &collapseRowsRes,
1566
+ const AdjacencyResult &adjResult)
1567
+ {
1568
+ auto counts = collapseRowsRes.ExCounts;
1569
+ auto embedQuads = collapseRowsRes.StridedMergeQuads;
1570
+
1571
+ if (!embedQuads.is_contiguous()) {
1572
+ throw std::runtime_error("Input embed quads were not contiguous!");
1573
+ }
1574
+
1575
+ torch::Tensor isLeadRow;
1576
+ if (counts.size(0) == 1) {
1577
+ isLeadRow = adjResult.IsLeadRow;
1578
+ } else {
1579
+ // For multiple examples: IsLeadRow will have true values beyond the extent of the number of quads
1580
+ // However, we know that Counts > 0 only happen within the extent, so the set intersection
1581
+ // tells us which rows are actually lead
1582
+ isLeadRow = torch::logical_and(adjResult.IsLeadRow, adjResult.AdjCounts > 0);
1583
+ }
1584
+
1585
+ auto regionCounts = isLeadRow.sum(/*dim=*/ 1, /*keepdim=*/ false, torch::kInt64);
1586
+
1587
+ const int64_t numOutQuads = counts.size(0) == 1 ? regionCounts.item<int64_t>() : regionCounts.sum().item<int64_t>();
1588
+
1589
+ constexpr int32_t NUM_WARPS = 4;
1590
+ dim3 blockSize(NUM_WARPS * 32, 1, 1);
1591
+ dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
1592
+
1593
+ // If the batch is only a single example, instead of hitting global memory for the count, we can
1594
+ // just encode the count into the pointer instead
1595
+ uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
1596
+ if (counts.size(0) == 1) {
1597
+ ptrCounts = adjResult.MaxExCount;
1598
+ }
1599
+
1600
+ torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
1601
+ torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
1602
+
1603
+ auto collapseFn = counts.size(0) == 1 ?
1604
+ device_a2a_collapse<NUM_WARPS, scalar_t, true> :
1605
+ device_a2a_collapse<NUM_WARPS, scalar_t, false>;
1606
+
1607
+ collapseFn KERNEL_ARG2(gridSize, blockSize) (
1608
+ ptrCounts,
1609
+ embedQuads.packed_accessor64<scalar_t, 3>(),
1610
+ isLeadRow.packed_accessor64<bool, 2>(),
1611
+ regionCounts.data_ptr<int64_t>(),
1612
+ adjResult.AdjCounts.packed_accessor64<int32_t, 2>(),
1613
+ adjResult.AdjValues.packed_accessor64<int32_t, 3>(),
1614
+ outQuads.packed_accessor64<scalar_t, 3>(),
1615
+ outConf.data_ptr<scalar_t>()
1616
+ );
1617
+
1618
+ return { outQuads, outConf, regionCounts };
1619
+ }
1620
+
1621
+ template<typename scalar_t>
1622
+ nms_result_t cuda_quad_non_maximal_suppression_impl(
1623
+ torch::Tensor quads, torch::Tensor probs,
1624
+ scalar_t probThreshold, scalar_t iouThreshold,
1625
+ int64_t maxRegions, bool verbose)
1626
+ {
1627
+ static const bool s_timerEnabled = true;
1628
+ static const bool s_verboseLevel2 = true;
1629
+
1630
+ // Make sure there's a batch dimension
1631
+ if (quads.dim() == 4) {
1632
+ // B,H,W,V,2
1633
+ quads = quads.unsqueeze(0);
1634
+ // B,H,W
1635
+ probs = probs.unsqueeze(0);
1636
+ }
1637
+
1638
+ //print_tensor_vec_stats2("NMS Input (quads, probs): ", { quads, probs });
1639
+
1640
+ double msRowCollapse = -1,
1641
+ msAdjacency = -1,
1642
+ msA2ACollapse = -1,
1643
+ msTotal = -1;
1644
+
1645
+ CollapseRowsResult collapseRows;
1646
+ AdjacencyResult adjacency;
1647
+ torch::Tensor retQuads, retConf, regionCounts;
1648
+
1649
+ {
1650
+ CudaStoreTimer tTotal{msTotal, s_timerEnabled};
1651
+ {
1652
+ CudaStoreTimer t{msRowCollapse, s_timerEnabled && verbose && s_verboseLevel2};
1653
+
1654
+ // First combine all of the quads in each row
1655
+ collapseRows = collapse_rows(quads, probs, probThreshold, iouThreshold);
1656
+
1657
+ if (collapseRows.TotalNumQuads == 0) {
1658
+ return {
1659
+ torch::empty({ 0, 4, 2 }, quads.options()),
1660
+ torch::empty({ 0 }, probs.options()),
1661
+ collapseRows.ExCounts.toType(torch::kInt64)
1662
+ };
1663
+ }
1664
+ }
1665
+ {
1666
+ CudaStoreTimer t{msAdjacency, s_timerEnabled && verbose && s_verboseLevel2};
1667
+ adjacency = compute_all_to_all_adjacency(collapseRows, iouThreshold);
1668
+ }
1669
+ {
1670
+ CudaStoreTimer t{msA2ACollapse, s_timerEnabled && verbose && s_verboseLevel2};
1671
+ std::tie(retQuads, retConf, regionCounts) = all_to_all_collapse<scalar_t>(collapseRows, adjacency);
1672
+ }
1673
+ }
1674
+
1675
+ #ifndef NDEBUG
1676
+ assert(regionCounts.sum().item<int64_t>() == retQuads.size(0));
1677
+ #endif
1678
+
1679
+ //print_tensor_vec_stats2(" Full NMS (quads, conf, counts): ", { retQuads, retConf, retCounts });
1680
+
1681
+ if (s_timerEnabled && verbose) {
1682
+ std::cout << "NMS Cuda " << retQuads.size(0)
1683
+ << " - Row Collapse (" << quads.size(0) << ", " << quads.size(1) << ", " << quads.size(2) << ") - (" << collapseRows.TotalNumQuads << "): " << msRowCollapse << "ms"
1684
+ << ", Adjacency (" << adjacency.AdjCounts.sum(torch::kInt32).item<int32_t>() << "): " << msAdjacency << "ms"
1685
+ << ", A2A Collapse (" << retQuads.size(0) << "): " << msA2ACollapse << "ms"
1686
+ << ", Total: " << msTotal << "ms"
1687
+ << std::endl;
1688
+ }
1689
+
1690
+ return { retQuads, retConf, regionCounts };
1691
+ }
1692
+
1693
+ nms_result_t cuda_quad_non_maximal_suppression(
1694
+ torch::Tensor quads, torch::Tensor probs,
1695
+ float probThreshold, float iouThreshold,
1696
+ int64_t kernelHeight, int64_t kernelWidth,
1697
+ int64_t maxRegions, bool verbose)
1698
+ {
1699
+ nms_result_t ret;
1700
+
1701
+ ret = cuda_quad_non_maximal_suppression_impl<float>(
1702
+ quads.toType(torch::kFloat32), probs.toType(torch::kFloat32),
1703
+ probThreshold, iouThreshold,
1704
+ maxRegions, verbose
1705
+ );
1706
+
1707
+ // AT_DISPATCH_FLOATING_TYPES_AND_HALF(
1708
+ // quads.scalar_type(),
1709
+ // "cuda_quad_non_maximal_suppression_impl",
1710
+ // ([&] {
1711
+ // ret = cuda_quad_non_maximal_suppression_impl<scalar_t>(
1712
+ // move(quads), move(probs),
1713
+ // probThreshold, iouThreshold,
1714
+ // maxRegions
1715
+ // );
1716
+ // })
1717
+ // );
1718
+
1719
+ return ret;
1720
+ }
nemo-retriever-ocr/cpp/non_maximal_suppression/nms_common.h ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <memory>
8
+ #include <vector>
9
+ #include <unordered_set>
10
+
11
+ #include "../geometry.h"
12
+ #include "../cuda_intellisense.cuh"
13
+ #include "strided_quad.h"
14
+
15
+
16
+
17
+ std::vector<torch::Tensor> quad_nms_from_adjacency(
18
+ torch::Tensor quads, torch::Tensor probs, torch::Tensor adjacency,
19
+ float probThreshold, float iouThreshold,
20
+ int64_t maxRegions);
21
+
22
+ template<typename T>
23
+ struct EmbedQuad_ : public QuadBase_<T, EmbedQuad_<T> > {
24
+ Point_<T> Vertices[4];
25
+ T Confidence;
26
+ T NumQuads = 0;
27
+
28
+ __device__
29
+ EmbedQuad_(T confidence = 0)
30
+ {
31
+ Reset();
32
+ Confidence = confidence;
33
+ }
34
+ __device__
35
+ EmbedQuad_(const EmbedQuad_ &other) = default;
36
+
37
+ __device__
38
+ void swap(EmbedQuad_ &other) noexcept {
39
+ using std::swap;
40
+
41
+ for (size_t i = 0; i < 4; ++i) {
42
+ swap(Vertices[i], other.Vertices[i]);
43
+ }
44
+
45
+ SWAP(Confidence, other.Confidence);
46
+ SWAP(NumQuads, other.NumQuads);
47
+ }
48
+
49
+ __device__
50
+ EmbedQuad_(EmbedQuad_ &&other) : EmbedQuad_() {
51
+ other.swap(*this);
52
+ }
53
+
54
+ __device__
55
+ EmbedQuad_ &operator=(EmbedQuad_ other) {
56
+ other.swap(*this);
57
+ return *this;
58
+ }
59
+
60
+ __device__
61
+ void Append(const EmbedQuad_ &other) {
62
+ Append(other, other.Confidence, other.NumQuads);
63
+ }
64
+
65
+ template<typename Derived>
66
+ __device__
67
+ void Append(const QuadBase_<T, Derived> &q, T conf, T numQuads = 1) {
68
+ Confidence *= NumQuads;
69
+
70
+ if (Confidence > 0) {
71
+ for (size_t i = 0; i < 4; ++i) {
72
+ Vertices[i] *= Confidence;
73
+ }
74
+ }
75
+
76
+ Confidence += conf * numQuads;
77
+
78
+ auto qVertices = static_cast<const Derived *>(&q)->Vertices;
79
+ for (size_t i = 0; i < 4; ++i) {
80
+ Vertices[i] += conf * numQuads * qVertices[i];
81
+ Vertices[i] /= Confidence;
82
+ }
83
+
84
+ NumQuads += numQuads;
85
+ Confidence /= NumQuads;
86
+ }
87
+
88
+ __device__
89
+ void Prepare() {
90
+ // T factor = 1.0 / Confidence;
91
+ // for (size_t i = 0; i < 4; ++i) {
92
+ // Vertices[i] *= factor;
93
+ // }
94
+ // Confidence /= numQuads;
95
+ }
96
+
97
+ __device__
98
+ void Reset() {
99
+ for (size_t i = 0; i < 4; ++i) {
100
+ Vertices[i] = Point_<T>{0, 0};
101
+ }
102
+ Confidence = 0.0f;
103
+ NumQuads = 0;
104
+ }
105
+
106
+ __device__
107
+ const Point_<T> &operator[](size_t v) const { return Vertices[v]; }
108
+ __device__
109
+ Point_<T> &operator[](size_t v) { return Vertices[v]; }
110
+ };
111
+
112
+ struct ZeroInitTag {};
113
+
114
+ template<typename T>
115
+ struct MergeQuad_ : public QuadBase_<T, MergeQuad_<T>> {
116
+ Point_<T> Vertices[4];
117
+ T Confidence;
118
+ T NumQuads;
119
+
120
+ MergeQuad_() = default;
121
+
122
+ __device__
123
+ MergeQuad_(ZeroInitTag) : Confidence(0), NumQuads(0) {
124
+ for (size_t i = 0; i < 4; ++i) {
125
+ Vertices[i] = Point_<T>{0, 0};
126
+ }
127
+ }
128
+
129
+ template<typename Derived>
130
+ __device__
131
+ void Append(const QuadBase_<T, Derived> &q, T conf) {
132
+ Confidence += conf;
133
+ ++NumQuads;
134
+
135
+ auto &d = static_cast<const Derived&>(q);
136
+ for (size_t i = 0; i < 4; ++i) {
137
+ Vertices[i] += conf * d[i];
138
+ }
139
+ }
140
+ __device__
141
+ void Append(const EmbedQuad_<T> &q) {
142
+ T qConf = q.NumQuads * q.Confidence;
143
+
144
+ Confidence += qConf;
145
+ NumQuads += q.NumQuads;
146
+ for (size_t i = 0; i < 4; ++i) {
147
+ Vertices[i] += qConf * q.Vertices[i];
148
+ }
149
+ }
150
+ __device__
151
+ void Append(const StridedEmbedQuad_<T> &q) {
152
+ const T numQuads = q.NumQuads();
153
+ const T qConf = numQuads * q.Confidence();
154
+
155
+ Confidence += qConf;
156
+ NumQuads += numQuads;
157
+ for (size_t i = 0; i < 4; ++i) {
158
+ Vertices[i] += qConf * q[i];
159
+ }
160
+ }
161
+
162
+ __device__
163
+ EmbedQuad_<T> Commit() {
164
+ EmbedQuad_<T> ret;
165
+ for (size_t i = 0; i < 4; ++i) {
166
+ ret.Vertices[i] = Vertices[i] / Confidence;
167
+ }
168
+ ret.Confidence = Confidence / NumQuads;
169
+ ret.NumQuads = NumQuads;
170
+
171
+ return ret;
172
+ }
173
+
174
+ __device__
175
+ const Point_<T> &operator[](size_t v) const { return Vertices[v]; }
176
+ __device__
177
+ Point_<T> &operator[](size_t v) { return Vertices[v]; }
178
+ };
179
+
180
+ template<typename T, typename Intermediate=float>
181
+ __device__
182
+ inline T triangle_root(T val)
183
+ {
184
+ // It's easier to visualize this algorithm for a lower triangular matrix
185
+ // What we're trying to find is the `row` of a lower triangular matrix that a given `val` resides in.
186
+ // e.g. 0->0, 2->1, 4->2, etc.
187
+ //
188
+ // 0: 0
189
+ // 1: 1 2
190
+ // 2: 3 4 5
191
+ // 3: 6 7 8 9
192
+ //
193
+ // See https://math.stackexchange.com/questions/698961/finding-the-triangular-root-of-a-number for explanation
194
+ Intermediate numer = Intermediate(-1) + sqrt(Intermediate(1) + Intermediate(8) * Intermediate(val));
195
+ Intermediate denom = Intermediate(2);
196
+
197
+ Intermediate ret = floor(numer / denom);
198
+ return T(ret);
199
+ }
200
+
201
+ template<typename T>
202
+ void visit_node(const std::vector<EmbedQuad_<T>> &allQuads, size_t quadIdx,
203
+ const std::vector<std::vector<size_t>> &adjIdxs, EmbedQuad_<T> &currQuad,
204
+ std::unordered_set<size_t> &visited)
205
+ {
206
+ if (visited.count(quadIdx) > 0) return;
207
+
208
+ const EmbedQuad_<T> &vQuad = allQuads[quadIdx];
209
+
210
+ currQuad.Append(vQuad);
211
+ visited.insert(quadIdx);
212
+
213
+ for (size_t childIdx : adjIdxs[quadIdx]) {
214
+ visit_node(allQuads, childIdx, adjIdxs, currQuad, visited);
215
+ }
216
+ }
217
+
218
+ template<typename T, typename Derived, typename scalar_t>
219
+ void copy_quad(const QuadBase_<T, Derived> &srcQuad, scalar_t *pDest)
220
+ {
221
+ auto vertices = static_cast<const Derived*>(&srcQuad)->Vertices;
222
+ for (size_t i = 0; i < 4; ++i) {
223
+ const Point_<T> &v = vertices[i];
224
+ *pDest++ = v.X;
225
+ *pDest++ = v.Y;
226
+ }
227
+ }
nemo-retriever-ocr/cpp/non_maximal_suppression/nms_kd_tree.h ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ // All rights reserved.
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ #pragma once
6
+
7
+ #include <memory>
8
+ #include <vector>
9
+ #include <stack>
10
+
11
+ #include "../geometry.h"
12
+
13
+
14
+ #define MODE_GEOMETRY 0x02ull
15
+ #define MODE_CHILDREN 0x00ull
16
+
17
+ #define DIM_X 0x0ull
18
+ #define DIM_Y 0x1ull
19
+
20
+ static const size_t INVALID_IDX = -1;
21
+
22
+ template<typename T>
23
+ struct NMS_BoundsWrapper
24
+ {
25
+ typedef std::unique_ptr<NMS_BoundsWrapper> Ptr;
26
+ typedef AABB_<typename T::inner_type> bds_t;
27
+
28
+ size_t GeoIdx;
29
+ const T *Geometry;
30
+ bds_t Bounds;
31
+
32
+ NMS_BoundsWrapper(size_t geoIdx, const T *geometry) : GeoIdx(geoIdx), Geometry(geometry), Bounds(geometry->Bounds()) { }
33
+ };
34
+
35
+ template<typename T>
36
+ class NMS_NodeAllocator;
37
+
38
+ template<typename T>
39
+ class NMS_KDTree;
40
+
41
+ template<typename T>
42
+ class NMS_BuildCache;
43
+
44
+ template<typename T>
45
+ class NMS_KDNode
46
+ {
47
+ friend class NMS_KDTree<T>;
48
+
49
+ public:
50
+ typedef NMS_BoundsWrapper<T> bds_t;
51
+ typedef std::unique_ptr<NMS_KDNode[]> UPtr;
52
+ typedef typename T::inner_type inner_type;
53
+ typedef std::vector<bds_t*> geo_vec_t;
54
+ typedef std::unique_ptr<geo_vec_t> geo_vec_ptr;
55
+
56
+ void Build(geo_vec_ptr geometries, const typename bds_t::bds_t &envelope,
57
+ NMS_NodeAllocator<T> &allocator, NMS_BuildCache<T> &buildCache);
58
+
59
+ template<typename Fn>
60
+ void FindIntersections(size_t geoIdx, const typename bds_t::bds_t &bds, const Fn &fn) const;
61
+
62
+ private:
63
+ inline uintptr_t Dim() const { return reinterpret_cast<uintptr_t>(m_ptr) & 0x01ull; }
64
+ inline uintptr_t Mode() const { return reinterpret_cast<uintptr_t>(m_ptr) & 0x02ull; }
65
+ inline void Children(NMS_KDNode *&children, inner_type &splitPos) const
66
+ {
67
+ auto vPtr = Geometries();
68
+ splitPos = *reinterpret_cast<inner_type*>(vPtr);
69
+ children = reinterpret_cast<NMS_KDNode*>(vPtr + sizeof(inner_type));
70
+ }
71
+ inline uint8_t* Geometries() const
72
+ {
73
+ return reinterpret_cast<uint8_t*>(reinterpret_cast<uintptr_t>(m_ptr) & ~0x3ull);
74
+ }
75
+ inline void SetPtr(uint8_t *vPtr, uintptr_t mode, uintptr_t dim)
76
+ {
77
+ m_ptr = reinterpret_cast<uint8_t*>(
78
+ reinterpret_cast<uintptr_t>(vPtr) | mode | dim
79
+ );
80
+ }
81
+ void AssignGeometries(geo_vec_ptr geometries, NMS_BuildCache<T> &buildCache);
82
+
83
+ uint8_t *m_ptr;
84
+ };
85
+
86
+ template<typename T>
87
+ class NMS_NodeAllocator
88
+ {
89
+ public:
90
+ typedef NMS_KDNode<T> node_t;
91
+ typedef typename node_t::inner_type inner_type;
92
+
93
+ NMS_NodeAllocator(size_t initialGuess = 512);
94
+ ~NMS_NodeAllocator();
95
+
96
+ void Get(size_t numNodes, NMS_KDNode<T> *&outNodes, inner_type *&outSplitPos, uint8_t *&outRawPtr);
97
+
98
+ private:
99
+ std::vector<std::pair<size_t, uint8_t*>> m_buffers;
100
+ size_t m_offset;
101
+ };
102
+
103
+ template<typename T>
104
+ class NMS_BuildCache
105
+ {
106
+ public:
107
+ typedef typename NMS_KDNode<T>::bds_t bds_t;
108
+ typedef std::unique_ptr<NMS_BuildCache> Ptr;
109
+ typedef std::vector<bds_t*> geo_vec_t;
110
+ typedef std::unique_ptr<geo_vec_t> geo_vec_ptr;
111
+
112
+ NMS_BuildCache(size_t initialSize);
113
+ ~NMS_BuildCache();
114
+
115
+ geo_vec_ptr Get(size_t sizeHint);
116
+ bds_t** GetRawBuffer(size_t numGeos, uint8_t *&rawPtr);
117
+
118
+ void Release(geo_vec_ptr buff);
119
+
120
+ private:
121
+ std::stack<geo_vec_ptr> m_cache;
122
+ std::vector<std::pair<size_t, uint8_t*>> m_rawBuffers;
123
+ size_t m_rawOffset;
124
+ };
125
+
126
+
127
+ template<typename T>
128
+ class NMS_KDTree
129
+ {
130
+ typedef typename T::inner_type inner_type;
131
+ typedef NMS_BoundsWrapper<T> bds_t;
132
+ typedef NMS_KDNode<T> node_t;
133
+
134
+ public:
135
+ NMS_KDTree();
136
+ ~NMS_KDTree();
137
+
138
+ void Build(const std::vector<T> &geometries);
139
+
140
+ template<typename Fn>
141
+ void FindIntersections(size_t geoIdx, const Fn &fn) const;
142
+
143
+ template<typename Fn>
144
+ void FindIntersections(const T &geo, const Fn &fn) const;
145
+
146
+ private:
147
+ bds_t *m_wrappers;
148
+ NMS_NodeAllocator<T> m_allocator;
149
+ node_t m_root;
150
+ typename NMS_BuildCache<T>::Ptr m_buildCache;
151
+ };
152
+
153
+ template<typename T>
154
+ NMS_KDTree<T>::NMS_KDTree()
155
+ : m_wrappers(nullptr)
156
+ {
157
+ m_root.m_ptr = nullptr;
158
+ }
159
+
160
+ template<typename T>
161
+ NMS_KDTree<T>::~NMS_KDTree()
162
+ {
163
+ free(m_wrappers);
164
+ }
165
+
166
+ template<typename T>
167
+ void NMS_KDTree<T>::Build(const std::vector<T> &geometries)
168
+ {
169
+ if (geometries.empty()) {
170
+ m_root.m_ptr = nullptr;
171
+ return;
172
+ }
173
+
174
+ // Doing this so that we can perform placement-new on the array buffer, and thus
175
+ // can only perform a single memory allocation for all geometries at once
176
+ m_wrappers = reinterpret_cast<bds_t*>(malloc(sizeof(bds_t) * geometries.size()));
177
+
178
+ m_buildCache.reset(new NMS_BuildCache<T>(geometries.size()));
179
+
180
+ auto bdsGeos = m_buildCache->Get(geometries.size());
181
+
182
+ typename bds_t::bds_t envelope;
183
+
184
+ for (size_t i = 0; i < geometries.size(); ++i) {
185
+ // Placement new. Constructs the object in the place specified in the first (...)
186
+ new (m_wrappers + i) bds_t(i, &geometries[i]);
187
+
188
+ bdsGeos->push_back(m_wrappers + i);
189
+ if (i == 0) {
190
+ envelope = m_wrappers[i].Bounds;
191
+ } else {
192
+ envelope = envelope.Union(m_wrappers[i].Bounds);
193
+ }
194
+ }
195
+
196
+
197
+ m_root.Build(std::move(bdsGeos), envelope, m_allocator, *m_buildCache);
198
+ }
199
+
200
+ template<typename T>
201
+ void NMS_KDNode<T>::Build(geo_vec_ptr geometries, const typename bds_t::bds_t &envelope,
202
+ NMS_NodeAllocator<T> &allocator, NMS_BuildCache<T> &buildCache)
203
+ {
204
+ static const size_t MAX_GEOMETRIES = 8;
205
+
206
+ if (geometries->size() <= MAX_GEOMETRIES) {
207
+ AssignGeometries(std::move(geometries), buildCache);
208
+ } else {
209
+ geo_vec_ptr leftGeos = buildCache.Get(geometries->size()),
210
+ rightGeos = buildCache.Get(geometries->size());
211
+
212
+ inner_type szX = envelope[2] - envelope[0];
213
+ inner_type szY = envelope[3] - envelope[1];
214
+
215
+ int64_t dim = szX > szY ? 0 : 1;
216
+ auto emn = envelope[dim];
217
+ auto emx = envelope[dim + 2];
218
+
219
+ auto pivotPos = (emn + emx) / 2;
220
+ for (bds_t *g : *geometries) {
221
+ auto mn = g->Bounds[dim];
222
+ auto mx = g->Bounds[dim + 2];
223
+
224
+ if (mn < pivotPos) {
225
+ leftGeos->push_back(g);
226
+ }
227
+ if (mx > pivotPos) {
228
+ rightGeos->push_back(g);
229
+ }
230
+ }
231
+
232
+ if (leftGeos->size() == geometries->size() || rightGeos->size() == geometries->size()) {
233
+ AssignGeometries(std::move(geometries), buildCache);
234
+ buildCache.Release(std::move(leftGeos));
235
+ buildCache.Release(std::move(rightGeos));
236
+ } else {
237
+ buildCache.Release(std::move(geometries));
238
+
239
+ inner_type *nodeSplitPos;
240
+ uint8_t *nodeRawPtr;
241
+ NMS_KDNode *children;
242
+ allocator.Get(2, children, nodeSplitPos, nodeRawPtr);
243
+
244
+ SetPtr(nodeRawPtr, MODE_CHILDREN, dim);
245
+ *nodeSplitPos = pivotPos;
246
+
247
+ typename bds_t::bds_t leftEnv{envelope}, rightEnv{envelope};
248
+ // Set the max of the left envelope to the split plane
249
+ leftEnv[dim + 2] = pivotPos;
250
+ // Set the min of the right envelope to the split plane
251
+ rightEnv[dim] = pivotPos;
252
+
253
+ children[0].Build(std::move(leftGeos), leftEnv, allocator, buildCache);
254
+ children[1].Build(std::move(rightGeos), rightEnv, allocator, buildCache);
255
+ }
256
+ }
257
+ }
258
+
259
+ template<typename T>
260
+ void NMS_KDNode<T>::AssignGeometries(geo_vec_ptr geometries, NMS_BuildCache<T> &buildCache)
261
+ {
262
+ if (geometries->empty()) {
263
+ SetPtr(nullptr, MODE_GEOMETRY, 0);
264
+ } else {
265
+ uint8_t *vPtr;
266
+ bds_t **geoPtr = buildCache.GetRawBuffer(geometries->size(), vPtr);
267
+ std::copy(geometries->begin(), geometries->end(), geoPtr);
268
+
269
+ SetPtr(vPtr, MODE_GEOMETRY, 0);
270
+ }
271
+ buildCache.Release(std::move(geometries));
272
+ }
273
+
274
+ template<typename T>
275
+ template<typename Fn>
276
+ void NMS_KDTree<T>::FindIntersections(size_t geoIdx, const Fn &fn) const
277
+ {
278
+ if (!m_wrappers) return;
279
+
280
+ auto &bds = m_wrappers[geoIdx].Bounds;
281
+
282
+ m_root.FindIntersections(geoIdx, bds, fn);
283
+ }
284
+
285
+ template<typename T>
286
+ template<typename Fn>
287
+ void NMS_KDTree<T>::FindIntersections(const T &geo, const Fn &fn) const
288
+ {
289
+ if (!m_wrappers) return;
290
+
291
+ NMS_BoundsWrapper<T> bdsWrapper(INVALID_IDX, &geo);
292
+
293
+ m_root.FindIntersections(INVALID_IDX, bdsWrapper.Bounds, fn);
294
+ }
295
+
296
+ template<typename T>
297
+ template<typename Fn>
298
+ void NMS_KDNode<T>::FindIntersections(size_t geoIdx, const typename bds_t::bds_t &bds, const Fn &fn) const
299
+ {
300
+ auto mode = Mode();
301
+
302
+ if (mode == MODE_GEOMETRY) {
303
+ auto *vPtr = Geometries();
304
+
305
+ size_t numGeos = *reinterpret_cast<size_t*>(vPtr);
306
+ bds_t **geoPtr = reinterpret_cast<bds_t**>(vPtr + sizeof(size_t));
307
+
308
+ bds_t **endPtr = geoPtr + numGeos;
309
+ for (; geoPtr != endPtr; ++geoPtr) {
310
+ const bds_t *child = *geoPtr;
311
+
312
+ // Don't compute this against self
313
+ if (geoIdx != INVALID_IDX && child->GeoIdx <= geoIdx) continue;
314
+
315
+ typename bds_t::bds_t::inner_type pctN, pctM, iou;
316
+ std::tie(pctN, pctM, iou) = geometry_region_sizes(bds, child->Bounds);
317
+
318
+ if (iou > 0) {
319
+ fn(child->GeoIdx, pctN, pctM, iou);
320
+ }
321
+ }
322
+ } else {
323
+ auto dim = Dim();
324
+
325
+ auto mn = bds[dim];
326
+ auto mx = bds[dim + 2];
327
+
328
+ NMS_KDNode *children;
329
+ inner_type splitPos;
330
+ Children(children, splitPos);
331
+
332
+ if (mn < splitPos) {
333
+ children[0].FindIntersections(geoIdx, bds, fn);
334
+ }
335
+ if (mx > splitPos) {
336
+ children[1].FindIntersections(geoIdx, bds, fn);
337
+ }
338
+ }
339
+ }
340
+
341
+ template<typename T>
342
+ NMS_NodeAllocator<T>::NMS_NodeAllocator(size_t initialGuess)
343
+ : m_offset(0)
344
+ {
345
+ size_t allocSize = initialGuess * (sizeof(inner_type) + 2 * sizeof(node_t));
346
+ auto ptr = reinterpret_cast<uint8_t*>(malloc(allocSize));
347
+ m_buffers.emplace_back(initialGuess, ptr);
348
+ }
349
+
350
+ template<typename T>
351
+ NMS_NodeAllocator<T>::~NMS_NodeAllocator()
352
+ {
353
+ for (auto &p : m_buffers) {
354
+ free(p.second);
355
+ }
356
+ }
357
+
358
+ template<typename T>
359
+ void NMS_NodeAllocator<T>::Get(size_t numNodes, node_t *&outNodes, inner_type *&outSplitPos, uint8_t *&outRawPtr)
360
+ {
361
+ auto &currBuff = m_buffers.back();
362
+
363
+ size_t rem = currBuff.first - m_offset;
364
+
365
+ size_t reqSize = sizeof(inner_type) + sizeof(node_t) * numNodes;
366
+
367
+ if (rem >= reqSize) {
368
+ outRawPtr = currBuff.second + m_offset;
369
+ outSplitPos = reinterpret_cast<inner_type*>(outRawPtr);
370
+ outNodes = reinterpret_cast<node_t*>(outRawPtr + sizeof(inner_type));
371
+ m_offset += reqSize;
372
+ return;
373
+ }
374
+
375
+ // Rounds up to the nearest factor of 2
376
+ size_t allocSize = (std::max(currBuff.first * 2, reqSize) + 1) & ~0x01ull;
377
+ auto ptr = reinterpret_cast<uint8_t*>(malloc(allocSize));
378
+ m_buffers.emplace_back(allocSize, ptr);
379
+ m_offset = 0;
380
+
381
+ Get(numNodes, outNodes, outSplitPos, outRawPtr);
382
+ }
383
+
384
+ template<typename T>
385
+ NMS_BuildCache<T>::NMS_BuildCache(size_t initialSize)
386
+ : m_rawOffset(0)
387
+ {
388
+ auto allocSize = sizeof(bds_t*) * initialSize * 2;
389
+ auto raw1 = reinterpret_cast<uint8_t*>(malloc(allocSize));
390
+ m_rawBuffers.emplace_back(allocSize, raw1);
391
+ }
392
+
393
+ template<typename T>
394
+ NMS_BuildCache<T>::~NMS_BuildCache()
395
+ {
396
+ for (auto &p : m_rawBuffers) {
397
+ free(p.second);
398
+ }
399
+ }
400
+
401
+ template<typename T>
402
+ typename NMS_BuildCache<T>::geo_vec_ptr NMS_BuildCache<T>::Get(size_t sizeHint)
403
+ {
404
+ geo_vec_ptr ret;
405
+ if (! m_cache.empty()) {
406
+ ret = std::move(m_cache.top());
407
+ m_cache.pop();
408
+ ret->clear();
409
+ } else {
410
+ ret.reset(new std::vector<bds_t*>);
411
+ }
412
+
413
+ ret->reserve(sizeHint);
414
+ return ret;
415
+ }
416
+
417
+ template<typename T>
418
+ typename NMS_BuildCache<T>::bds_t** NMS_BuildCache<T>::GetRawBuffer(size_t numGeos, uint8_t *&rawPtr)
419
+ {
420
+ auto &currBuff = m_rawBuffers.back();
421
+ size_t rem = currBuff.first - m_rawOffset;
422
+
423
+ size_t reqSize = sizeof(size_t) + sizeof(bds_t*) * numGeos;
424
+
425
+ if (rem >= reqSize) {
426
+ rawPtr = currBuff.second + m_rawOffset;
427
+ m_rawOffset += reqSize;
428
+ reinterpret_cast<size_t*>(rawPtr)[0] = numGeos;
429
+ return reinterpret_cast<bds_t**>(rawPtr + sizeof(size_t));
430
+ }
431
+
432
+ size_t allocSize = (std::max(currBuff.first * 2, reqSize) + 1) & ~0x01ull;
433
+ auto ptr = reinterpret_cast<uint8_t*>(malloc(allocSize));
434
+ m_rawBuffers.emplace_back(allocSize, ptr);
435
+ m_rawOffset = 0;
436
+
437
+ return GetRawBuffer(numGeos, rawPtr);
438
+ }
439
+
440
+ template<typename T>
441
+ void NMS_BuildCache<T>::Release(geo_vec_ptr buff)
442
+ {
443
+ m_cache.push(std::move(buff));
444
+ }
445
+
446
+ #undef MODE_GEOMETRY
447
+ #undef MODE_CHILDREN
448
+ #undef DIM_X
449
+ #undef DIM_Y